| import gradio as gr |
| from transformers import ( |
| AutoTokenizer, |
| AutoModelForSequenceClassification, |
| AutoModelForTokenClassification, |
| pipeline |
| ) |
|
|
| |
| |
| |
| tokenizer = AutoTokenizer.from_pretrained("alemmrr/finbert-gics-sector-classifier") |
| model = AutoModelForSequenceClassification.from_pretrained("alemmrr/finbert-gics-sector-classifier") |
|
|
| clf = pipeline( |
| "text-classification", |
| model=model, |
| tokenizer=tokenizer, |
| top_k=None, |
| device=-1 |
| ) |
|
|
| |
| |
| |
| ner_pipeline = pipeline( |
| "ner", |
| model="Jean-Baptiste/roberta-large-ner-english", |
| aggregation_strategy="simple" |
| ) |
|
|
| |
| |
| |
| def format_headline_variant3(headline): |
| ents = ner_pipeline(headline) |
|
|
| |
| entity_buckets = { |
| "ORG": [], |
| "LOC": [], |
| "PER": [], |
| "GPE": [] |
| } |
|
|
| |
| for ent in ents: |
| tag = ent["entity_group"] |
| word = ent["word"] |
| if tag in entity_buckets: |
| entity_buckets[tag].append(word) |
|
|
| |
| prefix = "" |
| for tag, values in entity_buckets.items(): |
| if values: |
| prefix += f"[{tag}] " + " | ".join(values) + " " |
|
|
| |
| if prefix: |
| prefix = prefix.strip() + " [SEP] " |
|
|
| |
| return prefix + headline |
|
|
|
|
| |
| |
| |
| def predict(text): |
| |
| formatted = format_headline_variant3(text) |
|
|
| outputs = clf(formatted) |
|
|
| |
| if isinstance(outputs, list) and len(outputs) == 1 and isinstance(outputs[0], list): |
| outputs = outputs[0] |
|
|
| scores = [ |
| { |
| "label": o["label"], |
| "confidence": round(float(o["score"]) * 100, 2) |
| } |
| for o in outputs |
| ] |
|
|
| |
| scores = sorted(scores, key=lambda x: x["confidence"], reverse=True) |
| return scores |
|
|
|
|
| |
| |
| |
| demo = gr.Interface( |
| fn=predict, |
| inputs=gr.Textbox(lines=3, label="Enter a financial headline (plain text)"), |
| outputs=gr.JSON(label="All Sector Scores"), |
| title="FinBERT GICS Sector Classifier (Auto-Formatted)", |
| description=( |
| "Enter a plain financial news headline. The app automatically applies NER tagging " |
| ), |
| ) |
|
|
| demo.launch() |
|
|