RamezCh commited on
Commit
722f900
·
verified ·
1 Parent(s): 9084ef7

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. README.md +100 -41
README.md CHANGED
@@ -103,62 +103,121 @@ The model depends on the base `sproto` package (which contains `MultiProtoModule
103
 
104
  ```bash
105
  pip install torch>=1.12.1 \
106
- transformers>=4.25.1 \
107
- torchmetrics>=0.10.1 \
108
- pytorch-lightning==1.9
 
 
109
  ```
110
 
111
  | Package | Required version | Reason |
112
  |---------|-----------------|--------|
113
  | `torch` | `>= 1.12.1` | Minimum version for `nn.PairwiseDistance` and `torch.einsum` patterns used in the prototype layer |
114
- | `transformers` | `>= 4.25.1` | Minimum version with `trust_remote_code` + `auto_map` support for custom model loading |
115
- | `torchmetrics` | `>= 0.10.1` | `MultilabelAveragePrecision` was added in 0.10; older versions raise `AttributeError` on load |
116
  | `pytorch-lightning` | `== 1.9` | `MultiProtoModule` is a `pl.LightningModule`; the exact API (e.g. `validation_epoch_end`) changed in 2.x |
 
 
117
  | `sproto` | bundled | The `sproto/` package is included in this HF repo and downloaded automatically with `trust_remote_code=True` — no separate install needed |
118
 
119
  ## Inference Example
120
 
121
  ```python
122
- from transformers import AutoTokenizer, AutoModel
123
  import torch
 
 
 
 
124
 
125
- tokenizer = AutoTokenizer.from_pretrained("DATEXIS/sproto")
126
- model = AutoModel.from_pretrained("DATEXIS/sproto", trust_remote_code=True)
127
- model.eval()
128
-
129
- text_input = [
130
- "CHIEF COMPLAINT: Right Carotid Artery Stenosis. "
131
- "PRESENT ILLNESS: Ms. ___ is a ___ year old woman with hyperlipidemia, "
132
- "cirrhosis with esophageal varices, alcoholism, COPD, left eye blindness, "
133
- "and right carotid stenosis status post right carotid endarterectomy."
134
- ]
135
-
136
- inputs = tokenizer(
137
- text_input,
138
- padding=True,
139
- truncation=True,
140
- max_length=512,
141
- return_tensors="pt"
142
- )
143
-
144
- tokens = [tokenizer.convert_ids_to_tokens(ids) for ids in inputs["input_ids"]]
145
-
146
- with torch.no_grad():
147
- output = model(
148
- input_ids=inputs["input_ids"],
149
- attention_mask=inputs["attention_mask"],
150
- token_type_ids=inputs.get("token_type_ids"),
151
- tokens=tokens
152
  )
153
 
154
- logits = output["logits"]
155
- max_indices = output["max_indices"]
156
- metadata = output["metadata"]
157
-
158
- print("Inference successful")
159
- print("Logits shape:", logits.shape)
160
- print("Max indices:", max_indices)
161
- print("Metadata:", metadata)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
  ```
163
 
164
  > **Note:** `tokens` (the list of token strings per sample) is **required** when `use_attention=True`
 
103
 
104
  ```bash
105
  pip install torch>=1.12.1 \
106
+ transformers==4.40.0 \
107
+ torchmetrics==0.10.3 \
108
+ pytorch-lightning==1.9 \
109
+ huggingface-hub \
110
+ matplotlib
111
  ```
112
 
113
  | Package | Required version | Reason |
114
  |---------|-----------------|--------|
115
  | `torch` | `>= 1.12.1` | Minimum version for `nn.PairwiseDistance` and `torch.einsum` patterns used in the prototype layer |
116
+ | `transformers` | `== 4.40.0` | Required to bypass a metadata parsing bug |
117
+ | `torchmetrics` | `== 0.10.3` | `MultilabelAveragePrecision` was added in 0.10; older versions raise `AttributeError` on load |
118
  | `pytorch-lightning` | `== 1.9` | `MultiProtoModule` is a `pl.LightningModule`; the exact API (e.g. `validation_epoch_end`) changed in 2.x |
119
+ | `huggingface-hub` | any | Required for fetching additional assets like thresholds and labels |
120
+ | `matplotlib` | any | Used for visualizations |
121
  | `sproto` | bundled | The `sproto/` package is included in this HF repo and downloaded automatically with `trust_remote_code=True` — no separate install needed |
122
 
123
  ## Inference Example
124
 
125
  ```python
 
126
  import torch
127
+ import sys
128
+ import json
129
+ from huggingface_hub import snapshot_download, hf_hub_download
130
+ from transformers import AutoTokenizer, AutoModel
131
 
132
+ def main():
133
+ # 1. Download the repo and inject it into sys.path to resolve the internal 'sproto' package
134
+ repo_id = "DATEXIS/sproto"
135
+ repo_path = snapshot_download(repo_id)
136
+ if repo_path not in sys.path:
137
+ sys.path.insert(0, repo_path)
138
+
139
+ # 2. Load Tokenizer and Model
140
+ tokenizer = AutoTokenizer.from_pretrained(repo_id)
141
+ # use_safetensors=False is required to bypass a metadata parsing bug in transformers 4.40.0
142
+ model = AutoModel.from_pretrained(repo_id, trust_remote_code=True, use_safetensors=False)
143
+ model.eval()
144
+
145
+ # 3. Prepare Input Text
146
+ text = """CHIEF COMPLAINT: depression, chest pain and vomiting
147
+
148
+ PRESENT ILLNESS: The patient is a 53-year-old woman with a history of hypertension, diabetes, and depression. She developed severe anxiety and depression. She was having chest pains along with significant vomiting and diarrhea.
149
+ """
150
+
151
+ inputs = tokenizer(
152
+ text,
153
+ return_tensors="pt",
154
+ padding="max_length",
155
+ truncation=True,
156
+ max_length=512
 
 
157
  )
158
 
159
+ # Sproto requires raw token strings for its clinical section masking logic
160
+ tokens = [tokenizer.convert_ids_to_tokens(ids) for ids in inputs["input_ids"]]
161
+
162
+ # 4. Forward Pass
163
+ with torch.no_grad():
164
+ outputs = model(
165
+ input_ids=inputs["input_ids"],
166
+ attention_mask=inputs["attention_mask"],
167
+ tokens=tokens
168
+ )
169
+
170
+ # Apply sigmoid to convert BCE loss logits to probabilities
171
+ probs = torch.sigmoid(outputs.logits)[0]
172
+
173
+ # 5. Fetch Labels and Thresholds dynamically from Hugging Face Hub
174
+ try:
175
+ labels_path = hf_hub_download(repo_id=repo_id, filename="labels.txt")
176
+ icd_mapping_path = hf_hub_download(repo_id=repo_id, filename="icd_10_mappings.json")
177
+ thresholds_path = hf_hub_download(repo_id=repo_id, filename="thresholds_per_label.json")
178
+
179
+ with open(labels_path, "r") as f:
180
+ labels = f.read().strip().split("\n")
181
+ with open(icd_mapping_path, "r") as f:
182
+ icd_mapping = json.load(f)
183
+ with open(thresholds_path, "r") as f:
184
+ threshold_mapping = json.load(f)
185
+ except Exception as e:
186
+ print(f"Warning: Could not load label mapping files from HF Hub: {e}")
187
+ labels, threshold_mapping = None, None
188
+
189
+ # 6. Evaluate and Print Results
190
+ print("\n--- Inference Results ---")
191
+ if labels and threshold_mapping:
192
+ threshold_tensor = torch.zeros(len(labels))
193
+ for idx, label in enumerate(labels):
194
+ val = threshold_mapping.get(label, 0.20)
195
+ threshold_tensor[idx] = val if val > 0.0 else 0.20 # Enforce valid > 0.0 threshold
196
+
197
+ predicted_indices = torch.where(probs > threshold_tensor)[0]
198
+ else:
199
+ predicted_indices = torch.where(probs > 0.20)[0]
200
+
201
+ if len(predicted_indices) == 0:
202
+ print("No diagnoses predicted above the threshold.")
203
+ else:
204
+ results = []
205
+ for idx in predicted_indices:
206
+ idx_val = idx.item()
207
+ prob = probs[idx_val].item()
208
+
209
+ if labels and idx_val < len(labels):
210
+ icd_code = labels[idx_val]
211
+ description = icd_mapping.get(icd_code, "Unknown Description")
212
+ results.append((icd_code, description, prob))
213
+
214
+ # Sort alphabetically by ICD-10 code
215
+ results.sort(key=lambda x: x[0])
216
+ for icd_code, description, prob in results:
217
+ print(f"- {icd_code} ({description}): {prob:.4f}")
218
+
219
+ if __name__ == "__main__":
220
+ main()
221
  ```
222
 
223
  > **Note:** `tokens` (the list of token strings per sample) is **required** when `use_attention=True`