tdb-intake / app.py
tttjjj's picture
Make submissions editable; load by trial_id + username
089c7d1
Raw
History Blame
9.89 kB
"""Trial Design Benchmark — Intake form.
Run locally:
streamlit run app.py
Deployed to Hugging Face Space; submissions are committed to an HF Dataset repo.
See README.md for setup.
"""
from __future__ import annotations
import json
from typing import List
import streamlit as st
from lib.schema import (
DESIGN_ELEMENTS,
QUESTION_TYPES,
Question,
blank_question,
next_question_id,
rubrics_for_type,
)
from lib.storage import get_submission_by_key, hf_configured, save_submission
st.set_page_config(
page_title="TDB Intake",
page_icon="🔬",
layout="centered",
)
# ------------- state init ------------------------------------------------
if "questions" not in st.session_state:
st.session_state.questions = [] # type: List[Question]
if "trial_id" not in st.session_state:
st.session_state.trial_id = ""
if "username" not in st.session_state:
st.session_state.username = ""
if "last_result" not in st.session_state:
st.session_state.last_result = None
# Bumped whenever we replace the questions wholesale (load / reset) so the
# per-question widgets get fresh keys and actually show the new values.
if "form_nonce" not in st.session_state:
st.session_state.form_nonce = 0
# ------------- callbacks -------------------------------------------------
def _add_question() -> None:
qid = next_question_id(st.session_state.questions)
st.session_state.questions.append(blank_question(qid))
def _remove_question(idx: int) -> None:
st.session_state.questions.pop(idx)
def _change_question_type(q_idx: int, new_type: str) -> None:
"""Regenerate rubrics when question_type changes."""
q = st.session_state.questions[q_idx]
if q["question_type"] != new_type:
q["question_type"] = new_type
q["rubrics"] = rubrics_for_type(new_type)
def _save_draft() -> None:
"""Stash form into session_state (already there); just signal."""
st.session_state.last_result = {"kind": "draft", "msg": "Draft saved in this browser session."}
def _load() -> None:
"""Load an existing submission by (trial_id, username) into the form."""
trial_id = st.session_state.trial_id.strip()
username = st.session_state.username.strip()
if not trial_id or not username:
st.session_state.last_result = {
"kind": "error",
"msg": "Enter trial_id and username, then click Load.",
}
return
try:
record = get_submission_by_key(trial_id, username)
except Exception as e:
st.session_state.last_result = {"kind": "error", "msg": f"Load failed: {e}"}
return
if not record:
st.session_state.last_result = {
"kind": "info",
"msg": f"No existing submission for `{trial_id}` / `{username}`. "
"Add questions and Submit to create one.",
}
return
prompts = (record.get("comparison") or {}).get("prompts") or []
st.session_state.questions = prompts
st.session_state.form_nonce += 1 # force question widgets to refresh
updated = record.get("updatedAt") or record.get("submittedAt") or ""
st.session_state.last_result = {
"kind": "success",
"msg": f"Loaded {len(prompts)} question(s) (last updated {updated}). "
"Edit and Submit to update.",
}
def _submit() -> None:
trial_id = st.session_state.trial_id.strip()
username = st.session_state.username.strip()
if not trial_id or not username:
st.session_state.last_result = {"kind": "error", "msg": "trial_id and username are required."}
return
comparison = {
"trial_id": trial_id,
"username": username,
"prompts": st.session_state.questions,
}
try:
result = save_submission(trial_id, username, comparison)
verb = "Updated" if result.get("updated") else "Submitted"
st.session_state.last_result = {
"kind": "success",
"msg": f"{verb}: `{result['submissionId']}`. "
"You can keep editing and Submit again to update.",
"url": result.get("url"),
}
# Keep the form populated so the user can continue editing.
except Exception as e:
st.session_state.last_result = {"kind": "error", "msg": f"Submit failed: {e}"}
# ------------- header ----------------------------------------------------
st.title("Trial Design Benchmark")
st.caption("Statistician intake form")
if not hf_configured:
st.info(
"ℹ️ HF env vars not set — submissions will be written to `./data/submissions/` "
"(local dev mode)."
)
# ------------- top fields ------------------------------------------------
c1, c2 = st.columns(2)
with c1:
st.text_input("trial_id", key="trial_id", placeholder="e.g., NCT01234567")
with c2:
st.text_input("username", key="username", placeholder="e.g., jdoe")
st.button(
"Load existing submission",
on_click=_load,
help="If you already submitted for this trial_id + username, load it back to edit.",
)
st.divider()
# ------------- questions list --------------------------------------------
st.subheader("Questions")
if not st.session_state.questions:
st.caption('No questions yet. Click "Add question" below to begin.')
n = st.session_state.form_nonce # widget-key namespace; changes on load/reset
for i, q in enumerate(st.session_state.questions):
with st.container(border=True):
head_l, head_r = st.columns([6, 1])
with head_l:
new_id = st.text_input(
"id",
value=q["id"],
key=f"q_{n}_{i}_id",
label_visibility="collapsed",
)
q["id"] = new_id
with head_r:
st.button("Remove", key=f"rm_{n}_{i}", on_click=_remove_question, args=(i,))
col1, col2 = st.columns(2)
with col1:
de_options = [""] + DESIGN_ELEMENTS
de_idx = de_options.index(q["design_element"]) if q["design_element"] in de_options else 0
new_de = st.selectbox(
"design_element",
options=de_options,
index=de_idx,
key=f"q_{n}_{i}_de",
format_func=lambda x: "— select —" if x == "" else x,
)
q["design_element"] = new_de
if new_de == "Others":
q["design_element_other"] = st.text_input(
"Specify other design element",
value=q.get("design_element_other", ""),
key=f"q_{n}_{i}_de_other",
)
else:
q["design_element_other"] = ""
with col2:
qt_options = [""] + QUESTION_TYPES
qt_idx = qt_options.index(q["question_type"]) if q["question_type"] in qt_options else 0
new_qt = st.selectbox(
"question_type",
options=qt_options,
index=qt_idx,
key=f"q_{n}_{i}_qt",
format_func=lambda x: "— select —" if x == "" else x,
)
# If question_type changed, regenerate rubrics via callback-like pattern.
if new_qt != q["question_type"]:
_change_question_type(i, new_qt)
new_question = st.text_input(
"question",
value=q["question"],
key=f"q_{n}_{i}_question",
placeholder="e.g., Alpha allocated to PFS",
)
q["question"] = new_question
if q["rubrics"]:
st.markdown(f"**Rubrics ({len(q['rubrics'])})**")
for j, r in enumerate(q["rubrics"]):
with st.container(border=True):
meta_parts = [f"**Artifact:** `{r['artifact']}`"]
if r["dimension"]:
meta_parts.append(f"**Dimension:** {r['dimension']}")
st.markdown(" · ".join(meta_parts))
rc1, rc2 = st.columns(2)
with rc1:
r["points"] = st.text_input(
"points",
value=r["points"],
key=f"q_{n}_{i}_r_{j}_points",
)
with rc2:
r["tolerance"] = st.text_input(
"tolerance",
value=r["tolerance"],
key=f"q_{n}_{i}_r_{j}_tolerance",
)
r["criterion"] = st.text_area(
"criterion",
value=r["criterion"],
key=f"q_{n}_{i}_r_{j}_criterion",
height=80,
)
st.button("+ Add question", on_click=_add_question)
st.divider()
# ------------- actions ---------------------------------------------------
action_l, action_r = st.columns([1, 1])
with action_l:
st.button("Save draft", on_click=_save_draft, use_container_width=True)
with action_r:
st.button("Submit", on_click=_submit, type="primary", use_container_width=True)
# ------------- status banner ---------------------------------------------
res = st.session_state.last_result
if res:
if res["kind"] == "success":
st.success(res["msg"])
if res.get("url"):
st.markdown(f"[View on Hugging Face]({res['url']})")
elif res["kind"] == "error":
st.error(res["msg"])
else:
st.info(res["msg"])
with st.expander("Debug: current form state (JSON)"):
st.code(
json.dumps(
{
"trial_id": st.session_state.trial_id,
"username": st.session_state.username,
"prompts": st.session_state.questions,
},
indent=2,
ensure_ascii=False,
),
language="json",
)