solailabs commited on
Commit
0858348
·
verified ·
1 Parent(s): 07e2c61

Initial release: wmt22-cometkiwi-da-int8

Browse files
Files changed (4) hide show
  1. README.md +97 -0
  2. config.json +39 -0
  3. load.py +72 -0
  4. state_dict.pt +3 -0
README.md ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ language:
4
+ - multilingual
5
+ tags:
6
+ - translation
7
+ - quality-estimation
8
+ - reference-free
9
+ - comet
10
+ - cometkiwi
11
+ - pruning
12
+ base_model: Unbabel/wmt22-cometkiwi-da
13
+ pipeline_tag: translation
14
+ ---
15
+
16
+ # wmt22-cometkiwi-da-int8
17
+
18
+ A compressed version of [Unbabel/wmt22-cometkiwi-da](https://huggingface.co/Unbabel/wmt22-cometkiwi-da) — a reference-free machine-translation quality estimation model (source + MT only, no human reference required).
19
+
20
+ **Lossless compression** — zero human-Pearson loss, ~40% smaller on disk via int8 alone.
21
+
22
+ ## What's different from the base model
23
+
24
+ - ****No layer pruning** — all 24 XLM-R encoder layers retained. Compression comes entirely from dynamic int8 quantization + fp16 storage.**
25
+ - `layerwise_attention` rebuilt to mix only the surviving layers (embeddings + kept layer outputs).
26
+ - **Dynamic int8 quantization** on the XLM-R encoder + fp16 storage (cast back to fp32 at load before quant). No layer pruning — all 24 encoder layers retained.
27
+
28
+ ## Accuracy
29
+
30
+ Benchmarked on 1200 stratified segments from [RicardoRei/wmt-da-human-evaluation](https://huggingface.co/datasets/RicardoRei/wmt-da-human-evaluation) (reference-free, src+mt only):
31
+
32
+ | Metric | This variant | Full cometkiwi |
33
+ |---|---|---|
34
+ | Pearson r vs human DA | **0.6404** | 0.6402 |
35
+ | Spearman vs human DA | **0.6703** | 0.6698 |
36
+ | Pearson r vs full | **0.9919** | 1.0000 |
37
+ | MAE vs full | **0.0138** | 0.0000 |
38
+ | Params | **565.1M** | 565.1M |
39
+ | On-disk size | **~1130 MB** | ~2200 MB |
40
+
41
+ ### All variants at a glance
42
+
43
+ | Variant | Pearson(human) | Pearson(full) | Size | When to use |
44
+ |---|---|---|---|---|
45
+ | [full base](https://huggingface.co/Unbabel/wmt22-cometkiwi-da) | 0.6402 | 1.0000 | ~2200 MB | reference quality |
46
+ | [`-int8`](https://huggingface.co/solailabs/wmt22-cometkiwi-da-int8) | **0.6404** | 0.9919 | ~1300 MB | **lossless compression** |
47
+ | [`-pruned-k2`](https://huggingface.co/solailabs/wmt22-cometkiwi-da-pruned-k2) | **0.6300** | 0.9784 | ~2100 MB | best-quality pruned |
48
+ | [`-pruned-k4`](https://huggingface.co/solailabs/wmt22-cometkiwi-da-pruned-k4) | 0.5642 | 0.8316 | ~2060 MB | aggressive prune |
49
+ | [`-pruned-k4-xs`](https://huggingface.co/solailabs/wmt22-cometkiwi-da-pruned-k4-xs) | 0.5544 | 0.8113 | ~1030 MB | smallest footprint |
50
+
51
+ ## Usage
52
+
53
+ ```python
54
+ # pip install "unbabel-comet" "setuptools<81" huggingface_hub
55
+ # export HF_TOKEN=<your_token> # must have Unbabel/wmt22-cometkiwi-da access
56
+
57
+ from huggingface_hub import snapshot_download
58
+ import sys
59
+ folder = snapshot_download(repo_id="solailabs/wmt22-cometkiwi-da-int8")
60
+ sys.path.insert(0, folder)
61
+ from load import load_model
62
+
63
+ model = load_model(folder)
64
+ out = model.predict(
65
+ [{{"src": "The meeting has been postponed until next week.",
66
+ "mt": "La réunion a été reportée à la semaine prochaine."}}],
67
+ batch_size=8, gpus=0, progress_bar=False, num_workers=2,
68
+ )
69
+ print(out["scores"])
70
+ ```
71
+
72
+ The loader re-downloads the base cometkiwi, drops the same encoder layers, optionally applies int8 dynamic quantization, then loads the weights shipped in this repo.
73
+
74
+ ## Files
75
+
76
+ - `state_dict.pt` — pruned model weights
77
+ - `config.json` — base model id, kept/dropped layer indices, quant flag, accuracy
78
+ - `load.py` — drop-in loader
79
+ - `README.md` — this file
80
+
81
+ ## Gated base model
82
+
83
+ The base `Unbabel/wmt22-cometkiwi-da` is gated. You must accept its license on the Hub while logged in with the same account your `HF_TOKEN` belongs to — otherwise the base-model download inside `load.py` returns 403.
84
+
85
+ ## Citation
86
+
87
+ **Base model:** [`Unbabel/wmt22-cometkiwi-da`](https://huggingface.co/Unbabel/wmt22-cometkiwi-da) by Unbabel.
88
+
89
+ ```
90
+ @inproceedings{{rei-etal-2022-cometkiwi,
91
+ title = "{{C}}omet{{K}}iwi: {{IST}}-{{U}}nbabel 2022 Submission for the Quality Estimation Shared Task",
92
+ author = "Rei, Ricardo and others",
93
+ booktitle = "WMT 2022",
94
+ }}
95
+ ```
96
+
97
+ Released under the same license as the base model (Apache 2.0).
config.json ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "base_model": "Unbabel/wmt22-cometkiwi-da",
3
+ "orig_num_layers": 24,
4
+ "keep_idx": [
5
+ 0,
6
+ 1,
7
+ 2,
8
+ 3,
9
+ 4,
10
+ 5,
11
+ 6,
12
+ 7,
13
+ 8,
14
+ 9,
15
+ 10,
16
+ 11,
17
+ 12,
18
+ 13,
19
+ 14,
20
+ 15,
21
+ 16,
22
+ 17,
23
+ 18,
24
+ 19,
25
+ 20,
26
+ 21,
27
+ 22,
28
+ 23
29
+ ],
30
+ "dropped": [],
31
+ "tag": "cometkiwi_int8",
32
+ "quantized": true,
33
+ "quant_dtype": "qint8",
34
+ "fp16_storage": true,
35
+ "pearson_vs_full": 0.9919,
36
+ "mae_vs_full": 0.0138,
37
+ "pearson_human": 0.6404,
38
+ "params_M": 565.137435
39
+ }
load.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Drop-in loader for solailabs/wmt22-comet-da-pruned* models.
3
+
4
+ from huggingface_hub import snapshot_download
5
+ import sys
6
+ folder = snapshot_download(repo_id="solailabs/wmt22-comet-da-pruned-k4-int8")
7
+ sys.path.insert(0, folder)
8
+ from load import load_model
9
+ model = load_model()
10
+ print(model.predict([{"src": "...", "mt": "...", "ref": "..."}], gpus=0)["scores"])
11
+ """
12
+ import json
13
+ import platform
14
+ from pathlib import Path
15
+
16
+ import torch
17
+ from comet import download_model, load_from_checkpoint
18
+ from torch.nn import Parameter, ParameterList
19
+
20
+
21
+ def load_model(folder: str | Path | None = None):
22
+ """Reconstruct the pruned (and optionally int8-quantized) COMET model."""
23
+ folder = Path(folder) if folder else Path(__file__).parent
24
+ cfg = json.loads((folder / "config.json").read_text())
25
+
26
+ base_ckpt = download_model(cfg["base_model"])
27
+ model = load_from_checkpoint(base_ckpt)
28
+
29
+ keep = cfg["keep_idx"]
30
+ layers = model.encoder.model.encoder.layer
31
+ model.encoder.model.encoder.layer = torch.nn.ModuleList([layers[i] for i in keep])
32
+ model.encoder.model.config.num_hidden_layers = len(keep)
33
+
34
+ la = model.layerwise_attention
35
+ mix_keep = [0] + [i + 1 for i in keep]
36
+ la.scalar_parameters = ParameterList([
37
+ Parameter(la.scalar_parameters[i].data.clone(), requires_grad=True)
38
+ for i in mix_keep
39
+ ])
40
+ la.num_layers = len(mix_keep)
41
+ if hasattr(la, "dropout_mask"):
42
+ la.dropout_mask = torch.zeros(len(mix_keep))
43
+ la.dropout_fill = torch.empty(len(mix_keep)).fill_(-1e20)
44
+
45
+ quantize_at_load = cfg.get("quantized") and cfg.get("fp16_storage")
46
+ if cfg.get("quantized") and not quantize_at_load:
47
+ # Legacy path: state_dict contains already-quantized packed params
48
+ engine = "qnnpack" if platform.machine() in ("arm64", "aarch64") else "fbgemm"
49
+ torch.backends.quantized.engine = engine
50
+ model.encoder.model = torch.quantization.quantize_dynamic(
51
+ model.encoder.model, {torch.nn.Linear}, dtype=torch.qint8
52
+ )
53
+
54
+ state = torch.load(folder / "state_dict.pt", map_location="cpu", weights_only=False)
55
+ own = model.state_dict()
56
+ fixed = {}
57
+ for k, v in state.items():
58
+ if k in own and isinstance(v, torch.Tensor) and isinstance(own[k], torch.Tensor) and v.dtype != own[k].dtype:
59
+ fixed[k] = v.to(own[k].dtype)
60
+ else:
61
+ fixed[k] = v
62
+ model.load_state_dict(fixed, strict=False)
63
+
64
+ if quantize_at_load:
65
+ # Quantize AFTER loading fp16/fp32 weights
66
+ engine = "qnnpack" if platform.machine() in ("arm64", "aarch64") else "fbgemm"
67
+ torch.backends.quantized.engine = engine
68
+ model.encoder.model = torch.quantization.quantize_dynamic(
69
+ model.encoder.model, {torch.nn.Linear}, dtype=torch.qint8
70
+ )
71
+ model.eval()
72
+ return model
state_dict.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:720b028f0fc062623a63fefb5f6564289e9a6107fc293cdbc6ef79031072b929
3
+ size 1130416312