""" Drop-in loader for solailabs/wmt22-comet-da-pruned* models. from huggingface_hub import snapshot_download import sys folder = snapshot_download(repo_id="solailabs/wmt22-comet-da-pruned-k4-int8") sys.path.insert(0, folder) from load import load_model model = load_model() print(model.predict([{"src": "...", "mt": "...", "ref": "..."}], gpus=0)["scores"]) """ import json import platform from pathlib import Path import torch from comet import download_model, load_from_checkpoint from torch.nn import Parameter, ParameterList def load_model(folder: str | Path | None = None): """Reconstruct the pruned (and optionally int8-quantized) COMET model.""" folder = Path(folder) if folder else Path(__file__).parent cfg = json.loads((folder / "config.json").read_text()) base_ckpt = download_model(cfg["base_model"]) model = load_from_checkpoint(base_ckpt) keep = cfg["keep_idx"] layers = model.encoder.model.encoder.layer model.encoder.model.encoder.layer = torch.nn.ModuleList([layers[i] for i in keep]) model.encoder.model.config.num_hidden_layers = len(keep) la = model.layerwise_attention mix_keep = [0] + [i + 1 for i in keep] la.scalar_parameters = ParameterList([ Parameter(la.scalar_parameters[i].data.clone(), requires_grad=True) for i in mix_keep ]) la.num_layers = len(mix_keep) if hasattr(la, "dropout_mask"): la.dropout_mask = torch.zeros(len(mix_keep)) la.dropout_fill = torch.empty(len(mix_keep)).fill_(-1e20) quantize_at_load = cfg.get("quantized") and cfg.get("fp16_storage") if cfg.get("quantized") and not quantize_at_load: # Legacy path: state_dict contains already-quantized packed params engine = "qnnpack" if platform.machine() in ("arm64", "aarch64") else "fbgemm" torch.backends.quantized.engine = engine model.encoder.model = torch.quantization.quantize_dynamic( model.encoder.model, {torch.nn.Linear}, dtype=torch.qint8 ) state = torch.load(folder / "state_dict.pt", map_location="cpu", weights_only=False) own = model.state_dict() fixed = {} for k, v in state.items(): if k in own and isinstance(v, torch.Tensor) and isinstance(own[k], torch.Tensor) and v.dtype != own[k].dtype: fixed[k] = v.to(own[k].dtype) else: fixed[k] = v model.load_state_dict(fixed, strict=False) if quantize_at_load: # Quantize AFTER loading fp16/fp32 weights engine = "qnnpack" if platform.machine() in ("arm64", "aarch64") else "fbgemm" torch.backends.quantized.engine = engine model.encoder.model = torch.quantization.quantize_dynamic( model.encoder.model, {torch.nn.Linear}, dtype=torch.qint8 ) model.eval() return model