--- license: apache-2.0 base_model: Qwen/Qwen2.5-7B-Instruct tags: - medical - clinical-reasoning - diagnosis - grpo - rlhf - lora - qwen datasets: - gretelai/symptom_to_diagnosis language: - en pipeline_tag: text-generation library_name: peft --- # dx-reasoning-qwen2.5-grpo A LoRA adapter for clinical diagnostic reasoning, fine-tuned from Qwen2.5-7B-Instruct using Group Relative Policy Optimisation (GRPO). ## Model description This model was trained to improve clinical diagnostic reasoning by learning to generate step-by-step reasoning before providing a diagnosis. It uses a structured format with `` and `` tags. ### Training details - **Base model:** Qwen/Qwen2.5-7B-Instruct - **Training method:** GRPO (Group Relative Policy Optimisation) - **LoRA configuration:** - Rank (r): 64 - Alpha: 128 - Dropout: 0.05 - Target modules: q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj - **Trainable parameters:** 161M / 7.7B total (2.08%) - **Training steps:** 700 (2+ epochs) - **Hardware:** 2x NVIDIA H100 80GB - **Training time:** ~20 hours ### Reward function The model was trained with a composite reward function: - **Embedding similarity:** Cosine similarity between generated diagnosis and ground truth using PubMedBERT embeddings - **Reasoning quality:** Bonus for including structured reasoning steps ### Dataset Trained on [gretelai/symptom_to_diagnosis](https://huggingface.co/datasets/gretelai/symptom_to_diagnosis): - 853 training samples - 200 evaluation samples ## Usage ```python from transformers import AutoModelForCausalLM, AutoTokenizer from peft import PeftModel # Load base model base_model = AutoModelForCausalLM.from_pretrained( "Qwen/Qwen2.5-7B-Instruct", torch_dtype="auto", device_map="auto" ) tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct") # Load LoRA adapter model = PeftModel.from_pretrained(base_model, "chrisvoncsefalvay/dx-reasoning-qwen2.5-grpo") # Example inference prompt = """You are a medical expert. Given the patient's symptoms, provide a diagnosis. Patient symptoms: The patient presents with severe headache, sensitivity to light, neck stiffness, and fever. First, provide your reasoning in tags, then give your diagnosis in tags.""" messages = [{"role": "user", "content": prompt}] text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) inputs = tokenizer(text, return_tensors="pt").to(model.device) outputs = model.generate(**inputs, max_new_tokens=512, temperature=0.7) print(tokenizer.decode(outputs[0], skip_special_tokens=True)) ``` ## Rollouts Training rollouts are available in the `rollouts/` directory, containing generation samples at each evaluation step (100, 200, 300, 400, 500, 600, 700). These can be used for per-diagnosis analysis of training progression. ## Limitations - Trained on a relatively small dataset (853 samples) - Focused on symptom-to-diagnosis task; may not generalise to other medical reasoning tasks - Should not be used for actual medical diagnosis - for research purposes only ## Citation If you use this model, please cite: ```bibtex @misc{dx-reasoning-qwen2.5-grpo, author = {Chris von Csefalvay}, title = {dx-reasoning-qwen2.5-grpo: Clinical Diagnostic Reasoning with GRPO}, year = {2026}, publisher = {Hugging Face}, url = {https://huggingface.co/chrisvoncsefalvay/dx-reasoning-qwen2.5-grpo} } ``` ## Training logs Training was monitored via Weights & Biases. See the project for detailed metrics and training curves.