""" Standalone loader for solailabs/wmt22-cometkiwi-da-* variants. Unlike the original `load.py`, this does NOT download the gated base `Unbabel/wmt22-cometkiwi-da`. Instead it instantiates an empty COMET model via `load_pretrained_weights=False` and loads the fine-tuned state_dict shipped in this repo. Only ungated assets are fetched at load time (the `microsoft/infoxlm-large` tokenizer + config, ~5 MB on first load, cached after). Usage ----- from huggingface_hub import snapshot_download import sys folder = snapshot_download("solailabs/wmt22-cometkiwi-da-int8") sys.path.insert(0, folder) from load import load_model model = load_model(folder) out = model.predict([{"src": "...", "mt": "..."}], batch_size=8, gpus=0) """ import json import platform from pathlib import Path import torch import yaml from torch.nn import Parameter, ParameterList from comet.models import UnifiedMetric def load_model(folder=None): folder = Path(folder) if folder else Path(__file__).parent cfg = json.loads((folder / "config.json").read_text()) hparams = yaml.safe_load((folder / "hparams.yaml").read_text()) # Filter to keys UnifiedMetric actually accepts (hparams.yaml has extras # like class_identifier / layer / pool / word_weights that are legacy). import inspect accepted = set(inspect.signature(UnifiedMetric.__init__).parameters) hparams = {k: v for k, v in hparams.items() if k in accepted} # Skip the 2 GB gated download. Only the infoxlm-large tokenizer/config # is fetched (ungated, ~5 MB, cached). model = UnifiedMetric(load_pretrained_weights=False, **hparams) keep = cfg["keep_idx"] n_full = cfg.get("orig_num_layers", 24) # Truncate encoder to the pruned shape (no-op for int8 variant where keep == all). if len(keep) != n_full: layers = model.encoder.model.encoder.layer model.encoder.model.encoder.layer = torch.nn.ModuleList([layers[i] for i in keep]) model.encoder.model.config.num_hidden_layers = len(keep) # Rebuild layerwise_attention so scalar_parameters count matches kept layers. la = model.layerwise_attention mix_keep = [0] + [i + 1 for i in keep] la.scalar_parameters = ParameterList([ Parameter(la.scalar_parameters[i].data.clone(), requires_grad=True) for i in mix_keep ]) la.num_layers = len(mix_keep) if hasattr(la, "dropout_mask"): la.dropout_mask = torch.zeros(len(mix_keep)) la.dropout_fill = torch.empty(len(mix_keep)).fill_(-1e20) # Load our fine-tuned (possibly pruned, possibly fp16-stored) weights. state = torch.load(folder / "state_dict.pt", map_location="cpu", weights_only=False) own = model.state_dict() fixed = {} for k, v in state.items(): if k in own and isinstance(v, torch.Tensor) and isinstance(own[k], torch.Tensor) and v.dtype != own[k].dtype: fixed[k] = v.to(own[k].dtype) else: fixed[k] = v model.load_state_dict(fixed, strict=False) # Apply int8 dynamic quantization at load if the variant ships fp16 storage. if cfg.get("quantized") and cfg.get("fp16_storage"): engine = "qnnpack" if platform.machine() in ("arm64", "aarch64") else "fbgemm" torch.backends.quantized.engine = engine model.encoder.model = torch.quantization.quantize_dynamic( model.encoder.model, {torch.nn.Linear}, dtype=torch.qint8 ) model.eval() return model