"""HF PreTrainedModel wrapper around Qovaryx FinanceDecoder. Loads the random-init Qovaryx compact decoder via the standard HF AutoModelForCausalLM.from_pretrained API. trust_remote_code=True required. The underlying class is named FinanceDecoder for legacy reasons; the architecture is task-agnostic. """ from __future__ import annotations import torch from torch import nn from transformers import PreTrainedModel from transformers.modeling_outputs import CausalLMOutput from .configuration_qovaryx import QovaryxConfig from ._qovaryx_decoder import DecoderConfig, FinanceDecoder def _to_decoder_config(cfg: QovaryxConfig) -> DecoderConfig: return DecoderConfig( vocab_size=cfg.vocab_size, d_model=cfg.d_model, n_layer=cfg.n_layer, n_head=cfg.n_head, n_kv_head=cfg.n_kv_head, d_ff=cfg.d_ff, max_seq_len=cfg.max_seq_len, rope_base=cfg.rope_base, rms_eps=cfg.rms_eps, dropout=cfg.dropout, decision_head_classes=cfg.decision_head_classes, decision_head_dropout=cfg.decision_head_dropout, mtp_k=cfg.mtp_k, mtp_head_kind=cfg.mtp_head_kind, init_std=cfg.init_std, tie_word_embeddings=cfg.tie_word_embeddings, ffn_kind=cfg.ffn_kind, ffn_rank=cfg.ffn_rank, ffn_experts=cfg.ffn_experts, ffn_top_k=cfg.ffn_top_k, chart_patch_encoder_enabled=cfg.chart_patch_encoder_enabled, chart_image_size=cfg.chart_image_size, chart_patch_size=cfg.chart_patch_size, chart_channels=cfg.chart_channels, chart_embed_dropout=cfg.chart_embed_dropout, ) class QovaryxForCausalLM(PreTrainedModel): config_class = QovaryxConfig base_model_prefix = "qovaryx" supports_gradient_checkpointing = False def __init__(self, config: QovaryxConfig): super().__init__(config) self.decoder = FinanceDecoder(_to_decoder_config(config)) self.post_init() def _init_weights(self, module): # Already initialized by FinanceDecoder; leave alone. pass def get_input_embeddings(self): return self.decoder.embed def set_input_embeddings(self, value): self.decoder.embed = value def get_output_embeddings(self): return self.decoder.lm_head def forward( self, input_ids: torch.LongTensor, labels: torch.LongTensor = None, attention_mask=None, **kwargs, ): out = self.decoder(input_ids) logits = out.logits if hasattr(out, "logits") else out loss = None if labels is not None: shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss = nn.functional.cross_entropy( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1), ignore_index=-100, ) return CausalLMOutput(loss=loss, logits=logits) def prepare_inputs_for_generation(self, input_ids, **kwargs): return {"input_ids": input_ids} @classmethod def can_generate(cls): return True