Spaces:
Sleeping
Sleeping
Claude
feat(sprint-E.2): 10 modules measurements/ migrés vers evaluation/metrics/
4eb91d0 unverified | """Calcul des métriques de précision sur entités nommées (NER). | |
| Sprint 38 — A.II.1.a du plan d'évolution 2026 : couche de calcul pure. | |
| Pourquoi ce module | |
| ------------------ | |
| Pour un médiéviste, un archiviste ou un économiste-historien, | |
| l'utilité aval d'un OCR ne se mesure pas seulement au CER ; ce qui | |
| compte c'est de savoir si les **entités nommées** (personnes, lieux, | |
| dates, organisations) ont survécu à la transcription. Un CER de 5 % | |
| qui rate 80 % des noms propres est inutilisable pour l'indexation | |
| prosopographique. | |
| Stratégie de découpage en sprints | |
| --------------------------------- | |
| Comme pour la divergence taxonomique (Sprints 35-37), on découpe : | |
| - **Sprint 38** (ici) — couche de calcul pure : alignement IoU entre | |
| deux listes d'entités, calcul de Precision/Recall/F1 par catégorie | |
| et global, détection des hallucinations d'entité. Aucune dépendance | |
| externe (pas de spaCy, pas de Stanza) ; les listes d'entités sont | |
| fournies en entrée. Un test de l'enregistrement dans le registre | |
| typé Sprint 34 garantit l'intégration. | |
| - **Sprint à venir** — backend extracteur (spaCy / Stanza / HIPE) et | |
| câblage runner+narratif+HTML. | |
| Format des entités | |
| ------------------ | |
| Compatible avec ``EntitiesGT`` du Sprint 32 — chaque entité est un | |
| dictionnaire ``{"label": str, "start": int, "end": int, "text": str}`` | |
| où ``start``/``end`` sont des offsets caractère. | |
| Convention d'alignement | |
| ----------------------- | |
| Une entité hypothèse "matche" une entité de référence si : | |
| 1. les **labels sont identiques** (case-insensitive), | |
| 2. le ratio d'**Intersection-over-Union** (IoU) sur leurs spans | |
| caractère est ``≥ iou_threshold`` (défaut : 0,5). | |
| Une entité de référence non matchée → faux négatif (recall pénalisé). | |
| Une entité hypothèse non matchée → faux positif (précision pénalisée). | |
| Un faux positif est aussi compté comme **hallucination d'entité**, ce | |
| qui est utile pour les VLM/LLM qui inventent. | |
| Limites | |
| ------- | |
| - L'alignement bag-of-spans : une entité peut être matchée par au plus | |
| une entité de l'autre côté (sinon double-comptage). | |
| - Les modèles NER (spaCy, etc.) hallucinent eux-mêmes. La métrique | |
| mesure conjointement OCR + NER. Documenter explicitement. | |
| """ | |
| from __future__ import annotations | |
| import logging | |
| from dataclasses import dataclass | |
| from typing import Iterable | |
| from picarones.evaluation.metric_registry import register_metric | |
| from picarones.domain.artifacts import ArtifactType | |
| logger = logging.getLogger(__name__) | |
| # ────────────────────────────────────────────────────────────────────────── | |
| # Modèle de données | |
| # ────────────────────────────────────────────────────────────────────────── | |
| class Entity: | |
| """Entité nommée alignée sur un texte. | |
| Attributs | |
| --------- | |
| label: | |
| Catégorie de l'entité (ex. ``"PER"``, ``"LOC"``, ``"DATE"``). | |
| La comparaison se fait en *case-insensitive*. | |
| start, end: | |
| Offsets caractère (inclus, exclu) sur le texte de référence. | |
| text: | |
| Forme de surface — informative, **non utilisée pour | |
| l'alignement** (deux entités peuvent matcher même si leur | |
| forme de surface diffère, du moment que leurs spans | |
| chevauchent suffisamment). | |
| """ | |
| label: str | |
| start: int | |
| end: int | |
| text: str = "" | |
| def __post_init__(self) -> None: | |
| if self.start > self.end: | |
| raise ValueError( | |
| f"Entity span invalide : start={self.start} > end={self.end}" | |
| ) | |
| def length(self) -> int: | |
| return max(0, self.end - self.start) | |
| def _to_entity(obj: Entity | dict) -> Entity: | |
| """Coerce un dict (format EntitiesGT) en ``Entity``.""" | |
| if isinstance(obj, Entity): | |
| return obj | |
| return Entity( | |
| label=str(obj["label"]), | |
| start=int(obj["start"]), | |
| end=int(obj["end"]), | |
| text=str(obj.get("text", "")), | |
| ) | |
| # ────────────────────────────────────────────────────────────────────────── | |
| # Alignement par IoU | |
| # ────────────────────────────────────────────────────────────────────────── | |
| def _iou(a: Entity, b: Entity) -> float: | |
| """Intersection-over-Union sur les spans caractère.""" | |
| inter_start = max(a.start, b.start) | |
| inter_end = min(a.end, b.end) | |
| inter = max(0, inter_end - inter_start) | |
| union = a.length + b.length - inter | |
| if union <= 0: | |
| return 0.0 | |
| return inter / union | |
| def _align( | |
| references: list[Entity], | |
| hypotheses: list[Entity], | |
| iou_threshold: float, | |
| ) -> tuple[list[tuple[int, int, float]], set[int], set[int]]: | |
| """Aligne deux listes d'entités par IoU décroissant (greedy). | |
| Returns | |
| ------- | |
| matches: | |
| Liste de triplets ``(idx_ref, idx_hyp, iou)`` triés par IoU | |
| décroissant — chaque entité n'apparaît qu'une fois. | |
| unmatched_refs: | |
| Indices des entités GT non matchées (faux négatifs). | |
| unmatched_hyps: | |
| Indices des entités hypothèse non matchées (faux positifs). | |
| """ | |
| candidates: list[tuple[float, int, int]] = [] | |
| for i, r in enumerate(references): | |
| for j, h in enumerate(hypotheses): | |
| if r.label.casefold() != h.label.casefold(): | |
| continue | |
| score = _iou(r, h) | |
| if score >= iou_threshold: | |
| candidates.append((score, i, j)) | |
| # Tri par IoU décroissant ; à IoU égale, on prend l'ordre des paires | |
| # pour garantir un tri stable et déterministe. | |
| candidates.sort(key=lambda t: (-t[0], t[1], t[2])) | |
| matched_refs: set[int] = set() | |
| matched_hyps: set[int] = set() | |
| matches: list[tuple[int, int, float]] = [] | |
| for score, i, j in candidates: | |
| if i in matched_refs or j in matched_hyps: | |
| continue | |
| matched_refs.add(i) | |
| matched_hyps.add(j) | |
| matches.append((i, j, score)) | |
| unmatched_refs = set(range(len(references))) - matched_refs | |
| unmatched_hyps = set(range(len(hypotheses))) - matched_hyps | |
| return matches, unmatched_refs, unmatched_hyps | |
| # ────────────────────────────────────────────────────────────────────────── | |
| # Calcul des métriques | |
| # ────────────────────────────────────────────────────────────────────────── | |
| def _prf(tp: int, fp: int, fn: int) -> dict[str, float]: | |
| """Précision / rappel / F1 à partir des comptes.""" | |
| precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0 | |
| recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0 | |
| f1 = ( | |
| 2 * precision * recall / (precision + recall) | |
| if (precision + recall) > 0 | |
| else 0.0 | |
| ) | |
| return { | |
| "precision": precision, | |
| "recall": recall, | |
| "f1": f1, | |
| "support": tp + fn, | |
| } | |
| def compute_ner_metrics( | |
| reference_entities: Iterable[Entity | dict], | |
| hypothesis_entities: Iterable[Entity | dict], | |
| iou_threshold: float = 0.5, | |
| ) -> dict: | |
| """Calcule la précision/rappel/F1 sur entités nommées. | |
| Parameters | |
| ---------- | |
| reference_entities: | |
| Liste d'entités GT (format ``Entity`` ou dict de | |
| ``EntitiesGT``). | |
| hypothesis_entities: | |
| Liste d'entités produites par le NER sur la sortie OCR. | |
| iou_threshold: | |
| Seuil de chevauchement caractère pour qu'un appariement | |
| soit valide (défaut : 0,5 — convention CoNLL/HIPE). | |
| Returns | |
| ------- | |
| dict | |
| ``{ | |
| "global": {"precision", "recall", "f1", "support"}, | |
| "per_category": {label: {"precision", ...}}, | |
| "true_positives": int, | |
| "false_positives": int, | |
| "false_negatives": int, | |
| "hallucinated_entities": list[dict], # entités OCR sans GT | |
| "missed_entities": list[dict], # entités GT non détectées | |
| "iou_threshold": float, | |
| }`` | |
| """ | |
| refs = [_to_entity(e) for e in reference_entities] | |
| hyps = [_to_entity(e) for e in hypothesis_entities] | |
| matches, unmatched_refs, unmatched_hyps = _align(refs, hyps, iou_threshold) | |
| tp = len(matches) | |
| fn = len(unmatched_refs) | |
| fp = len(unmatched_hyps) | |
| # Comptes par catégorie | |
| cat_tp: dict[str, int] = {} | |
| cat_fn: dict[str, int] = {} | |
| cat_fp: dict[str, int] = {} | |
| for i, _j, _score in matches: | |
| cat = refs[i].label | |
| cat_tp[cat] = cat_tp.get(cat, 0) + 1 | |
| for i in unmatched_refs: | |
| cat = refs[i].label | |
| cat_fn[cat] = cat_fn.get(cat, 0) + 1 | |
| for j in unmatched_hyps: | |
| cat = hyps[j].label | |
| cat_fp[cat] = cat_fp.get(cat, 0) + 1 | |
| all_categories = sorted(set(cat_tp) | set(cat_fn) | set(cat_fp)) | |
| per_category = { | |
| cat: _prf(cat_tp.get(cat, 0), cat_fp.get(cat, 0), cat_fn.get(cat, 0)) | |
| for cat in all_categories | |
| } | |
| return { | |
| "global": _prf(tp, fp, fn), | |
| "per_category": per_category, | |
| "true_positives": tp, | |
| "false_positives": fp, | |
| "false_negatives": fn, | |
| "hallucinated_entities": [ | |
| {"label": hyps[j].label, "start": hyps[j].start, | |
| "end": hyps[j].end, "text": hyps[j].text} | |
| for j in sorted(unmatched_hyps) | |
| ], | |
| "missed_entities": [ | |
| {"label": refs[i].label, "start": refs[i].start, | |
| "end": refs[i].end, "text": refs[i].text} | |
| for i in sorted(unmatched_refs) | |
| ], | |
| "iou_threshold": iou_threshold, | |
| } | |
| # ────────────────────────────────────────────────────────────────────────── | |
| # Enregistrement dans le registre typé (Sprint 34) | |
| # ────────────────────────────────────────────────────────────────────────── | |
| def ner_f1( | |
| reference_entities: Iterable[Entity | dict], | |
| hypothesis_entities: Iterable[Entity | dict], | |
| ) -> float: | |
| """F1 global ; raccourci enregistré pour les jonctions ``(ENTITIES, ENTITIES)``.""" | |
| return compute_ner_metrics(reference_entities, hypothesis_entities)["global"]["f1"] | |
| __all__ = [ | |
| "Entity", | |
| "compute_ner_metrics", | |
| "ner_f1", | |
| ] | |