File size: 3,581 Bytes
f3ceb78
dfee6eb
6d5647d
 
dfee6eb
 
 
 
 
df4cadc
dfee6eb
 
 
 
 
 
 
f3ceb78
 
dfee6eb
df4cadc
dfee6eb
df4cadc
dfee6eb
df4cadc
dfee6eb
df4cadc
dfee6eb
df4cadc
dfee6eb
 
 
 
 
 
 
 
 
 
 
df4cadc
dfee6eb
df4cadc
dfee6eb
 
 
df4cadc
dfee6eb
df4cadc
dfee6eb
 
 
df4cadc
dfee6eb
df4cadc
dfee6eb
 
 
df4cadc
dfee6eb
 
 
 
 
 
 
df4cadc
dfee6eb
 
df4cadc
dfee6eb
 
df4cadc
dfee6eb
df4cadc
dfee6eb
df4cadc
dfee6eb
 
 
df4cadc
dfee6eb
 
 
df4cadc
dfee6eb
df4cadc
dfee6eb
df4cadc
dfee6eb
df4cadc
dfee6eb
 
 
df4cadc
dfee6eb
df4cadc
dfee6eb
df4cadc
dfee6eb
 
 
 
 
 
 
 
 
df4cadc
dfee6eb
df4cadc
dfee6eb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
---
license: apache-2.0
base_model: Qwen/Qwen2.5-7B-Instruct
tags:
- medical
- clinical-reasoning
- diagnosis
- grpo
- rlhf
- lora
- qwen
datasets:
- gretelai/symptom_to_diagnosis
language:
- en
pipeline_tag: text-generation
library_name: peft
---

# dx-reasoning-qwen2.5-grpo

A LoRA adapter for clinical diagnostic reasoning, fine-tuned from Qwen2.5-7B-Instruct using Group Relative Policy Optimisation (GRPO).

## Model description

This model was trained to improve clinical diagnostic reasoning by learning to generate step-by-step reasoning before providing a diagnosis. It uses a structured format with `<reasoning>` and `<diagnosis>` tags.

### Training details

- **Base model:** Qwen/Qwen2.5-7B-Instruct
- **Training method:** GRPO (Group Relative Policy Optimisation)
- **LoRA configuration:**
  - Rank (r): 64
  - Alpha: 128
  - Dropout: 0.05
  - Target modules: q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj
- **Trainable parameters:** 161M / 7.7B total (2.08%)
- **Training steps:** 700 (2+ epochs)
- **Hardware:** 2x NVIDIA H100 80GB
- **Training time:** ~20 hours

### Reward function

The model was trained with a composite reward function:
- **Embedding similarity:** Cosine similarity between generated diagnosis and ground truth using PubMedBERT embeddings
- **Reasoning quality:** Bonus for including structured reasoning steps

### Dataset

Trained on [gretelai/symptom_to_diagnosis](https://huggingface.co/datasets/gretelai/symptom_to_diagnosis):
- 853 training samples
- 200 evaluation samples

## Usage

```python
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel

# Load base model
base_model = AutoModelForCausalLM.from_pretrained(
    "Qwen/Qwen2.5-7B-Instruct",
    torch_dtype="auto",
    device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct")

# Load LoRA adapter
model = PeftModel.from_pretrained(base_model, "chrisvoncsefalvay/dx-reasoning-qwen2.5-grpo")

# Example inference
prompt = """You are a medical expert. Given the patient's symptoms, provide a diagnosis.

Patient symptoms: The patient presents with severe headache, sensitivity to light, neck stiffness, and fever.

First, provide your reasoning in <reasoning> tags, then give your diagnosis in <diagnosis> tags."""

messages = [{"role": "user", "content": prompt}]
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(text, return_tensors="pt").to(model.device)

outputs = model.generate(**inputs, max_new_tokens=512, temperature=0.7)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
```

## Rollouts

Training rollouts are available in the `rollouts/` directory, containing generation samples at each evaluation step (100, 200, 300, 400, 500, 600, 700). These can be used for per-diagnosis analysis of training progression.

## Limitations

- Trained on a relatively small dataset (853 samples)
- Focused on symptom-to-diagnosis task; may not generalise to other medical reasoning tasks
- Should not be used for actual medical diagnosis - for research purposes only

## Citation

If you use this model, please cite:

```bibtex
@misc{dx-reasoning-qwen2.5-grpo,
  author = {Chris von Csefalvay},
  title = {dx-reasoning-qwen2.5-grpo: Clinical Diagnostic Reasoning with GRPO},
  year = {2026},
  publisher = {Hugging Face},
  url = {https://huggingface.co/chrisvoncsefalvay/dx-reasoning-qwen2.5-grpo}
}
```

## Training logs

Training was monitored via Weights & Biases. See the project for detailed metrics and training curves.