tttjjj commited on
Commit
2f8293f
·
1 Parent(s): dc4c942

Add LLM runner to populate submission answers from the SAP

Browse files

llm/run_llm.py: loads a submission (private intake_form_data dataset) + its
parsed SAP (public source dataset, sap.lines.json with page markers), builds the
null-placeholder prompt block, and runs each configured model (Anthropic /
OpenAI) with the trial-statistician prompt to produce output.json + output.R per
model, written under llm/out/. NCT->doc resolved via data/tdr.parquet.

Files changed (4) hide show
  1. .gitignore +3 -0
  2. llm/README.md +65 -0
  3. llm/requirements.txt +5 -0
  4. llm/run_llm.py +371 -0
.gitignore CHANGED
@@ -22,3 +22,6 @@ data/
22
 
23
  # OS
24
  .DS_Store
 
 
 
 
22
 
23
  # OS
24
  .DS_Store
25
+
26
+ # LLM runner outputs
27
+ llm/out/
llm/README.md ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # LLM runner — populate a submission's answers
2
+
3
+ Standalone script that runs one or more LLMs over an intake submission's
4
+ questions, using the trial's parsed SAP as the only source, and writes the
5
+ completed `output.json` (+ `output.R`) per model to local files.
6
+
7
+ ## Setup
8
+
9
+ ```bash
10
+ pip install -r requirements.txt
11
+ export HF_TOKEN=hf_... # read access to the private intake_form_data repo
12
+ export ANTHROPIC_API_KEY=... # for claude-* models
13
+ export OPENAI_API_KEY=... # for gpt-* models
14
+ ```
15
+
16
+ ## Run
17
+
18
+ ```bash
19
+ python run_llm.py --submission NCT02578680__EricZ
20
+ # pick a specific SAP doc + models:
21
+ python run_llm.py --submission NCT02578680__EricZ \
22
+ --doc-id 10.1056_nejmoa1801005 \
23
+ --models claude-opus-4-8 gpt-4o
24
+ ```
25
+
26
+ ## What it does
27
+
28
+ 1. **Submission** — loads the latest version of
29
+ `submissions/<submission>/<stamp>.json` from `trialdesignbench/intake_form_data`.
30
+ 2. **SAP** — maps the NCT id (prefix of the submission name) to a
31
+ `documents/<doi>` folder via `data/tdr.parquet` (one NCT can map to several
32
+ docs — use `--doc-id` to pick), then rebuilds SAP text **with page markers**
33
+ from `sap.lines.json` (so answers can cite page numbers, as the prompt
34
+ requires).
35
+ 3. **Prompt block** — turns the questions into the null-placeholder block
36
+ (`extraction_only` → `extracted_value: null`; `derivation_required` →
37
+ `dimensions: {inputs_used, method, calculated_value}`).
38
+ 4. **Models** — sends the task prompt (system) + SAP + prompt block to each
39
+ model and parses the returned ```json``` (output.json) and ```r``` (output.R)
40
+ blocks.
41
+
42
+ ## Outputs
43
+
44
+ ```
45
+ out/<submission>/
46
+ prompt_block.json # the null block the models were asked to fill
47
+ sap.txt # SAP text fed to the models (with page markers)
48
+ <model>/output.json # completed block
49
+ <model>/output.R # R for derivation questions
50
+ <model>/raw.txt # raw model response (for debugging)
51
+ <model>/error.txt # present only if the call failed
52
+ ```
53
+
54
+ ## Notes / caveats
55
+
56
+ - **Context size:** parsed SAPs are large (e.g. NCT02578680 ≈ 480K chars ≈
57
+ ~120K tokens). That fits Claude (large context) but may exceed smaller
58
+ context windows (e.g. gpt-4o's 128K). For those, use a long-context model,
59
+ or pre-trim the SAP. The script does not truncate.
60
+ - **Closed-book:** the prompt forbids outside knowledge; the script only ever
61
+ sends the one SAP document.
62
+ - **Models are pluggable:** any `claude-*` id routes to Anthropic, any
63
+ `gpt-*`/`o1-*`/`o3-*` to OpenAI. Edit `DEFAULT_MODELS` or pass `--models`.
64
+ - Anthropic calls cache the SAP-bearing block, so re-runs are cheaper.
65
+ ```
llm/requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ huggingface_hub>=0.25
2
+ pandas>=2.0
3
+ pyarrow>=15
4
+ anthropic>=0.40
5
+ openai>=1.40
llm/run_llm.py ADDED
@@ -0,0 +1,371 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Populate a Trial Design Benchmark submission's answers by running LLMs.
3
+
4
+ For one intake submission it:
5
+ 1. loads the submission's latest version (questions + rubrics) from the
6
+ private HF dataset `trialdesignbench/intake_form_data`,
7
+ 2. resolves the trial's parsed SAP from the public HF dataset
8
+ `trialdesignbench/source` (documents/<doi>/sap.lines.json),
9
+ 3. asks each configured model to reproduce the statistical design, returning
10
+ a filled output.json (+ output.R for derivation questions),
11
+ 4. writes the results to local files under --out.
12
+
13
+ Usage:
14
+ export HF_TOKEN=hf_... # needed for the private submissions repo
15
+ export ANTHROPIC_API_KEY=... # for Claude models
16
+ export OPENAI_API_KEY=... # for OpenAI models
17
+ python run_llm.py --submission NCT02578680__EricZ
18
+ python run_llm.py --submission NCT02578680__EricZ --doc-id 10.1056_nejmoa1801005 \
19
+ --models claude-opus-4-8 gpt-4o
20
+
21
+ Outputs:
22
+ out/<submission>/<model>/output.json # completed prompt block
23
+ out/<submission>/<model>/output.R # R for derivation questions
24
+ out/<submission>/<model>/raw.txt # raw model response
25
+ out/<submission>/prompt_block.json # what the models were asked to fill
26
+ out/<submission>/sap.txt # SAP text fed to the models
27
+ """
28
+
29
+ from __future__ import annotations
30
+
31
+ import argparse
32
+ import json
33
+ import os
34
+ import re
35
+ import sys
36
+ from pathlib import Path
37
+
38
+ # ---------------------------------------------------------------------------
39
+ # The task prompt (verbatim). Used as the system prompt.
40
+ # ---------------------------------------------------------------------------
41
+ SYSTEM_PROMPT = "\n".join(
42
+ [
43
+ "You are an experienced trial statistician. You will be provided with the Statistical Analysis Plan (SAP) or protocol from a Phase 3 registrational trial. Your task is to reproduce the statistical design by answering the evaluation questions below.",
44
+ "",
45
+ "There are two types of evaluation questions:",
46
+ "",
47
+ "- Extraction only: locate and report the parameter value directly from the SAP/protocol.",
48
+ "- Derivation required: identify the source inputs from the SAP/protocol, calculate the requested parameter, explain the calculation method, and provide reproducible R code in output.R that implements the calculation and prints the final result.",
49
+ "",
50
+ "Closed-book constraint: use only the input document provided below. Do not draw on prior knowledge of this trial from any external source, including published papers, press releases, registry entries, amendments, or training data. If a value is absent or not derivable from the input document, state this explicitly.",
51
+ "",
52
+ "Every reported value must be traceable to a specific section and page of the input document, or to a calculation whose inputs are themselves traceable to a specific section and page.",
53
+ "",
54
+ "Every numeric value reported must be expressed to at least 4 decimal places unless otherwise specified.",
55
+ "",
56
+ "Do not assume any specific statistical method unless it is explicitly stated or directly derivable from the input document. If multiple methods are plausible, state the assumption made and justify it based on the input document.",
57
+ "",
58
+ "Output instructions:",
59
+ "- Return a file named output.json containing a single block named 'output'. Copy the entire prompt block into 'output' and replace each null value with the extracted or derived result. Do not add, remove, rename, or modify any other fields.",
60
+ "- Return a separate file named output.R implementing the calculations for all Derivation required questions. For each question the script must print: (1) the source inputs, (2) the calculation method and formula applied, and (3) the final calculated value.",
61
+ ]
62
+ )
63
+
64
+ # How the harness asks the model to package its two files in one response.
65
+ RESPONSE_FORMAT_INSTRUCTION = (
66
+ "Return your answer as exactly two fenced code blocks, in this order:\n"
67
+ "1. A ```json block containing the completed output.json "
68
+ "(an object with a single top-level key \"output\").\n"
69
+ "2. A ```r block containing output.R. If there are no Derivation required "
70
+ "questions, return an empty ```r block.\n"
71
+ "Do not include any prose outside the two code blocks."
72
+ )
73
+
74
+ INTAKE_REPO = "trialdesignbench/intake_form_data"
75
+ SOURCE_REPO = "trialdesignbench/source"
76
+
77
+ DEFAULT_MODELS = ["claude-opus-4-8", "gpt-4o"]
78
+
79
+
80
+ # ---------------------------------------------------------------------------
81
+ # HF dataset access
82
+ # ---------------------------------------------------------------------------
83
+
84
+ def _hf():
85
+ from huggingface_hub import HfApi
86
+
87
+ token = os.environ.get("HF_TOKEN", "").strip() or None
88
+ return HfApi(token=token), token
89
+
90
+
91
+ def load_submission(submission: str) -> dict:
92
+ """Load the latest version of submissions/<submission>/<stamp>.json."""
93
+ api, token = _hf()
94
+ prefix = f"submissions/{submission}/"
95
+ try:
96
+ files = api.list_repo_files(repo_id=INTAKE_REPO, repo_type="dataset")
97
+ except Exception as e:
98
+ sys.exit(
99
+ f"Could not list {INTAKE_REPO} (private?). Set HF_TOKEN with read "
100
+ f"access. Error: {e}"
101
+ )
102
+ versions = sorted(
103
+ f for f in files if f.startswith(prefix) and f.endswith(".json")
104
+ )
105
+ if not versions:
106
+ # Fall back: maybe the submission is a single flat file.
107
+ flat = [f for f in files if f == f"submissions/{submission}.json"]
108
+ if not flat:
109
+ sys.exit(f"No submission files found under {prefix} in {INTAKE_REPO}.")
110
+ versions = flat
111
+ latest = versions[-1]
112
+ from huggingface_hub import hf_hub_download
113
+
114
+ path = hf_hub_download(
115
+ repo_id=INTAKE_REPO, repo_type="dataset", filename=latest, token=token
116
+ )
117
+ with open(path, encoding="utf-8") as fh:
118
+ rec = json.load(fh)
119
+ print(f" submission version: {latest}")
120
+ return rec
121
+
122
+
123
+ def resolve_doc_id(nct_id: str, override: str | None) -> str:
124
+ """Map an NCT id to a documents/<doi> folder via tdr.parquet, or use override."""
125
+ api, token = _hf()
126
+ if override:
127
+ return override
128
+ from huggingface_hub import hf_hub_download
129
+
130
+ parquet = hf_hub_download(
131
+ repo_id=SOURCE_REPO, repo_type="dataset", filename="data/tdr.parquet", token=token
132
+ )
133
+ import pandas as pd
134
+
135
+ df = pd.read_parquet(parquet)
136
+ rows = df[df["NCT ID"].astype(str).str.strip() == nct_id]
137
+ if rows.empty:
138
+ sys.exit(f"NCT {nct_id} not found in tdr.parquet; pass --doc-id explicitly.")
139
+
140
+ existing = set(_list_doc_folders(api))
141
+ candidates = []
142
+ for link in rows["Paper Link"].dropna().astype(str):
143
+ m = re.search(r"(10\.\d{4,9}/\S+)", link)
144
+ if not m:
145
+ continue
146
+ folder = m.group(1).replace("/", "_").rstrip(".")
147
+ candidates.append(folder)
148
+ for folder in candidates:
149
+ if folder in existing:
150
+ print(f" resolved {nct_id} -> documents/{folder}")
151
+ return folder
152
+ sys.exit(
153
+ f"None of the DOI folders for {nct_id} exist in {SOURCE_REPO}: {candidates}. "
154
+ f"Pass --doc-id explicitly."
155
+ )
156
+
157
+
158
+ def _list_doc_folders(api) -> list[str]:
159
+ files = api.list_repo_files(repo_id=SOURCE_REPO, repo_type="dataset")
160
+ return sorted({f.split("/")[1] for f in files if f.startswith("documents/") and "/" in f[len("documents/") :]})
161
+
162
+
163
+ def load_sap_text(doc_id: str) -> str:
164
+ """Reconstruct SAP text with page markers from documents/<doc>/sap.lines.json."""
165
+ api, token = _hf()
166
+ from huggingface_hub import hf_hub_download
167
+
168
+ fname = f"documents/{doc_id}/sap.lines.json"
169
+ try:
170
+ path = hf_hub_download(
171
+ repo_id=SOURCE_REPO, repo_type="dataset", filename=fname, token=token
172
+ )
173
+ except Exception as e:
174
+ sys.exit(f"Could not download {fname} from {SOURCE_REPO}: {e}")
175
+ with open(path, encoding="utf-8") as fh:
176
+ data = json.load(fh)
177
+
178
+ chunks = []
179
+ for page in data.get("pages", []):
180
+ pageno = page.get("page", "?")
181
+ chunks.append(f"\n===== Page {pageno} =====")
182
+ for line in page.get("lines", []):
183
+ txt = (line.get("text") or "").strip()
184
+ if txt:
185
+ chunks.append(txt)
186
+ text = "\n".join(chunks).strip()
187
+ if not text:
188
+ sys.exit(f"SAP text for {doc_id} was empty.")
189
+ return text
190
+
191
+
192
+ # ---------------------------------------------------------------------------
193
+ # Build the prompt block (questions with null placeholders)
194
+ # ---------------------------------------------------------------------------
195
+
196
+ def build_prompt_block(submission: dict) -> dict:
197
+ prompts = (submission.get("comparison") or {}).get("prompts") or []
198
+ block = []
199
+ for q in prompts:
200
+ de = q.get("design_element", "")
201
+ if de == "Others" and q.get("design_element_other"):
202
+ de = q["design_element_other"]
203
+ qtype = q.get("question_type", "")
204
+ if qtype == "derivation_required":
205
+ output = {"dimensions": {"inputs_used": None, "method": None, "calculated_value": None}}
206
+ else: # extraction_only (default)
207
+ output = {"extracted_value": None}
208
+ block.append(
209
+ {
210
+ "id": q.get("id", ""),
211
+ "design_element": de,
212
+ "question": q.get("question", ""),
213
+ "question_type": qtype,
214
+ "output": output,
215
+ }
216
+ )
217
+ return {"prompt": block}
218
+
219
+
220
+ def build_user_message(sap_text: str, prompt_block: dict) -> str:
221
+ return (
222
+ "INPUT DOCUMENT (SAP / protocol):\n"
223
+ "<<<BEGIN DOCUMENT>>>\n"
224
+ f"{sap_text}\n"
225
+ "<<<END DOCUMENT>>>\n\n"
226
+ "PROMPT BLOCK — copy this whole block into 'output' and replace each null:\n"
227
+ "```json\n"
228
+ f"{json.dumps(prompt_block, indent=2, ensure_ascii=False)}\n"
229
+ "```\n\n"
230
+ f"{RESPONSE_FORMAT_INSTRUCTION}"
231
+ )
232
+
233
+
234
+ # ---------------------------------------------------------------------------
235
+ # Model callers — return raw response text
236
+ # ---------------------------------------------------------------------------
237
+
238
+ def call_anthropic(model: str, sap_text: str, user_msg: str) -> str:
239
+ from anthropic import Anthropic
240
+
241
+ client = Anthropic() # reads ANTHROPIC_API_KEY
242
+ # Cache the large SAP-bearing user block so re-runs are cheaper.
243
+ resp = client.messages.create(
244
+ model=model,
245
+ max_tokens=8192,
246
+ system=[{"type": "text", "text": SYSTEM_PROMPT}],
247
+ messages=[
248
+ {
249
+ "role": "user",
250
+ "content": [
251
+ {
252
+ "type": "text",
253
+ "text": user_msg,
254
+ "cache_control": {"type": "ephemeral"},
255
+ }
256
+ ],
257
+ }
258
+ ],
259
+ )
260
+ return "".join(b.text for b in resp.content if getattr(b, "type", "") == "text")
261
+
262
+
263
+ def call_openai(model: str, sap_text: str, user_msg: str) -> str:
264
+ from openai import OpenAI
265
+
266
+ client = OpenAI() # reads OPENAI_API_KEY
267
+ resp = client.chat.completions.create(
268
+ model=model,
269
+ messages=[
270
+ {"role": "system", "content": SYSTEM_PROMPT},
271
+ {"role": "user", "content": user_msg},
272
+ ],
273
+ )
274
+ return resp.choices[0].message.content or ""
275
+
276
+
277
+ def call_model(model: str, sap_text: str, user_msg: str) -> str:
278
+ if model.startswith(("claude", "anthropic")):
279
+ return call_anthropic(model, sap_text, user_msg)
280
+ if model.startswith(("gpt", "o1", "o3", "openai")):
281
+ return call_openai(model, sap_text, user_msg)
282
+ raise ValueError(f"Unknown model '{model}' — prefix with claude-/gpt-/o1-...")
283
+
284
+
285
+ # ---------------------------------------------------------------------------
286
+ # Parse the two fenced blocks out of a response
287
+ # ---------------------------------------------------------------------------
288
+
289
+ def extract_blocks(text: str) -> tuple[dict | None, str, str | None]:
290
+ """Return (parsed_output_json, output_r, json_parse_error)."""
291
+ json_match = re.search(r"```json\s*(.+?)```", text, re.DOTALL | re.IGNORECASE)
292
+ r_match = re.search(r"```r\s*(.+?)```", text, re.DOTALL | re.IGNORECASE)
293
+ output_json, err = None, None
294
+ if json_match:
295
+ raw = json_match.group(1).strip()
296
+ try:
297
+ output_json = json.loads(raw)
298
+ except Exception as e:
299
+ err = f"{e}"
300
+ else:
301
+ err = "no ```json block found"
302
+ output_r = r_match.group(1).strip() if r_match else ""
303
+ return output_json, output_r, err
304
+
305
+
306
+ # ---------------------------------------------------------------------------
307
+ # Main
308
+ # ---------------------------------------------------------------------------
309
+
310
+ def main() -> None:
311
+ ap = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter)
312
+ ap.add_argument("--submission", default="NCT02578680__EricZ",
313
+ help="submission folder name <trial>__<user>")
314
+ ap.add_argument("--doc-id", default=None,
315
+ help="documents/<doc-id> folder (default: resolve from NCT via tdr.parquet)")
316
+ ap.add_argument("--models", nargs="+", default=DEFAULT_MODELS,
317
+ help=f"model ids to run (default: {DEFAULT_MODELS})")
318
+ ap.add_argument("--out", default="out", help="output directory")
319
+ args = ap.parse_args()
320
+
321
+ nct_id = args.submission.split("__")[0]
322
+ out_dir = Path(args.out) / args.submission
323
+ out_dir.mkdir(parents=True, exist_ok=True)
324
+
325
+ print(f"Submission: {args.submission} (NCT {nct_id})")
326
+ submission = load_submission(args.submission)
327
+ doc_id = resolve_doc_id(nct_id, args.doc_id)
328
+ sap_text = load_sap_text(doc_id)
329
+ print(f" SAP chars: {len(sap_text):,}")
330
+
331
+ prompt_block = build_prompt_block(submission)
332
+ n_q = len(prompt_block["prompt"])
333
+ print(f" questions: {n_q}")
334
+ (out_dir / "prompt_block.json").write_text(
335
+ json.dumps(prompt_block, indent=2, ensure_ascii=False), encoding="utf-8"
336
+ )
337
+ (out_dir / "sap.txt").write_text(sap_text, encoding="utf-8")
338
+
339
+ if n_q == 0:
340
+ sys.exit("Submission has no questions; nothing to run.")
341
+
342
+ user_msg = build_user_message(sap_text, prompt_block)
343
+
344
+ for model in args.models:
345
+ safe = re.sub(r"[^a-zA-Z0-9._-]", "_", model)
346
+ mdir = out_dir / safe
347
+ mdir.mkdir(parents=True, exist_ok=True)
348
+ print(f"\n>>> {model}")
349
+ try:
350
+ raw = call_model(model, sap_text, user_msg)
351
+ except Exception as e:
352
+ print(f" FAILED: {e}")
353
+ (mdir / "error.txt").write_text(str(e), encoding="utf-8")
354
+ continue
355
+ (mdir / "raw.txt").write_text(raw, encoding="utf-8")
356
+ output_json, output_r, err = extract_blocks(raw)
357
+ if output_json is not None:
358
+ (mdir / "output.json").write_text(
359
+ json.dumps(output_json, indent=2, ensure_ascii=False), encoding="utf-8"
360
+ )
361
+ print(f" wrote output.json ({len((output_json.get('output') or []))} answers)")
362
+ else:
363
+ print(f" could not parse output.json: {err} (see raw.txt)")
364
+ (mdir / "output.R").write_text(output_r or "", encoding="utf-8")
365
+ print(f" wrote output.R ({len(output_r or '')} chars)")
366
+
367
+ print(f"\nDone. Results in {out_dir}/")
368
+
369
+
370
+ if __name__ == "__main__":
371
+ main()