File size: 2,542 Bytes
b0b1bdf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
"""Courbes de distribution de la performance (Sprint 7).

- :func:`compute_reliability_curve` — pour les X % docs les plus
  faciles, quel est le CER moyen ? Révèle si un moteur a un long
  tail catastrophique.
- :func:`compute_venn_data` — cardinalités pour un diagramme de
  Venn 2 ou 3 moteurs sur les ensembles d'erreurs commises.
"""

from __future__ import annotations


def compute_reliability_curve(
    cer_values: list[float],
    steps: int = 20,
) -> list[dict]:
    """Pour les X% documents les plus faciles, quel est le CER moyen ?

    Returns
    -------
    Liste de {pct_docs: float, mean_cer: float}
    """
    if not cer_values:
        return []
    sorted_cer = sorted(cer_values)
    n = len(sorted_cer)
    points = []
    for step in range(1, steps + 1):
        pct = step / steps
        cutoff = max(1, int(pct * n))
        subset = sorted_cer[:cutoff]
        mean_cer = sum(subset) / len(subset)
        points.append({"pct_docs": round(pct * 100, 1), "mean_cer": round(mean_cer, 6)})
    return points


def compute_venn_data(
    engine_error_sets: dict[str, set[str]],
) -> dict:
    """Calcule les cardinalités pour un diagramme de Venn entre 2 ou 3 concurrents.

    Parameters
    ----------
    engine_error_sets : {engine_name → set of doc_id:error_token_pair strings}

    Returns
    -------
    Pour 2 concurrents :
      {only_a, only_b, both, label_a, label_b}
    Pour 3 concurrents :
      {only_a, only_b, only_c, ab, ac, bc, abc, label_a, label_b, label_c}
    """
    names = list(engine_error_sets.keys())[:3]  # max 3 pour Venn lisible
    if len(names) < 2:
        return {}

    sets = {n: engine_error_sets[n] for n in names}

    if len(names) == 2:
        a, b = names
        sa, sb = sets[a], sets[b]
        return {
            "type": "venn2",
            "label_a": a,
            "label_b": b,
            "only_a": len(sa - sb),
            "only_b": len(sb - sa),
            "both": len(sa & sb),
        }
    else:
        a, b, c = names
        sa, sb, sc = sets[a], sets[b], sets[c]
        return {
            "type": "venn3",
            "label_a": a,
            "label_b": b,
            "label_c": c,
            "only_a": len(sa - sb - sc),
            "only_b": len(sb - sa - sc),
            "only_c": len(sc - sa - sb),
            "ab": len((sa & sb) - sc),
            "ac": len((sa & sc) - sb),
            "bc": len((sb & sc) - sa),
            "abc": len(sa & sb & sc),
        }


__all__ = ["compute_reliability_curve", "compute_venn_data"]