Text Classification
Transformers
PyTorch
Safetensors
English
sproto
multi-label-classification
long-tail-learning
medical
clinical-nlp
interpretability
prototypical-networks
ehr
custom_code
Instructions to use DATEXIS/sproto with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use DATEXIS/sproto with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-classification", model="DATEXIS/sproto", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("DATEXIS/sproto", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
Upload folder using huggingface_hub
Browse files
README.md
CHANGED
|
@@ -103,62 +103,121 @@ The model depends on the base `sproto` package (which contains `MultiProtoModule
|
|
| 103 |
|
| 104 |
```bash
|
| 105 |
pip install torch>=1.12.1 \
|
| 106 |
-
transformers
|
| 107 |
-
torchmetrics
|
| 108 |
-
pytorch-lightning==1.9
|
|
|
|
|
|
|
| 109 |
```
|
| 110 |
|
| 111 |
| Package | Required version | Reason |
|
| 112 |
|---------|-----------------|--------|
|
| 113 |
| `torch` | `>= 1.12.1` | Minimum version for `nn.PairwiseDistance` and `torch.einsum` patterns used in the prototype layer |
|
| 114 |
-
| `transformers` | `
|
| 115 |
-
| `torchmetrics` | `
|
| 116 |
| `pytorch-lightning` | `== 1.9` | `MultiProtoModule` is a `pl.LightningModule`; the exact API (e.g. `validation_epoch_end`) changed in 2.x |
|
|
|
|
|
|
|
| 117 |
| `sproto` | bundled | The `sproto/` package is included in this HF repo and downloaded automatically with `trust_remote_code=True` — no separate install needed |
|
| 118 |
|
| 119 |
## Inference Example
|
| 120 |
|
| 121 |
```python
|
| 122 |
-
from transformers import AutoTokenizer, AutoModel
|
| 123 |
import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
token_type_ids=inputs.get("token_type_ids"),
|
| 151 |
-
tokens=tokens
|
| 152 |
)
|
| 153 |
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 162 |
```
|
| 163 |
|
| 164 |
> **Note:** `tokens` (the list of token strings per sample) is **required** when `use_attention=True`
|
|
|
|
| 103 |
|
| 104 |
```bash
|
| 105 |
pip install torch>=1.12.1 \
|
| 106 |
+
transformers==4.40.0 \
|
| 107 |
+
torchmetrics==0.10.3 \
|
| 108 |
+
pytorch-lightning==1.9 \
|
| 109 |
+
huggingface-hub \
|
| 110 |
+
matplotlib
|
| 111 |
```
|
| 112 |
|
| 113 |
| Package | Required version | Reason |
|
| 114 |
|---------|-----------------|--------|
|
| 115 |
| `torch` | `>= 1.12.1` | Minimum version for `nn.PairwiseDistance` and `torch.einsum` patterns used in the prototype layer |
|
| 116 |
+
| `transformers` | `== 4.40.0` | Required to bypass a metadata parsing bug |
|
| 117 |
+
| `torchmetrics` | `== 0.10.3` | `MultilabelAveragePrecision` was added in 0.10; older versions raise `AttributeError` on load |
|
| 118 |
| `pytorch-lightning` | `== 1.9` | `MultiProtoModule` is a `pl.LightningModule`; the exact API (e.g. `validation_epoch_end`) changed in 2.x |
|
| 119 |
+
| `huggingface-hub` | any | Required for fetching additional assets like thresholds and labels |
|
| 120 |
+
| `matplotlib` | any | Used for visualizations |
|
| 121 |
| `sproto` | bundled | The `sproto/` package is included in this HF repo and downloaded automatically with `trust_remote_code=True` — no separate install needed |
|
| 122 |
|
| 123 |
## Inference Example
|
| 124 |
|
| 125 |
```python
|
|
|
|
| 126 |
import torch
|
| 127 |
+
import sys
|
| 128 |
+
import json
|
| 129 |
+
from huggingface_hub import snapshot_download, hf_hub_download
|
| 130 |
+
from transformers import AutoTokenizer, AutoModel
|
| 131 |
|
| 132 |
+
def main():
|
| 133 |
+
# 1. Download the repo and inject it into sys.path to resolve the internal 'sproto' package
|
| 134 |
+
repo_id = "DATEXIS/sproto"
|
| 135 |
+
repo_path = snapshot_download(repo_id)
|
| 136 |
+
if repo_path not in sys.path:
|
| 137 |
+
sys.path.insert(0, repo_path)
|
| 138 |
+
|
| 139 |
+
# 2. Load Tokenizer and Model
|
| 140 |
+
tokenizer = AutoTokenizer.from_pretrained(repo_id)
|
| 141 |
+
# use_safetensors=False is required to bypass a metadata parsing bug in transformers 4.40.0
|
| 142 |
+
model = AutoModel.from_pretrained(repo_id, trust_remote_code=True, use_safetensors=False)
|
| 143 |
+
model.eval()
|
| 144 |
+
|
| 145 |
+
# 3. Prepare Input Text
|
| 146 |
+
text = """CHIEF COMPLAINT: depression, chest pain and vomiting
|
| 147 |
+
|
| 148 |
+
PRESENT ILLNESS: The patient is a 53-year-old woman with a history of hypertension, diabetes, and depression. She developed severe anxiety and depression. She was having chest pains along with significant vomiting and diarrhea.
|
| 149 |
+
"""
|
| 150 |
+
|
| 151 |
+
inputs = tokenizer(
|
| 152 |
+
text,
|
| 153 |
+
return_tensors="pt",
|
| 154 |
+
padding="max_length",
|
| 155 |
+
truncation=True,
|
| 156 |
+
max_length=512
|
|
|
|
|
|
|
| 157 |
)
|
| 158 |
|
| 159 |
+
# Sproto requires raw token strings for its clinical section masking logic
|
| 160 |
+
tokens = [tokenizer.convert_ids_to_tokens(ids) for ids in inputs["input_ids"]]
|
| 161 |
+
|
| 162 |
+
# 4. Forward Pass
|
| 163 |
+
with torch.no_grad():
|
| 164 |
+
outputs = model(
|
| 165 |
+
input_ids=inputs["input_ids"],
|
| 166 |
+
attention_mask=inputs["attention_mask"],
|
| 167 |
+
tokens=tokens
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
# Apply sigmoid to convert BCE loss logits to probabilities
|
| 171 |
+
probs = torch.sigmoid(outputs.logits)[0]
|
| 172 |
+
|
| 173 |
+
# 5. Fetch Labels and Thresholds dynamically from Hugging Face Hub
|
| 174 |
+
try:
|
| 175 |
+
labels_path = hf_hub_download(repo_id=repo_id, filename="labels.txt")
|
| 176 |
+
icd_mapping_path = hf_hub_download(repo_id=repo_id, filename="icd_10_mappings.json")
|
| 177 |
+
thresholds_path = hf_hub_download(repo_id=repo_id, filename="thresholds_per_label.json")
|
| 178 |
+
|
| 179 |
+
with open(labels_path, "r") as f:
|
| 180 |
+
labels = f.read().strip().split("\n")
|
| 181 |
+
with open(icd_mapping_path, "r") as f:
|
| 182 |
+
icd_mapping = json.load(f)
|
| 183 |
+
with open(thresholds_path, "r") as f:
|
| 184 |
+
threshold_mapping = json.load(f)
|
| 185 |
+
except Exception as e:
|
| 186 |
+
print(f"Warning: Could not load label mapping files from HF Hub: {e}")
|
| 187 |
+
labels, threshold_mapping = None, None
|
| 188 |
+
|
| 189 |
+
# 6. Evaluate and Print Results
|
| 190 |
+
print("\n--- Inference Results ---")
|
| 191 |
+
if labels and threshold_mapping:
|
| 192 |
+
threshold_tensor = torch.zeros(len(labels))
|
| 193 |
+
for idx, label in enumerate(labels):
|
| 194 |
+
val = threshold_mapping.get(label, 0.20)
|
| 195 |
+
threshold_tensor[idx] = val if val > 0.0 else 0.20 # Enforce valid > 0.0 threshold
|
| 196 |
+
|
| 197 |
+
predicted_indices = torch.where(probs > threshold_tensor)[0]
|
| 198 |
+
else:
|
| 199 |
+
predicted_indices = torch.where(probs > 0.20)[0]
|
| 200 |
+
|
| 201 |
+
if len(predicted_indices) == 0:
|
| 202 |
+
print("No diagnoses predicted above the threshold.")
|
| 203 |
+
else:
|
| 204 |
+
results = []
|
| 205 |
+
for idx in predicted_indices:
|
| 206 |
+
idx_val = idx.item()
|
| 207 |
+
prob = probs[idx_val].item()
|
| 208 |
+
|
| 209 |
+
if labels and idx_val < len(labels):
|
| 210 |
+
icd_code = labels[idx_val]
|
| 211 |
+
description = icd_mapping.get(icd_code, "Unknown Description")
|
| 212 |
+
results.append((icd_code, description, prob))
|
| 213 |
+
|
| 214 |
+
# Sort alphabetically by ICD-10 code
|
| 215 |
+
results.sort(key=lambda x: x[0])
|
| 216 |
+
for icd_code, description, prob in results:
|
| 217 |
+
print(f"- {icd_code} ({description}): {prob:.4f}")
|
| 218 |
+
|
| 219 |
+
if __name__ == "__main__":
|
| 220 |
+
main()
|
| 221 |
```
|
| 222 |
|
| 223 |
> **Note:** `tokens` (the list of token strings per sample) is **required** when `use_attention=True`
|