"""Gradio demo for the ChEMU NER model (mpkato/chemu-biobert-ner). Loads a `BertForTokenClassification` fine-tuned on ChEMU 2020 Task 1 and shows extracted reaction entities as colored spans via `gr.HighlightedText`. """ from __future__ import annotations import os import gradio as gr from transformers import pipeline MODEL_ID = os.environ.get("CHEMU_MODEL_ID", "mpkato/chemu-biobert-ner") DEFAULT_TEXT = ( "Under blue LED light, N-Boc-pyrrolidine was coupled with " "4-cyanopyridine in acetonitrile using [Ru(bpy)\u2083]Cl\u2082 " "as the photocatalyst and DIPEA as the reductant to afford " "tert-butyl 2-(4-pyridyl)pyrrolidine-1-carboxylate." ) EXAMPLES = [ [DEFAULT_TEXT], [ "Step 1. 4-chloro-2-fluorobenzoic acid (5.0 g, 12.3 mmol) was " "dissolved in dioxane (40 mL) at room temperature for 2 h." ], [ "Benzyl bromide (1.5 mmol) and triethylamine (2 mmol) were " "stirred in DMF at 60 \u00b0C for 3 hours to give the title " "compound (85% yield)." ], ] # Color palette for the 10 entity types. Colors are chosen to be # visually distinct and mutually readable on a light background. COLOR_MAP = { "STARTING_MATERIAL": "#BBDEFB", # blue "REAGENT_CATALYST": "#FFE0B2", # orange "REACTION_PRODUCT": "#C8E6C9", # green "SOLVENT": "#E1BEE7", # purple "OTHER_COMPOUND": "#E0E0E0", # gray "TEMPERATURE": "#FFCDD2", # red "TIME": "#FFF59D", # yellow "YIELD_PERCENT": "#B2DFDB", # teal "YIELD_OTHER": "#B3E5FC", # cyan "EXAMPLE_LABEL": "#F8BBD0", # pink } # Held-out dev F1 per type (from the training run) PER_TYPE_METRICS = [ ["STARTING_MATERIAL", 0.8881, 413], ["REAGENT_CATALYST", 0.9005, 289], ["REACTION_PRODUCT", 0.9553, 506], ["SOLVENT", 0.9545, 250], ["OTHER_COMPOUND", 0.9689, 1080], ["TEMPERATURE", 0.9813, 346], ["TIME", 0.9862, 252], ["YIELD_PERCENT", 1.0000, 228], ["YIELD_OTHER", 0.9867, 261], ["EXAMPLE_LABEL", 0.9862, 218], ] ENTITY_DESCRIPTIONS = { "STARTING_MATERIAL": ("Reactant providing the core skeleton", "aniline, benzyl bromide, N-Boc-pyrrolidine"), "REAGENT_CATALYST": ("Reagent, catalyst, base, oxidant, reductant", "sodium hydride, DIPEA, [Ru(bpy)₃]Cl₂"), "REACTION_PRODUCT": ("Target product of the reaction", "tert-butyl 2-(4-pyridyl)pyrrolidine-1-carboxylate"), "SOLVENT": ("Reaction or extraction medium", "THF, dioxane, acetonitrile"), "OTHER_COMPOUND": ("Auxiliary: brine, drying agent, wash, by-product", "brine, celite, ethyl acetate"), "TEMPERATURE": ("Reaction temperature (or range)", "50 °C, room temperature, −78 °C"), "TIME": ("Elapsed reaction time", "2 h, overnight, 30 min"), "YIELD_PERCENT": ("Yield as a percentage", "56%, 85%, quantitative"), "YIELD_OTHER": ("Yield expressed as mass or moles", "1.30 g, 2.5 mmol"), "EXAMPLE_LABEL": ("Numeric identifier for a compound or example", "Example 5, (1), 14"), } def _legend_html() -> str: rows = [] for label, (desc, examples) in ENTITY_DESCRIPTIONS.items(): color = COLOR_MAP[label] rows.append( f"""
{label}
{desc}
e.g. {examples}
""" ) return '
' + "".join(rows) + "
" CUSTOM_CSS = """ .gradio-container { max-width: 1100px !important; margin: 0 auto !important; } #header-block { background: linear-gradient(135deg, #1e3c72 0%, #2a5298 100%); color: #ffffff; padding: 32px 28px; border-radius: 18px; text-align: center; box-shadow: 0 8px 24px rgba(30, 60, 114, 0.20); margin-bottom: 24px; } #header-block h1 { color: #ffffff; margin: 0 0 8px 0; font-size: 2.2rem; font-weight: 700; } #header-block p { color: rgba(255, 255, 255, 0.92); margin: 4px 0; font-size: 1.0rem; } #header-block .chip-row { margin-top: 14px; } #header-block .chip { display: inline-block; padding: 6px 14px; margin: 4px 4px 0 4px; background: rgba(255, 255, 255, 0.18); border: 1px solid rgba(255, 255, 255, 0.35); border-radius: 999px; font-size: 0.9rem; font-weight: 500; } .section-title { color: #1e3c72; font-size: 1.25rem; font-weight: 700; margin: 24px 0 8px 0; } .legend-grid { display: grid; grid-template-columns: repeat(2, 1fr); gap: 12px 24px; padding: 16px 4px; } @media (max-width: 700px) { .legend-grid { grid-template-columns: 1fr; } } .legend-row { display: flex; align-items: flex-start; gap: 12px; } .legend-chip { display: inline-block; min-width: 160px; padding: 6px 12px; border-radius: 8px; font-family: ui-monospace, Menlo, Consolas, monospace; font-size: 0.82rem; font-weight: 700; text-align: center; color: #1a1a1a; flex-shrink: 0; } .legend-body { font-size: 0.9rem; line-height: 1.4; } .legend-desc { color: #1a1a1a; } .legend-examples { color: #5f6368; font-size: 0.82rem; margin-top: 2px; } #footer-block { margin-top: 32px; padding: 18px; background: #f5f7fb; border-radius: 12px; color: #455a64; text-align: center; font-size: 0.88rem; } #footer-block a { color: #1e3c72; text-decoration: none; font-weight: 600; } #footer-block a:hover { text-decoration: underline; } """ HEADER_HTML = """

⚗️ ChEMU NER Demo

Named-entity extraction for chemical reaction descriptions in patents

BioBERT fine-tune held-out dev F1 = 0.9585 10 entity types CC BY-NC 3.0
""" FOOTER_HTML = """ """ def _load_pipeline(): return pipeline( "token-classification", model=MODEL_ID, aggregation_strategy="simple", stride=64, ) NER = _load_pipeline() def extract(text: str): """Run the NER model and return a list of (text, label) segments.""" if not text: return [] result = NER(text) spans: list[tuple[str, str | None]] = [] cursor = 0 for ent in result: start, end = int(ent["start"]), int(ent["end"]) if start > cursor: spans.append((text[cursor:start], None)) spans.append((text[start:end], ent["entity_group"])) cursor = end if cursor < len(text): spans.append((text[cursor:], None)) return spans with gr.Blocks( title="ChEMU NER Demo", theme=gr.themes.Soft( primary_hue="indigo", secondary_hue="blue", ), css=CUSTOM_CSS, ) as demo: gr.HTML(HEADER_HTML) gr.HTML('
🧪 Reaction description
') text_in = gr.Textbox( label="", lines=6, value=DEFAULT_TEXT, placeholder="Paste a chemical reaction description here...", show_label=False, ) with gr.Row(): extract_btn = gr.Button( "⚡ Extract entities", variant="primary", size="lg", scale=3, ) clear_btn = gr.Button("Clear", variant="secondary", scale=1) gr.HTML('
🔍 Extracted entities
') highlighted = gr.HighlightedText( label="", combine_adjacent=True, show_legend=False, color_map=COLOR_MAP, show_label=False, ) gr.HTML('
📋 Quick examples
') gr.Examples( examples=EXAMPLES, inputs=[text_in], label="", examples_per_page=3, ) with gr.Accordion("📊 Held-out dev performance (ChEMU 2020)", open=False): gr.Dataframe( headers=["Entity type", "F1", "support"], value=[[t, f"{f1:.4f}", n] for t, f1, n in PER_TYPE_METRICS], interactive=False, wrap=True, ) gr.Markdown( "**Overall micro-F1 = 0.9585** on 225 held-out dev documents " "(3,843 entities). For reference, the official BANNER baseline " "scores 0.8893." ) gr.HTML('
📖 Entity type legend
') gr.HTML(_legend_html()) gr.HTML(FOOTER_HTML) extract_btn.click(extract, inputs=[text_in], outputs=[highlighted]) clear_btn.click(lambda: ("", []), outputs=[text_in, highlighted]) text_in.submit(extract, inputs=[text_in], outputs=[highlighted]) # Run inference once at load so the user sees a highlighted result # as soon as the Space boots. demo.load(extract, inputs=[text_in], outputs=[highlighted]) if __name__ == "__main__": demo.launch()