"""Gradio demo for the ChEMU NER model (mpkato/chemu-biobert-ner).
Loads a `BertForTokenClassification` fine-tuned on ChEMU 2020 Task 1 and
shows extracted reaction entities as colored spans via `gr.HighlightedText`.
"""
from __future__ import annotations
import os
import gradio as gr
from transformers import pipeline
MODEL_ID = os.environ.get("CHEMU_MODEL_ID", "mpkato/chemu-biobert-ner")
DEFAULT_TEXT = (
"Under blue LED light, N-Boc-pyrrolidine was coupled with "
"4-cyanopyridine in acetonitrile using [Ru(bpy)\u2083]Cl\u2082 "
"as the photocatalyst and DIPEA as the reductant to afford "
"tert-butyl 2-(4-pyridyl)pyrrolidine-1-carboxylate."
)
EXAMPLES = [
[DEFAULT_TEXT],
[
"Step 1. 4-chloro-2-fluorobenzoic acid (5.0 g, 12.3 mmol) was "
"dissolved in dioxane (40 mL) at room temperature for 2 h."
],
[
"Benzyl bromide (1.5 mmol) and triethylamine (2 mmol) were "
"stirred in DMF at 60 \u00b0C for 3 hours to give the title "
"compound (85% yield)."
],
]
# Color palette for the 10 entity types. Colors are chosen to be
# visually distinct and mutually readable on a light background.
COLOR_MAP = {
"STARTING_MATERIAL": "#BBDEFB", # blue
"REAGENT_CATALYST": "#FFE0B2", # orange
"REACTION_PRODUCT": "#C8E6C9", # green
"SOLVENT": "#E1BEE7", # purple
"OTHER_COMPOUND": "#E0E0E0", # gray
"TEMPERATURE": "#FFCDD2", # red
"TIME": "#FFF59D", # yellow
"YIELD_PERCENT": "#B2DFDB", # teal
"YIELD_OTHER": "#B3E5FC", # cyan
"EXAMPLE_LABEL": "#F8BBD0", # pink
}
# Held-out dev F1 per type (from the training run)
PER_TYPE_METRICS = [
["STARTING_MATERIAL", 0.8881, 413],
["REAGENT_CATALYST", 0.9005, 289],
["REACTION_PRODUCT", 0.9553, 506],
["SOLVENT", 0.9545, 250],
["OTHER_COMPOUND", 0.9689, 1080],
["TEMPERATURE", 0.9813, 346],
["TIME", 0.9862, 252],
["YIELD_PERCENT", 1.0000, 228],
["YIELD_OTHER", 0.9867, 261],
["EXAMPLE_LABEL", 0.9862, 218],
]
ENTITY_DESCRIPTIONS = {
"STARTING_MATERIAL": ("Reactant providing the core skeleton",
"aniline, benzyl bromide, N-Boc-pyrrolidine"),
"REAGENT_CATALYST": ("Reagent, catalyst, base, oxidant, reductant",
"sodium hydride, DIPEA, [Ru(bpy)₃]Cl₂"),
"REACTION_PRODUCT": ("Target product of the reaction",
"tert-butyl 2-(4-pyridyl)pyrrolidine-1-carboxylate"),
"SOLVENT": ("Reaction or extraction medium",
"THF, dioxane, acetonitrile"),
"OTHER_COMPOUND": ("Auxiliary: brine, drying agent, wash, by-product",
"brine, celite, ethyl acetate"),
"TEMPERATURE": ("Reaction temperature (or range)",
"50 °C, room temperature, −78 °C"),
"TIME": ("Elapsed reaction time",
"2 h, overnight, 30 min"),
"YIELD_PERCENT": ("Yield as a percentage",
"56%, 85%, quantitative"),
"YIELD_OTHER": ("Yield expressed as mass or moles",
"1.30 g, 2.5 mmol"),
"EXAMPLE_LABEL": ("Numeric identifier for a compound or example",
"Example 5, (1), 14"),
}
def _legend_html() -> str:
rows = []
for label, (desc, examples) in ENTITY_DESCRIPTIONS.items():
color = COLOR_MAP[label]
rows.append(
f"""
"""
)
return '' + "".join(rows) + "
"
CUSTOM_CSS = """
.gradio-container {
max-width: 1100px !important;
margin: 0 auto !important;
}
#header-block {
background: linear-gradient(135deg, #1e3c72 0%, #2a5298 100%);
color: #ffffff;
padding: 32px 28px;
border-radius: 18px;
text-align: center;
box-shadow: 0 8px 24px rgba(30, 60, 114, 0.20);
margin-bottom: 24px;
}
#header-block h1 {
color: #ffffff;
margin: 0 0 8px 0;
font-size: 2.2rem;
font-weight: 700;
}
#header-block p {
color: rgba(255, 255, 255, 0.92);
margin: 4px 0;
font-size: 1.0rem;
}
#header-block .chip-row {
margin-top: 14px;
}
#header-block .chip {
display: inline-block;
padding: 6px 14px;
margin: 4px 4px 0 4px;
background: rgba(255, 255, 255, 0.18);
border: 1px solid rgba(255, 255, 255, 0.35);
border-radius: 999px;
font-size: 0.9rem;
font-weight: 500;
}
.section-title {
color: #1e3c72;
font-size: 1.25rem;
font-weight: 700;
margin: 24px 0 8px 0;
}
.legend-grid {
display: grid;
grid-template-columns: repeat(2, 1fr);
gap: 12px 24px;
padding: 16px 4px;
}
@media (max-width: 700px) {
.legend-grid { grid-template-columns: 1fr; }
}
.legend-row {
display: flex;
align-items: flex-start;
gap: 12px;
}
.legend-chip {
display: inline-block;
min-width: 160px;
padding: 6px 12px;
border-radius: 8px;
font-family: ui-monospace, Menlo, Consolas, monospace;
font-size: 0.82rem;
font-weight: 700;
text-align: center;
color: #1a1a1a;
flex-shrink: 0;
}
.legend-body {
font-size: 0.9rem;
line-height: 1.4;
}
.legend-desc { color: #1a1a1a; }
.legend-examples { color: #5f6368; font-size: 0.82rem; margin-top: 2px; }
#footer-block {
margin-top: 32px;
padding: 18px;
background: #f5f7fb;
border-radius: 12px;
color: #455a64;
text-align: center;
font-size: 0.88rem;
}
#footer-block a { color: #1e3c72; text-decoration: none; font-weight: 600; }
#footer-block a:hover { text-decoration: underline; }
"""
HEADER_HTML = """
"""
FOOTER_HTML = """
"""
def _load_pipeline():
return pipeline(
"token-classification",
model=MODEL_ID,
aggregation_strategy="simple",
stride=64,
)
NER = _load_pipeline()
def extract(text: str):
"""Run the NER model and return a list of (text, label) segments."""
if not text:
return []
result = NER(text)
spans: list[tuple[str, str | None]] = []
cursor = 0
for ent in result:
start, end = int(ent["start"]), int(ent["end"])
if start > cursor:
spans.append((text[cursor:start], None))
spans.append((text[start:end], ent["entity_group"]))
cursor = end
if cursor < len(text):
spans.append((text[cursor:], None))
return spans
with gr.Blocks(
title="ChEMU NER Demo",
theme=gr.themes.Soft(
primary_hue="indigo",
secondary_hue="blue",
),
css=CUSTOM_CSS,
) as demo:
gr.HTML(HEADER_HTML)
gr.HTML('🧪 Reaction description
')
text_in = gr.Textbox(
label="",
lines=6,
value=DEFAULT_TEXT,
placeholder="Paste a chemical reaction description here...",
show_label=False,
)
with gr.Row():
extract_btn = gr.Button(
"⚡ Extract entities",
variant="primary",
size="lg",
scale=3,
)
clear_btn = gr.Button("Clear", variant="secondary", scale=1)
gr.HTML('🔍 Extracted entities
')
highlighted = gr.HighlightedText(
label="",
combine_adjacent=True,
show_legend=False,
color_map=COLOR_MAP,
show_label=False,
)
gr.HTML('📋 Quick examples
')
gr.Examples(
examples=EXAMPLES,
inputs=[text_in],
label="",
examples_per_page=3,
)
with gr.Accordion("📊 Held-out dev performance (ChEMU 2020)", open=False):
gr.Dataframe(
headers=["Entity type", "F1", "support"],
value=[[t, f"{f1:.4f}", n] for t, f1, n in PER_TYPE_METRICS],
interactive=False,
wrap=True,
)
gr.Markdown(
"**Overall micro-F1 = 0.9585** on 225 held-out dev documents "
"(3,843 entities). For reference, the official BANNER baseline "
"scores 0.8893."
)
gr.HTML('📖 Entity type legend
')
gr.HTML(_legend_html())
gr.HTML(FOOTER_HTML)
extract_btn.click(extract, inputs=[text_in], outputs=[highlighted])
clear_btn.click(lambda: ("", []), outputs=[text_in, highlighted])
text_in.submit(extract, inputs=[text_in], outputs=[highlighted])
# Run inference once at load so the user sees a highlighted result
# as soon as the Space boots.
demo.load(extract, inputs=[text_in], outputs=[highlighted])
if __name__ == "__main__":
demo.launch()