File size: 11,519 Bytes
4eb91d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
"""Calcul des métriques de précision sur entités nommées (NER).

Sprint 38 — A.II.1.a du plan d'évolution 2026 : couche de calcul pure.

Pourquoi ce module
------------------
Pour un médiéviste, un archiviste ou un économiste-historien,
l'utilité aval d'un OCR ne se mesure pas seulement au CER ; ce qui
compte c'est de savoir si les **entités nommées** (personnes, lieux,
dates, organisations) ont survécu à la transcription.  Un CER de 5 %
qui rate 80 % des noms propres est inutilisable pour l'indexation
prosopographique.

Stratégie de découpage en sprints
---------------------------------
Comme pour la divergence taxonomique (Sprints 35-37), on découpe :

- **Sprint 38** (ici) — couche de calcul pure : alignement IoU entre
  deux listes d'entités, calcul de Precision/Recall/F1 par catégorie
  et global, détection des hallucinations d'entité.  Aucune dépendance
  externe (pas de spaCy, pas de Stanza) ; les listes d'entités sont
  fournies en entrée.  Un test de l'enregistrement dans le registre
  typé Sprint 34 garantit l'intégration.
- **Sprint à venir** — backend extracteur (spaCy / Stanza / HIPE) et
  câblage runner+narratif+HTML.

Format des entités
------------------
Compatible avec ``EntitiesGT`` du Sprint 32 — chaque entité est un
dictionnaire ``{"label": str, "start": int, "end": int, "text": str}``
où ``start``/``end`` sont des offsets caractère.

Convention d'alignement
-----------------------
Une entité hypothèse "matche" une entité de référence si :

1. les **labels sont identiques** (case-insensitive),
2. le ratio d'**Intersection-over-Union** (IoU) sur leurs spans
   caractère est ``≥ iou_threshold`` (défaut : 0,5).

Une entité de référence non matchée → faux négatif (recall pénalisé).
Une entité hypothèse non matchée → faux positif (précision pénalisée).
Un faux positif est aussi compté comme **hallucination d'entité**, ce
qui est utile pour les VLM/LLM qui inventent.

Limites
-------
- L'alignement bag-of-spans : une entité peut être matchée par au plus
  une entité de l'autre côté (sinon double-comptage).
- Les modèles NER (spaCy, etc.) hallucinent eux-mêmes.  La métrique
  mesure conjointement OCR + NER.  Documenter explicitement.
"""

from __future__ import annotations

import logging
from dataclasses import dataclass
from typing import Iterable

from picarones.evaluation.metric_registry import register_metric
from picarones.domain.artifacts import ArtifactType

logger = logging.getLogger(__name__)


# ──────────────────────────────────────────────────────────────────────────
# Modèle de données
# ──────────────────────────────────────────────────────────────────────────


@dataclass(frozen=True)
class Entity:
    """Entité nommée alignée sur un texte.

    Attributs
    ---------
    label:
        Catégorie de l'entité (ex. ``"PER"``, ``"LOC"``, ``"DATE"``).
        La comparaison se fait en *case-insensitive*.
    start, end:
        Offsets caractère (inclus, exclu) sur le texte de référence.
    text:
        Forme de surface — informative, **non utilisée pour
        l'alignement** (deux entités peuvent matcher même si leur
        forme de surface diffère, du moment que leurs spans
        chevauchent suffisamment).
    """

    label: str
    start: int
    end: int
    text: str = ""

    def __post_init__(self) -> None:
        if self.start > self.end:
            raise ValueError(
                f"Entity span invalide : start={self.start} > end={self.end}"
            )

    @property
    def length(self) -> int:
        return max(0, self.end - self.start)


def _to_entity(obj: Entity | dict) -> Entity:
    """Coerce un dict (format EntitiesGT) en ``Entity``."""
    if isinstance(obj, Entity):
        return obj
    return Entity(
        label=str(obj["label"]),
        start=int(obj["start"]),
        end=int(obj["end"]),
        text=str(obj.get("text", "")),
    )


# ──────────────────────────────────────────────────────────────────────────
# Alignement par IoU
# ──────────────────────────────────────────────────────────────────────────


def _iou(a: Entity, b: Entity) -> float:
    """Intersection-over-Union sur les spans caractère."""
    inter_start = max(a.start, b.start)
    inter_end = min(a.end, b.end)
    inter = max(0, inter_end - inter_start)
    union = a.length + b.length - inter
    if union <= 0:
        return 0.0
    return inter / union


def _align(
    references: list[Entity],
    hypotheses: list[Entity],
    iou_threshold: float,
) -> tuple[list[tuple[int, int, float]], set[int], set[int]]:
    """Aligne deux listes d'entités par IoU décroissant (greedy).

    Returns
    -------
    matches:
        Liste de triplets ``(idx_ref, idx_hyp, iou)`` triés par IoU
        décroissant — chaque entité n'apparaît qu'une fois.
    unmatched_refs:
        Indices des entités GT non matchées (faux négatifs).
    unmatched_hyps:
        Indices des entités hypothèse non matchées (faux positifs).
    """
    candidates: list[tuple[float, int, int]] = []
    for i, r in enumerate(references):
        for j, h in enumerate(hypotheses):
            if r.label.casefold() != h.label.casefold():
                continue
            score = _iou(r, h)
            if score >= iou_threshold:
                candidates.append((score, i, j))

    # Tri par IoU décroissant ; à IoU égale, on prend l'ordre des paires
    # pour garantir un tri stable et déterministe.
    candidates.sort(key=lambda t: (-t[0], t[1], t[2]))

    matched_refs: set[int] = set()
    matched_hyps: set[int] = set()
    matches: list[tuple[int, int, float]] = []
    for score, i, j in candidates:
        if i in matched_refs or j in matched_hyps:
            continue
        matched_refs.add(i)
        matched_hyps.add(j)
        matches.append((i, j, score))

    unmatched_refs = set(range(len(references))) - matched_refs
    unmatched_hyps = set(range(len(hypotheses))) - matched_hyps
    return matches, unmatched_refs, unmatched_hyps


# ──────────────────────────────────────────────────────────────────────────
# Calcul des métriques
# ──────────────────────────────────────────────────────────────────────────


def _prf(tp: int, fp: int, fn: int) -> dict[str, float]:
    """Précision / rappel / F1 à partir des comptes."""
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    f1 = (
        2 * precision * recall / (precision + recall)
        if (precision + recall) > 0
        else 0.0
    )
    return {
        "precision": precision,
        "recall": recall,
        "f1": f1,
        "support": tp + fn,
    }


def compute_ner_metrics(
    reference_entities: Iterable[Entity | dict],
    hypothesis_entities: Iterable[Entity | dict],
    iou_threshold: float = 0.5,
) -> dict:
    """Calcule la précision/rappel/F1 sur entités nommées.

    Parameters
    ----------
    reference_entities:
        Liste d'entités GT (format ``Entity`` ou dict de
        ``EntitiesGT``).
    hypothesis_entities:
        Liste d'entités produites par le NER sur la sortie OCR.
    iou_threshold:
        Seuil de chevauchement caractère pour qu'un appariement
        soit valide (défaut : 0,5 — convention CoNLL/HIPE).

    Returns
    -------
    dict
        ``{
            "global": {"precision", "recall", "f1", "support"},
            "per_category": {label: {"precision", ...}},
            "true_positives": int,
            "false_positives": int,
            "false_negatives": int,
            "hallucinated_entities": list[dict],   # entités OCR sans GT
            "missed_entities":       list[dict],   # entités GT non détectées
            "iou_threshold": float,
        }``
    """
    refs = [_to_entity(e) for e in reference_entities]
    hyps = [_to_entity(e) for e in hypothesis_entities]

    matches, unmatched_refs, unmatched_hyps = _align(refs, hyps, iou_threshold)

    tp = len(matches)
    fn = len(unmatched_refs)
    fp = len(unmatched_hyps)

    # Comptes par catégorie
    cat_tp: dict[str, int] = {}
    cat_fn: dict[str, int] = {}
    cat_fp: dict[str, int] = {}
    for i, _j, _score in matches:
        cat = refs[i].label
        cat_tp[cat] = cat_tp.get(cat, 0) + 1
    for i in unmatched_refs:
        cat = refs[i].label
        cat_fn[cat] = cat_fn.get(cat, 0) + 1
    for j in unmatched_hyps:
        cat = hyps[j].label
        cat_fp[cat] = cat_fp.get(cat, 0) + 1

    all_categories = sorted(set(cat_tp) | set(cat_fn) | set(cat_fp))
    per_category = {
        cat: _prf(cat_tp.get(cat, 0), cat_fp.get(cat, 0), cat_fn.get(cat, 0))
        for cat in all_categories
    }

    return {
        "global": _prf(tp, fp, fn),
        "per_category": per_category,
        "true_positives": tp,
        "false_positives": fp,
        "false_negatives": fn,
        "hallucinated_entities": [
            {"label": hyps[j].label, "start": hyps[j].start,
             "end": hyps[j].end, "text": hyps[j].text}
            for j in sorted(unmatched_hyps)
        ],
        "missed_entities": [
            {"label": refs[i].label, "start": refs[i].start,
             "end": refs[i].end, "text": refs[i].text}
            for i in sorted(unmatched_refs)
        ],
        "iou_threshold": iou_threshold,
    }


# ──────────────────────────────────────────────────────────────────────────
# Enregistrement dans le registre typé (Sprint 34)
# ──────────────────────────────────────────────────────────────────────────


@register_metric(
    name="ner_f1",
    input_types=(ArtifactType.ENTITIES, ArtifactType.ENTITIES),
    description=(
        "F1 global sur les entités nommées (alignement IoU ≥ 0,5, "
        "labels case-insensitive). Pour le détail par catégorie, "
        "utiliser compute_ner_metrics directement."
    ),
    higher_is_better=True,
    tags={"downstream", "ner", "structure"},
)
def ner_f1(
    reference_entities: Iterable[Entity | dict],
    hypothesis_entities: Iterable[Entity | dict],
) -> float:
    """F1 global ; raccourci enregistré pour les jonctions ``(ENTITIES, ENTITIES)``."""
    return compute_ner_metrics(reference_entities, hypothesis_entities)["global"]["f1"]


__all__ = [
    "Entity",
    "compute_ner_metrics",
    "ner_f1",
]