File size: 2,757 Bytes
ad1bf38
ac5d59b
 
 
 
 
 
ad1bf38
ac5d59b
 
 
c0bdd33
 
 
751ed4d
 
 
 
ac5d59b
 
 
 
 
 
 
 
 
 
 
751ed4d
ad1bf38
ac5d59b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ad1bf38
ac5d59b
 
 
 
0cfd049
 
 
 
 
751ed4d
 
 
 
0cfd049
751ed4d
 
0cfd049
 
751ed4d
 
c0bdd33
ac5d59b
 
 
 
c0bdd33
 
ac5d59b
0cfd049
ac5d59b
 
 
 
c0bdd33
ad1bf38
c0bdd33
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import gradio as gr
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    AutoModelForTokenClassification,
    pipeline
)

# -----------------------------
# Load Your Classifier
# -----------------------------
tokenizer = AutoTokenizer.from_pretrained("alemmrr/finbert-gics-sector-classifier")
model = AutoModelForSequenceClassification.from_pretrained("alemmrr/finbert-gics-sector-classifier")

clf = pipeline(
    "text-classification",
    model=model,
    tokenizer=tokenizer,
    top_k=None,
    device=-1
)

# -----------------------------
# Load NER Model (for auto-formatting)
# -----------------------------
ner_pipeline = pipeline(
    "ner",
    model="Jean-Baptiste/roberta-large-ner-english",
    aggregation_strategy="simple"
)

# -----------------------------
# Helper: Format headline (Variant 3 Prefixing)
# -----------------------------
def format_headline_variant3(headline):
    ents = ner_pipeline(headline)

    # Buckets (same as training Variant-3)
    entity_buckets = {
        "ORG": [],
        "LOC": [],
        "PER": [],
        "GPE": []
    }

    # Fill buckets
    for ent in ents:
        tag = ent["entity_group"]
        word = ent["word"]
        if tag in entity_buckets:
            entity_buckets[tag].append(word)

    # Build prefix
    prefix = ""
    for tag, values in entity_buckets.items():
        if values:
            prefix += f"[{tag}] " + " | ".join(values) + " "

    # Append [SEP] if any prefix exists
    if prefix:
        prefix = prefix.strip() + " [SEP] "

    # Return final formatted input for classifier
    return prefix + headline


# -----------------------------
# Main Prediction Function
# -----------------------------
def predict(text):
    # Auto-format headline → Variant 3
    formatted = format_headline_variant3(text)

    outputs = clf(formatted)

    # FIX: Flatten output if it's list-of-lists
    if isinstance(outputs, list) and len(outputs) == 1 and isinstance(outputs[0], list):
        outputs = outputs[0]

    scores = [
        {
            "label": o["label"],
            "confidence": round(float(o["score"]) * 100, 2)
        }
        for o in outputs
    ]

    # Sort by confidence
    scores = sorted(scores, key=lambda x: x["confidence"], reverse=True)
    return scores


# -----------------------------
# Gradio Interface
# -----------------------------
demo = gr.Interface(
    fn=predict,
    inputs=gr.Textbox(lines=3, label="Enter a financial headline (plain text)"),
    outputs=gr.JSON(label="All Sector Scores"),
    title="FinBERT GICS Sector Classifier (Auto-Formatted)",
    description=(
        "Enter a plain financial news headline. The app automatically applies NER tagging "
    ),
)

demo.launch()