"""Gradio demo for the ChEMU NER model (mpkato/chemu-biobert-ner). Loads a `BertForTokenClassification` fine-tuned on ChEMU 2020 Task 1 and shows extracted reaction entities as colored spans via `gr.HighlightedText`. """ from __future__ import annotations import os import gradio as gr from transformers import pipeline MODEL_ID = os.environ.get("CHEMU_MODEL_ID", "mpkato/chemu-biobert-ner") TITLE = "ChEMU NER (BioBERT)" DESCRIPTION = """\ Fine-tuned **BioBERT** for extracting reaction-step entities from chemical patents, trained on the [ChEMU 2020 Task 1]\ (https://chemu-patent-ie.github.io/) corpus. Paste any chemical patent snippet below and the model will highlight the 10 entity types (reactants, catalysts, solvents, products, conditions, yields, labels). **Held-out dev F1 (exact match, micro): \u2248 0.95** """ ENTITY_GUIDE = """\ | Label | Meaning | Examples | |---|---|---| | **STARTING_MATERIAL** | reactant that provides the core skeleton | `aniline`, `benzyl bromide` | | **REAGENT_CATALYST** | reagent / catalyst / base / oxidant / reductant | `sodium hydride`, `DIPEA` | | **REACTION_PRODUCT** | target product of the reaction | `tert-butyl 2-(4-pyridyl)pyrrolidine-1-carboxylate` | | **SOLVENT** | reaction or extraction medium | `THF`, `dioxane`, `acetonitrile` | | **OTHER_COMPOUND** | auxiliary: brines, drying agents, washes, by-products | `brine`, `celite`, `ethyl acetate` | | **TEMPERATURE** | reaction temperature or range | `50 \u00b0C`, `room temperature` | | **TIME** | elapsed reaction time | `2 h`, `overnight`, `30 min` | | **YIELD_PERCENT** | yield expressed as percentage | `56%`, `quantitative` | | **YIELD_OTHER** | yield expressed as mass or moles | `1.30 g`, `2.5 mmol` | | **EXAMPLE_LABEL** | numeric/identifier labels for compounds or examples | `Example 5`, `(1)`, `14` | """ EXAMPLES = [ [ "Under blue LED light, N-Boc-pyrrolidine was coupled with " "4-cyanopyridine in acetonitrile using [Ru(bpy)\u2083]Cl\u2082 " "as the photocatalyst and DIPEA as the reductant to afford " "tert-butyl 2-(4-pyridyl)pyrrolidine-1-carboxylate." ], [ "Step 1. 4-chloro-2-fluorobenzoic acid (5.0 g, 12.3 mmol) was " "dissolved in dioxane (40 mL) at room temperature for 2 h." ], [ "Benzyl bromide (1.5 mmol) and triethylamine (2 mmol) were " "stirred in DMF at 60 \u00b0C for 3 hours to give the title " "compound (85% yield)." ], ] def _load_pipeline(): return pipeline( "token-classification", model=MODEL_ID, aggregation_strategy="simple", ) NER = _load_pipeline() def extract(text: str): """Run the NER model and return a list of (text, label) segments. Gradio's `HighlightedText` accepts a list of tuples where `label=None` means un-highlighted plain text. """ if not text: return [] result = NER(text) # `aggregation_strategy="simple"` merges adjacent subwords into entity # chunks with `start`, `end`, `entity_group` fields. We walk the text # and emit plain / highlighted segments in order. spans: list[tuple[str, str | None]] = [] cursor = 0 for ent in result: start, end = int(ent["start"]), int(ent["end"]) if start > cursor: spans.append((text[cursor:start], None)) spans.append((text[start:end], ent["entity_group"])) cursor = end if cursor < len(text): spans.append((text[cursor:], None)) return spans with gr.Blocks(title=TITLE, theme=gr.themes.Soft()) as demo: gr.Markdown(f"# {TITLE}") gr.Markdown(DESCRIPTION) with gr.Row(): with gr.Column(scale=2): text_in = gr.Textbox( label="Chemical patent text", lines=8, placeholder="Paste a reaction description here...", ) extract_btn = gr.Button("Extract entities", variant="primary") highlighted = gr.HighlightedText( label="Detected entities", combine_adjacent=True, show_legend=True, ) gr.Examples(examples=EXAMPLES, inputs=[text_in]) with gr.Column(scale=1): gr.Markdown("### Entity legend") gr.Markdown(ENTITY_GUIDE) extract_btn.click(extract, inputs=[text_in], outputs=[highlighted]) text_in.submit(extract, inputs=[text_in], outputs=[highlighted]) gr.Markdown( "---\n" "Model: [`mpkato/chemu-biobert-ner`](https://huggingface.co/mpkato/chemu-biobert-ner) \n" "Training data: ChEMU 2020 NER corpus (CC BY-NC 3.0), for **non-commercial research use only**. \n" "Base encoder: [`dmis-lab/biobert-base-cased-v1.2`](https://huggingface.co/dmis-lab/biobert-base-cased-v1.2)" ) if __name__ == "__main__": demo.launch()