"""Hugging Face AutoConfig support for Ogma models.""" from __future__ import annotations from enum import StrEnum from typing import Any from transformers import PretrainedConfig __all__ = ["OgmaConfig", "VariantType", "PoolingType", "TaskToken"] class VariantType(StrEnum): """Architecture variant identifiers.""" TRANSFORMER = "transformer" DEEP_NARROW = "deep_narrow" CONV = "conv" LINEAR_ATTENTION = "linear_attention" MLP_MIXER = "mlp_mixer" TRANSFORMER_RESA = "transformer_resa" GLA = "gla" class PoolingType(StrEnum): """Pooling strategy identifiers.""" TASK_TOKEN = "task_token" LATENT_ATTENTION = "latent_attention" MEAN = "mean" class TaskToken(StrEnum): """Task token identifiers for asymmetric encoding.""" QRY = "QRY" DOC = "DOC" SYM = "SYM" class OgmaConfig(PretrainedConfig): """Configuration for Ogma embedding models.""" model_type = "ogma" def __init__( self, variant: str | VariantType = VariantType.TRANSFORMER, d_embed: int = 128, d_model: int = 256, n_layers: int = 1, n_heads: int = 4, vocab_size: int = 30_000, max_seq_len: int = 512, matryoshka_dims: list[int] | None = None, pooling: str | PoolingType = PoolingType.TASK_TOKEN, d_output: int = 256, ffn_mult: float = 8 / 3, conv_kernel_size: int = 7, spatial_rank: int = 32, n_random_features: int = 128, dropout: float = 0.0, scorer_type: str = "dot", scorer_alpha_init: float = 0.1, scorer_hidden: int = 0, gla_expand_k: float = 0.5, gla_expand_v: float = 1.0, gla_gate_low_rank_dim: int = 16, gla_gate_logit_normalizer: int = 16, gla_use_short_conv: bool = True, gla_conv_size: int = 4, pad_id: int = 0, unk_id: int = 1, bos_id: int = 2, eos_id: int = 3, qry_id: int = 4, doc_id: int = 5, sym_id: int = 6, n_special_tokens: int = 7, **kwargs: Any, ) -> None: kwargs.setdefault("pad_token_id", pad_id) kwargs.setdefault("bos_token_id", bos_id) kwargs.setdefault("eos_token_id", eos_id) super().__init__(**kwargs) self.variant = VariantType(variant) self.d_embed = d_embed self.d_model = d_model self.n_layers = n_layers self.n_heads = n_heads self.vocab_size = vocab_size self.max_seq_len = max_seq_len self.matryoshka_dims = matryoshka_dims or [32, 64, 128, 256] self.pooling = PoolingType(pooling) self.d_output = d_output self.ffn_mult = ffn_mult self.conv_kernel_size = conv_kernel_size self.spatial_rank = spatial_rank self.n_random_features = n_random_features self.dropout = dropout self.scorer_type = scorer_type self.scorer_alpha_init = scorer_alpha_init self.scorer_hidden = scorer_hidden self.gla_expand_k = gla_expand_k self.gla_expand_v = gla_expand_v self.gla_gate_low_rank_dim = gla_gate_low_rank_dim self.gla_gate_logit_normalizer = gla_gate_logit_normalizer self.gla_use_short_conv = gla_use_short_conv self.gla_conv_size = gla_conv_size self.pad_id = pad_id self.unk_id = unk_id self.bos_id = bos_id self.eos_id = eos_id self.qry_id = qry_id self.doc_id = doc_id self.sym_id = sym_id self.n_special_tokens = n_special_tokens @property def d_head(self) -> int: """Per-head dimension.""" return self.d_model // self.n_heads @property def ffn_hidden(self) -> int: """SwiGLU FFN hidden dimension.""" return int(self.d_model * self.ffn_mult) def task_token_id(self, task: TaskToken | str) -> int: """Return token ID for a task token.""" task = TaskToken(task) return { TaskToken.QRY: self.qry_id, TaskToken.DOC: self.doc_id, TaskToken.SYM: self.sym_id, }[task] def to_dict(self) -> dict[str, Any]: """Serialize config to a JSON-compatible dictionary.""" output = super().to_dict() output["variant"] = self.variant.value output["pooling"] = self.pooling.value return output