--- license: apache-2.0 language: - en pipeline_tag: text-generation tags: - medical - slm - healthcare - pytorch - supervised-fine-tuning - clinical-reasoning datasets: - pubmed_qa - BI55/MedText - Mohammed-Altaf/medical-instruction-100k metrics: - perplexity - loss library_name: transformers --- # 🏥 VitalLM-50M-Instruct: Instruction-Tuned Medical SLM
> **A 50.55 million parameter Small Language Model (SLM) fine-tuned for instruction-following clinical dialogue — combining deep biomedical pretraining with supervised instruction alignment.** VitalLM-50M-Instruct is the instruction-tuned successor to VitalLM-50M. Built on a custom decoder-only Transformer architecture pretrained on 764M+ biomedical tokens, this model has been further refined via **Supervised Fine-Tuning (SFT)** on a curated medical instruction dataset — enabling it to follow clinical prompts, answer patient queries, and generate structured medical responses. --- ## 🚀 Key Architectural Choices ### 1. SwiGLU Activation Function Unlike standard GPT models that use ReLU or GeLU, VitalLM-50M utilizes **SwiGLU** to increase reasoning density — enabling more nuanced capture of complex, non-linear relationships found in medical symptoms and drug interactions. ### 2. Specialized Biomedical Tokenization A custom **ByteLevelBPE Tokenizer** with a 16,384 vocabulary size was developed to preserve medical terminology as meaningful units (e.g., preventing fragmentation of terms like `bronchitis` or `tachycardia`), significantly improving inference accuracy and speed. --- ## 📊 Technical Specifications | Parameter | Value | Notes | | :--- | :--- | :--- | | **Total Parameters** | 50.55 Million | Optimized for edge/mobile deployment | | **Architecture** | Decoder-only Transformer | Custom GPT-style | | **Layers (n_layer)** | 10 | Hierarchical clinical reasoning | | **Attention Heads (n_head)** | 8 | Multi-head attention | | **Embedding Dim (n_embd)** | 512 | Medical concept vector space | | **Context Window** | 256 tokens | Clinical dialogues & Q&A | | **Activation** | SwiGLU | Enhanced reasoning density | | **Tokenizer** | ByteLevelBPE | Vocabulary size: 16,384 | --- ## 📈 Training — Stage 1: Pretraining ### Data Strategy - **Corpus**: 550M+ tokens of filtered biomedical research, clinical guidelines, and synthetic medical dialogues. - **Sources**: PubMed QA, MedMCQA, BI55/MedText. - **Pre-processing**: Extensive de-duplication and signal-preserving cleaning. ### Hardware & Optimization - **Compute**: NVIDIA P100 GPU (Kaggle) - **Optimizer**: AdamW with Weight Decay (0.1) - **Scheduler**: Cosine Learning Rate Decay - **Strategy**: Multi-session training with custom state-recovery logic ### Pretraining Results | Metric | Value | |:---|:---| | Final Training Loss | 3.32 | | Final Validation Loss | 3.66 | | Generalization Gap | 0.34 | --- ## 🎯 Training — Stage 2: Supervised Fine-Tuning (SFT) ### SFT Dataset - **Dataset**: [`Mohammed-Altaf/medical-instruction-100k`](https://huggingface.co/datasets/Mohammed-Altaf/medical-instruction-100k) - **Size**: ~100,000 instruction-response pairs - **Format**: Instruction-following medical Q&A covering symptoms, diagnoses, treatments, and clinical dialogue ### SFT Objective The model was fine-tuned to shift from **open-ended generation** (pretraining) to **structured instruction-following** — enabling it to respond reliably to clinical prompts in a doctor-patient dialogue format. ### SFT Hardware & Optimization - **Compute**: NVIDIA P100 GPU (Kaggle) - **Optimizer**: AdamW with Weight Decay (0.1) - **Scheduler**: Cosine Learning Rate Decay with linear warmup (peak LR: 2e-5) - **Training Duration**: ~4,300 iterations ### SFT Results | Metric | Value | |:---|:---| | Best Training Loss | **2.9866** | | Final Training Loss | ~2.96 | | Final Validation Loss | ~2.99 | | Final Train Perplexity | ~19.5 | | Final Val Perplexity | ~19.8 | ## 🛠 Usage & Implementation ### Download Required Files Before running any code, you need the following files. Download them directly from this repository and the Hugging Face model page: | File | Source | Description | |:---|:---|:---| | `model.py` | [GitHub](https://github.com/Aman041902/VitalLM-50M/blob/main/model.py) | Custom model architecture | | `vocab_50m.json` | [Hugging Face](https://huggingface.co/aman0419/Vitallm-50M-Instruct) | Tokenizer vocabulary | | `merges_50m.txt` | [Hugging Face](https://huggingface.co/aman0419/Vitallm-50M-Instruct) | BPE merge rules | > ⚠️ All files must be present in the **same working directory** before running inference. `model.py` contains the custom `SLM` and `SLMConfig` classes which are not available in the standard `transformers` library and cannot be skipped. --- ### Install Dependencies ```bash pip install torch transformers tokenizers safetensors ``` --- ### Loading the Instruction-Tuned Model ```python import torch import torch.nn.functional as F from model import SLM, SLMConfig from tokenizers import ByteLevelBPETokenizer from transformers import PreTrainedTokenizerFast from safetensors.torch import load_file from huggingface_hub import hf_hub_download device = "cuda" if torch.cuda.is_available() else "cpu" # Download safetensors weights weights_path = hf_hub_download( repo_id="aman0419/Vitallm-50M-Instruct", filename="model.safetensors" # use vital_lm_50m_weights.safetensors for pretrained model ) # Initialize model config = SLMConfig(vocab_size=16384, n_layer=10, n_head=8, n_embd=512, block_size=256, dropout=0.0) model = SLM(config) # Load safetensors and fix weight tying state_dict = load_file(weights_path) if 'lm_head.weight' in state_dict and 'transformer.wte.weight' not in state_dict: state_dict['transformer.wte.weight'] = state_dict['lm_head.weight'] model.load_state_dict(state_dict) model.to(device) model.eval() # Load tokenizer base_tokenizer = ByteLevelBPETokenizer(vocab="vocab_50m.json", merges="merges_50m.txt") tokenizer = PreTrainedTokenizerFast( tokenizer_object=base_tokenizer, eos_token="<|endoftext|>", bos_token="<|endoftext|>", unk_token="<|endoftext|>", pad_token="<|endoftext|>" ) ``` ### Generation Function ```python def generate(prompt, max_new_tokens=130, temperature=0.25, top_k=30, top_p=0.9, repetition_penalty=1.25): input_ids = torch.tensor(tokenizer.encode(prompt), dtype=torch.long).unsqueeze(0).to(device) with torch.no_grad(): for _ in range(max_new_tokens): input_ids_cond = input_ids[:, -256:] logits, _ = model(input_ids_cond) logits = logits[:, -1, :] / temperature for token in set(input_ids[0].tolist()): if logits[0, token] > 0: logits[0, token] /= repetition_penalty else: logits[0, token] *= repetition_penalty sorted_logits, sorted_indices = torch.sort(logits, descending=True) cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 logits[0, sorted_indices[sorted_indices_to_remove]] = -float('Inf') if top_k is not None: v, _ = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < v[:, [-1]]] = -float('Inf') next_token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1) input_ids = torch.cat((input_ids, next_token), dim=1) if next_token.item() == tokenizer.eos_token_id: break return tokenizer.decode(input_ids[0].tolist(), skip_special_tokens=True) # Test if __name__ == "__main__": prompt = "Patient: I have been feeling very thirsty and urinating frequently. Doctor:" response = generate(prompt) print(f"Response: {response}") ``` ### Recommended Prompt Format For best results with the SFT model, use the following dialogue-style format: ``` Patient: