File size: 2,757 Bytes
ad1bf38 ac5d59b ad1bf38 ac5d59b c0bdd33 751ed4d ac5d59b 751ed4d ad1bf38 ac5d59b ad1bf38 ac5d59b 0cfd049 751ed4d 0cfd049 751ed4d 0cfd049 751ed4d c0bdd33 ac5d59b c0bdd33 ac5d59b 0cfd049 ac5d59b c0bdd33 ad1bf38 c0bdd33 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 | import gradio as gr
from transformers import (
AutoTokenizer,
AutoModelForSequenceClassification,
AutoModelForTokenClassification,
pipeline
)
# -----------------------------
# Load Your Classifier
# -----------------------------
tokenizer = AutoTokenizer.from_pretrained("alemmrr/finbert-gics-sector-classifier")
model = AutoModelForSequenceClassification.from_pretrained("alemmrr/finbert-gics-sector-classifier")
clf = pipeline(
"text-classification",
model=model,
tokenizer=tokenizer,
top_k=None,
device=-1
)
# -----------------------------
# Load NER Model (for auto-formatting)
# -----------------------------
ner_pipeline = pipeline(
"ner",
model="Jean-Baptiste/roberta-large-ner-english",
aggregation_strategy="simple"
)
# -----------------------------
# Helper: Format headline (Variant 3 Prefixing)
# -----------------------------
def format_headline_variant3(headline):
ents = ner_pipeline(headline)
# Buckets (same as training Variant-3)
entity_buckets = {
"ORG": [],
"LOC": [],
"PER": [],
"GPE": []
}
# Fill buckets
for ent in ents:
tag = ent["entity_group"]
word = ent["word"]
if tag in entity_buckets:
entity_buckets[tag].append(word)
# Build prefix
prefix = ""
for tag, values in entity_buckets.items():
if values:
prefix += f"[{tag}] " + " | ".join(values) + " "
# Append [SEP] if any prefix exists
if prefix:
prefix = prefix.strip() + " [SEP] "
# Return final formatted input for classifier
return prefix + headline
# -----------------------------
# Main Prediction Function
# -----------------------------
def predict(text):
# Auto-format headline → Variant 3
formatted = format_headline_variant3(text)
outputs = clf(formatted)
# FIX: Flatten output if it's list-of-lists
if isinstance(outputs, list) and len(outputs) == 1 and isinstance(outputs[0], list):
outputs = outputs[0]
scores = [
{
"label": o["label"],
"confidence": round(float(o["score"]) * 100, 2)
}
for o in outputs
]
# Sort by confidence
scores = sorted(scores, key=lambda x: x["confidence"], reverse=True)
return scores
# -----------------------------
# Gradio Interface
# -----------------------------
demo = gr.Interface(
fn=predict,
inputs=gr.Textbox(lines=3, label="Enter a financial headline (plain text)"),
outputs=gr.JSON(label="All Sector Scores"),
title="FinBERT GICS Sector Classifier (Auto-Formatted)",
description=(
"Enter a plain financial news headline. The app automatically applies NER tagging "
),
)
demo.launch()
|