Spaces:
Sleeping
Sleeping
File size: 9,233 Bytes
03e7c21 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 | """Métrique d'absorption d'erreur — Sprint 94 (B.3).
Sprint 94 — B.3 du plan d'évolution 2026.
Pourquoi ce module
------------------
Quand un module de post-correction LLM aplatit les différences
entre OCR amont, ce n'est pas qu'il « améliore » tous les
moteurs — c'est qu'il introduit ses propres biais qui dominent
ceux de l'OCR. Mesurer la dégradation par étape ne suffit
pas : il faut **séparer** les deux flux.
À chaque jonction où un module transforme un artefact, on
mesure :
- **Taux de correction** : parmi les erreurs présentes en
entrée du module, combien sont corrigées en sortie ?
- **Taux d'introduction** : parmi les erreurs présentes en
sortie, combien sont **nouvelles** (absentes en entrée) ?
C'est la généralisation du score de sur-normalisation
(chantier A.I.7) à toute jonction. La formule s'applique
uniformément à OCR→LLM, OCR→reconstructor, VLM→ALTO_mapper —
toute jonction qui transforme un artefact en un autre du même
type.
Méthode (token-level)
---------------------
On split en tokens whitespace ``reference``, ``before``,
``after``. On compare en **multiset** (un token GT consommé
au plus une fois) :
- ``errors_before`` = tokens GT non retrouvés dans ``before``
- ``errors_after`` = tokens GT non retrouvés dans ``after``
- ``corrected`` = ``errors_before \\ errors_after``
(présents avant, absents après → corrigés)
- ``introduced`` = ``errors_after \\ errors_before``
(absents avant, présents après → introduits)
Garde-fou : le module ne classe pas les erreurs (visuelles,
abréviations, etc.) — c'est une métrique d'**absorption de
volume**, pas de qualité éditoriale. L'intersection sémantique
avec ``taxonomy`` (Sprint 5) est documentée dans le glossaire.
Sortie
------
``compute_error_absorption(reference, before, after)`` retourne :
.. code-block:: text
{
"n_gt_tokens": int,
"n_errors_before": int,
"n_errors_after": int,
"n_corrected": int,
"n_introduced": int,
"n_kept_wrong": int,
"correction_rate": float | None, # n_corrected / n_errors_before
"introduction_rate": float | None, # n_introduced / n_errors_after
"net_improvement": int, # n_corrected - n_introduced
"corrected_tokens": list[str],
"introduced_tokens": list[str],
}
``aggregate_error_absorption(per_doc_results)`` somme les
compteurs corpus-wide et recalcule les taux *micro*.
"""
from __future__ import annotations
import logging
from collections import Counter
from typing import Iterable, Optional
logger = logging.getLogger(__name__)
def _split_words(text: Optional[str]) -> list[str]:
if not text:
return []
return text.split()
def _missing_tokens(
reference: list[str], hypothesis: list[str],
) -> Counter:
"""Tokens GT manquants en hypothèse au sens multiset.
Un token GT compte plusieurs fois s'il apparaît plusieurs
fois ; chaque occurrence en hypothèse en absorbe au plus
une. Retourne un Counter ``{token: nb_occurrences_manquees}``.
"""
ref_count = Counter(reference)
hyp_count = Counter(hypothesis)
missing: Counter = Counter()
for token, n_ref in ref_count.items():
n_hyp = hyp_count.get(token, 0)
if n_hyp < n_ref:
missing[token] = n_ref - n_hyp
return missing
def compute_error_absorption(
reference: Optional[str],
before: Optional[str],
after: Optional[str],
*,
case_sensitive: bool = False,
) -> Optional[dict]:
"""Mesure l'absorption d'erreur entre ``before`` et ``after``.
Parameters
----------
reference:
GT (vérité terrain).
before:
Sortie de l'étape précédente (typiquement OCR amont).
after:
Sortie de l'étape courante (typiquement post-correction LLM).
case_sensitive:
Si False (défaut), match case-insensitive — la sortie
``corrected_tokens``/``introduced_tokens`` reste en casse
GT originale.
Returns
-------
dict | None
``None`` si la GT est vide ou ne contient aucun token.
"""
ref_tokens = _split_words(reference)
if not ref_tokens:
return None
before_tokens = _split_words(before)
after_tokens = _split_words(after)
if case_sensitive:
ref_match = list(ref_tokens)
before_match = list(before_tokens)
after_match = list(after_tokens)
else:
ref_match = [t.lower() for t in ref_tokens]
before_match = [t.lower() for t in before_tokens]
after_match = [t.lower() for t in after_tokens]
# Map case-insensitive token → liste de casses GT originales
ref_orig_by_match: dict[str, list[str]] = {}
for orig, m in zip(ref_tokens, ref_match):
ref_orig_by_match.setdefault(m, []).append(orig)
missing_before = _missing_tokens(ref_match, before_match)
missing_after = _missing_tokens(ref_match, after_match)
n_errors_before = sum(missing_before.values())
n_errors_after = sum(missing_after.values())
# Calcul corrigé / introduit en multiset
corrected_counter: Counter = Counter()
introduced_counter: Counter = Counter()
kept_wrong_counter: Counter = Counter()
all_tokens = set(missing_before) | set(missing_after)
for tok in all_tokens:
nb = missing_before.get(tok, 0)
na = missing_after.get(tok, 0)
if nb > na:
corrected_counter[tok] = nb - na
kept_wrong_counter[tok] = na
elif na > nb:
introduced_counter[tok] = na - nb
kept_wrong_counter[tok] = nb
else:
kept_wrong_counter[tok] = nb
n_corrected = sum(corrected_counter.values())
n_introduced = sum(introduced_counter.values())
n_kept_wrong = sum(kept_wrong_counter.values())
correction_rate = (
n_corrected / n_errors_before
if n_errors_before > 0 else None
)
introduction_rate = (
n_introduced / n_errors_after
if n_errors_after > 0 else None
)
def _expand(counter: Counter) -> list[str]:
out: list[str] = []
for tok, count in counter.items():
origs = ref_orig_by_match.get(tok, [tok])
# Ne renvoie que la casse représentative GT
display = origs[0] if origs else tok
out.extend([display] * count)
return out
return {
"n_gt_tokens": len(ref_tokens),
"n_errors_before": n_errors_before,
"n_errors_after": n_errors_after,
"n_corrected": n_corrected,
"n_introduced": n_introduced,
"n_kept_wrong": n_kept_wrong,
"correction_rate": correction_rate,
"introduction_rate": introduction_rate,
"net_improvement": n_corrected - n_introduced,
"corrected_tokens": _expand(corrected_counter),
"introduced_tokens": _expand(introduced_counter),
}
def aggregate_error_absorption(
per_doc: Iterable[Optional[dict]],
*,
sample_tokens: int = 50,
) -> Optional[dict]:
"""Agrège les compteurs corpus-wide et recalcule les taux
*micro*.
Parameters
----------
per_doc:
Itérable de sorties de ``compute_error_absorption`` (ou
``None`` pour les docs sans GT).
sample_tokens:
Nombre maximal de tokens corrigés/introduits gardés dans
l'échantillon (cap pour ne pas exploser le JSON).
Returns
-------
dict | None
``None`` si aucune entry valide.
"""
docs = [d for d in per_doc if d]
if not docs:
return None
n_gt = sum(int(d.get("n_gt_tokens") or 0) for d in docs)
n_errors_before = sum(int(d.get("n_errors_before") or 0) for d in docs)
n_errors_after = sum(int(d.get("n_errors_after") or 0) for d in docs)
n_corrected = sum(int(d.get("n_corrected") or 0) for d in docs)
n_introduced = sum(int(d.get("n_introduced") or 0) for d in docs)
n_kept_wrong = sum(int(d.get("n_kept_wrong") or 0) for d in docs)
correction_rate = (
n_corrected / n_errors_before if n_errors_before > 0 else None
)
introduction_rate = (
n_introduced / n_errors_after if n_errors_after > 0 else None
)
corrected_sample: list[str] = []
introduced_sample: list[str] = []
for d in docs:
corrected_sample.extend(d.get("corrected_tokens") or [])
introduced_sample.extend(d.get("introduced_tokens") or [])
if (
len(corrected_sample) >= sample_tokens
and len(introduced_sample) >= sample_tokens
):
break
return {
"n_docs": len(docs),
"n_gt_tokens": n_gt,
"n_errors_before": n_errors_before,
"n_errors_after": n_errors_after,
"n_corrected": n_corrected,
"n_introduced": n_introduced,
"n_kept_wrong": n_kept_wrong,
"correction_rate": correction_rate,
"introduction_rate": introduction_rate,
"net_improvement": n_corrected - n_introduced,
"corrected_tokens_sample": corrected_sample[:sample_tokens],
"introduced_tokens_sample": introduced_sample[:sample_tokens],
}
__all__ = [
"compute_error_absorption",
"aggregate_error_absorption",
]
|