Instructions to use solailabs/wmt22-cometkiwi-da-int8 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- COMET
How to use solailabs/wmt22-cometkiwi-da-int8 with COMET:
# No code snippets available yet for this library. # To use this model, check the repository files and the library's documentation. # Want to help? PRs adding snippets are welcome at: # https://github.com/huggingface/huggingface.js
- Notebooks
- Google Colab
- Kaggle
Initial release: wmt22-cometkiwi-da-int8
Browse files- README.md +97 -0
- config.json +39 -0
- load.py +72 -0
- state_dict.pt +3 -0
README.md
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: apache-2.0
|
| 3 |
+
language:
|
| 4 |
+
- multilingual
|
| 5 |
+
tags:
|
| 6 |
+
- translation
|
| 7 |
+
- quality-estimation
|
| 8 |
+
- reference-free
|
| 9 |
+
- comet
|
| 10 |
+
- cometkiwi
|
| 11 |
+
- pruning
|
| 12 |
+
base_model: Unbabel/wmt22-cometkiwi-da
|
| 13 |
+
pipeline_tag: translation
|
| 14 |
+
---
|
| 15 |
+
|
| 16 |
+
# wmt22-cometkiwi-da-int8
|
| 17 |
+
|
| 18 |
+
A compressed version of [Unbabel/wmt22-cometkiwi-da](https://huggingface.co/Unbabel/wmt22-cometkiwi-da) — a reference-free machine-translation quality estimation model (source + MT only, no human reference required).
|
| 19 |
+
|
| 20 |
+
**Lossless compression** — zero human-Pearson loss, ~40% smaller on disk via int8 alone.
|
| 21 |
+
|
| 22 |
+
## What's different from the base model
|
| 23 |
+
|
| 24 |
+
- ****No layer pruning** — all 24 XLM-R encoder layers retained. Compression comes entirely from dynamic int8 quantization + fp16 storage.**
|
| 25 |
+
- `layerwise_attention` rebuilt to mix only the surviving layers (embeddings + kept layer outputs).
|
| 26 |
+
- **Dynamic int8 quantization** on the XLM-R encoder + fp16 storage (cast back to fp32 at load before quant). No layer pruning — all 24 encoder layers retained.
|
| 27 |
+
|
| 28 |
+
## Accuracy
|
| 29 |
+
|
| 30 |
+
Benchmarked on 1200 stratified segments from [RicardoRei/wmt-da-human-evaluation](https://huggingface.co/datasets/RicardoRei/wmt-da-human-evaluation) (reference-free, src+mt only):
|
| 31 |
+
|
| 32 |
+
| Metric | This variant | Full cometkiwi |
|
| 33 |
+
|---|---|---|
|
| 34 |
+
| Pearson r vs human DA | **0.6404** | 0.6402 |
|
| 35 |
+
| Spearman vs human DA | **0.6703** | 0.6698 |
|
| 36 |
+
| Pearson r vs full | **0.9919** | 1.0000 |
|
| 37 |
+
| MAE vs full | **0.0138** | 0.0000 |
|
| 38 |
+
| Params | **565.1M** | 565.1M |
|
| 39 |
+
| On-disk size | **~1130 MB** | ~2200 MB |
|
| 40 |
+
|
| 41 |
+
### All variants at a glance
|
| 42 |
+
|
| 43 |
+
| Variant | Pearson(human) | Pearson(full) | Size | When to use |
|
| 44 |
+
|---|---|---|---|---|
|
| 45 |
+
| [full base](https://huggingface.co/Unbabel/wmt22-cometkiwi-da) | 0.6402 | 1.0000 | ~2200 MB | reference quality |
|
| 46 |
+
| [`-int8`](https://huggingface.co/solailabs/wmt22-cometkiwi-da-int8) | **0.6404** | 0.9919 | ~1300 MB | **lossless compression** |
|
| 47 |
+
| [`-pruned-k2`](https://huggingface.co/solailabs/wmt22-cometkiwi-da-pruned-k2) | **0.6300** | 0.9784 | ~2100 MB | best-quality pruned |
|
| 48 |
+
| [`-pruned-k4`](https://huggingface.co/solailabs/wmt22-cometkiwi-da-pruned-k4) | 0.5642 | 0.8316 | ~2060 MB | aggressive prune |
|
| 49 |
+
| [`-pruned-k4-xs`](https://huggingface.co/solailabs/wmt22-cometkiwi-da-pruned-k4-xs) | 0.5544 | 0.8113 | ~1030 MB | smallest footprint |
|
| 50 |
+
|
| 51 |
+
## Usage
|
| 52 |
+
|
| 53 |
+
```python
|
| 54 |
+
# pip install "unbabel-comet" "setuptools<81" huggingface_hub
|
| 55 |
+
# export HF_TOKEN=<your_token> # must have Unbabel/wmt22-cometkiwi-da access
|
| 56 |
+
|
| 57 |
+
from huggingface_hub import snapshot_download
|
| 58 |
+
import sys
|
| 59 |
+
folder = snapshot_download(repo_id="solailabs/wmt22-cometkiwi-da-int8")
|
| 60 |
+
sys.path.insert(0, folder)
|
| 61 |
+
from load import load_model
|
| 62 |
+
|
| 63 |
+
model = load_model(folder)
|
| 64 |
+
out = model.predict(
|
| 65 |
+
[{{"src": "The meeting has been postponed until next week.",
|
| 66 |
+
"mt": "La réunion a été reportée à la semaine prochaine."}}],
|
| 67 |
+
batch_size=8, gpus=0, progress_bar=False, num_workers=2,
|
| 68 |
+
)
|
| 69 |
+
print(out["scores"])
|
| 70 |
+
```
|
| 71 |
+
|
| 72 |
+
The loader re-downloads the base cometkiwi, drops the same encoder layers, optionally applies int8 dynamic quantization, then loads the weights shipped in this repo.
|
| 73 |
+
|
| 74 |
+
## Files
|
| 75 |
+
|
| 76 |
+
- `state_dict.pt` — pruned model weights
|
| 77 |
+
- `config.json` — base model id, kept/dropped layer indices, quant flag, accuracy
|
| 78 |
+
- `load.py` — drop-in loader
|
| 79 |
+
- `README.md` — this file
|
| 80 |
+
|
| 81 |
+
## Gated base model
|
| 82 |
+
|
| 83 |
+
The base `Unbabel/wmt22-cometkiwi-da` is gated. You must accept its license on the Hub while logged in with the same account your `HF_TOKEN` belongs to — otherwise the base-model download inside `load.py` returns 403.
|
| 84 |
+
|
| 85 |
+
## Citation
|
| 86 |
+
|
| 87 |
+
**Base model:** [`Unbabel/wmt22-cometkiwi-da`](https://huggingface.co/Unbabel/wmt22-cometkiwi-da) by Unbabel.
|
| 88 |
+
|
| 89 |
+
```
|
| 90 |
+
@inproceedings{{rei-etal-2022-cometkiwi,
|
| 91 |
+
title = "{{C}}omet{{K}}iwi: {{IST}}-{{U}}nbabel 2022 Submission for the Quality Estimation Shared Task",
|
| 92 |
+
author = "Rei, Ricardo and others",
|
| 93 |
+
booktitle = "WMT 2022",
|
| 94 |
+
}}
|
| 95 |
+
```
|
| 96 |
+
|
| 97 |
+
Released under the same license as the base model (Apache 2.0).
|
config.json
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"base_model": "Unbabel/wmt22-cometkiwi-da",
|
| 3 |
+
"orig_num_layers": 24,
|
| 4 |
+
"keep_idx": [
|
| 5 |
+
0,
|
| 6 |
+
1,
|
| 7 |
+
2,
|
| 8 |
+
3,
|
| 9 |
+
4,
|
| 10 |
+
5,
|
| 11 |
+
6,
|
| 12 |
+
7,
|
| 13 |
+
8,
|
| 14 |
+
9,
|
| 15 |
+
10,
|
| 16 |
+
11,
|
| 17 |
+
12,
|
| 18 |
+
13,
|
| 19 |
+
14,
|
| 20 |
+
15,
|
| 21 |
+
16,
|
| 22 |
+
17,
|
| 23 |
+
18,
|
| 24 |
+
19,
|
| 25 |
+
20,
|
| 26 |
+
21,
|
| 27 |
+
22,
|
| 28 |
+
23
|
| 29 |
+
],
|
| 30 |
+
"dropped": [],
|
| 31 |
+
"tag": "cometkiwi_int8",
|
| 32 |
+
"quantized": true,
|
| 33 |
+
"quant_dtype": "qint8",
|
| 34 |
+
"fp16_storage": true,
|
| 35 |
+
"pearson_vs_full": 0.9919,
|
| 36 |
+
"mae_vs_full": 0.0138,
|
| 37 |
+
"pearson_human": 0.6404,
|
| 38 |
+
"params_M": 565.137435
|
| 39 |
+
}
|
load.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Drop-in loader for solailabs/wmt22-comet-da-pruned* models.
|
| 3 |
+
|
| 4 |
+
from huggingface_hub import snapshot_download
|
| 5 |
+
import sys
|
| 6 |
+
folder = snapshot_download(repo_id="solailabs/wmt22-comet-da-pruned-k4-int8")
|
| 7 |
+
sys.path.insert(0, folder)
|
| 8 |
+
from load import load_model
|
| 9 |
+
model = load_model()
|
| 10 |
+
print(model.predict([{"src": "...", "mt": "...", "ref": "..."}], gpus=0)["scores"])
|
| 11 |
+
"""
|
| 12 |
+
import json
|
| 13 |
+
import platform
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
from comet import download_model, load_from_checkpoint
|
| 18 |
+
from torch.nn import Parameter, ParameterList
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def load_model(folder: str | Path | None = None):
|
| 22 |
+
"""Reconstruct the pruned (and optionally int8-quantized) COMET model."""
|
| 23 |
+
folder = Path(folder) if folder else Path(__file__).parent
|
| 24 |
+
cfg = json.loads((folder / "config.json").read_text())
|
| 25 |
+
|
| 26 |
+
base_ckpt = download_model(cfg["base_model"])
|
| 27 |
+
model = load_from_checkpoint(base_ckpt)
|
| 28 |
+
|
| 29 |
+
keep = cfg["keep_idx"]
|
| 30 |
+
layers = model.encoder.model.encoder.layer
|
| 31 |
+
model.encoder.model.encoder.layer = torch.nn.ModuleList([layers[i] for i in keep])
|
| 32 |
+
model.encoder.model.config.num_hidden_layers = len(keep)
|
| 33 |
+
|
| 34 |
+
la = model.layerwise_attention
|
| 35 |
+
mix_keep = [0] + [i + 1 for i in keep]
|
| 36 |
+
la.scalar_parameters = ParameterList([
|
| 37 |
+
Parameter(la.scalar_parameters[i].data.clone(), requires_grad=True)
|
| 38 |
+
for i in mix_keep
|
| 39 |
+
])
|
| 40 |
+
la.num_layers = len(mix_keep)
|
| 41 |
+
if hasattr(la, "dropout_mask"):
|
| 42 |
+
la.dropout_mask = torch.zeros(len(mix_keep))
|
| 43 |
+
la.dropout_fill = torch.empty(len(mix_keep)).fill_(-1e20)
|
| 44 |
+
|
| 45 |
+
quantize_at_load = cfg.get("quantized") and cfg.get("fp16_storage")
|
| 46 |
+
if cfg.get("quantized") and not quantize_at_load:
|
| 47 |
+
# Legacy path: state_dict contains already-quantized packed params
|
| 48 |
+
engine = "qnnpack" if platform.machine() in ("arm64", "aarch64") else "fbgemm"
|
| 49 |
+
torch.backends.quantized.engine = engine
|
| 50 |
+
model.encoder.model = torch.quantization.quantize_dynamic(
|
| 51 |
+
model.encoder.model, {torch.nn.Linear}, dtype=torch.qint8
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
state = torch.load(folder / "state_dict.pt", map_location="cpu", weights_only=False)
|
| 55 |
+
own = model.state_dict()
|
| 56 |
+
fixed = {}
|
| 57 |
+
for k, v in state.items():
|
| 58 |
+
if k in own and isinstance(v, torch.Tensor) and isinstance(own[k], torch.Tensor) and v.dtype != own[k].dtype:
|
| 59 |
+
fixed[k] = v.to(own[k].dtype)
|
| 60 |
+
else:
|
| 61 |
+
fixed[k] = v
|
| 62 |
+
model.load_state_dict(fixed, strict=False)
|
| 63 |
+
|
| 64 |
+
if quantize_at_load:
|
| 65 |
+
# Quantize AFTER loading fp16/fp32 weights
|
| 66 |
+
engine = "qnnpack" if platform.machine() in ("arm64", "aarch64") else "fbgemm"
|
| 67 |
+
torch.backends.quantized.engine = engine
|
| 68 |
+
model.encoder.model = torch.quantization.quantize_dynamic(
|
| 69 |
+
model.encoder.model, {torch.nn.Linear}, dtype=torch.qint8
|
| 70 |
+
)
|
| 71 |
+
model.eval()
|
| 72 |
+
return model
|
state_dict.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:720b028f0fc062623a63fefb5f6564289e9a6107fc293cdbc6ef79031072b929
|
| 3 |
+
size 1130416312
|