qovaryx-350m-scratch-base / modeling_qovaryx.py
tjarvis91's picture
Initial release: Qovaryx random-init scratch base, Apache-2.0
9cdb2ab verified
"""HF PreTrainedModel wrapper around Qovaryx FinanceDecoder.
Loads the random-init Qovaryx compact decoder via the standard HF
AutoModelForCausalLM.from_pretrained API. trust_remote_code=True required.
The underlying class is named FinanceDecoder for legacy reasons; the
architecture is task-agnostic.
"""
from __future__ import annotations
import torch
from torch import nn
from transformers import PreTrainedModel
from transformers.modeling_outputs import CausalLMOutput
from .configuration_qovaryx import QovaryxConfig
from ._qovaryx_decoder import DecoderConfig, FinanceDecoder
def _to_decoder_config(cfg: QovaryxConfig) -> DecoderConfig:
return DecoderConfig(
vocab_size=cfg.vocab_size,
d_model=cfg.d_model,
n_layer=cfg.n_layer,
n_head=cfg.n_head,
n_kv_head=cfg.n_kv_head,
d_ff=cfg.d_ff,
max_seq_len=cfg.max_seq_len,
rope_base=cfg.rope_base,
rms_eps=cfg.rms_eps,
dropout=cfg.dropout,
decision_head_classes=cfg.decision_head_classes,
decision_head_dropout=cfg.decision_head_dropout,
mtp_k=cfg.mtp_k,
mtp_head_kind=cfg.mtp_head_kind,
init_std=cfg.init_std,
tie_word_embeddings=cfg.tie_word_embeddings,
ffn_kind=cfg.ffn_kind,
ffn_rank=cfg.ffn_rank,
ffn_experts=cfg.ffn_experts,
ffn_top_k=cfg.ffn_top_k,
chart_patch_encoder_enabled=cfg.chart_patch_encoder_enabled,
chart_image_size=cfg.chart_image_size,
chart_patch_size=cfg.chart_patch_size,
chart_channels=cfg.chart_channels,
chart_embed_dropout=cfg.chart_embed_dropout,
)
class QovaryxForCausalLM(PreTrainedModel):
config_class = QovaryxConfig
base_model_prefix = "qovaryx"
supports_gradient_checkpointing = False
def __init__(self, config: QovaryxConfig):
super().__init__(config)
self.decoder = FinanceDecoder(_to_decoder_config(config))
self.post_init()
def _init_weights(self, module):
# Already initialized by FinanceDecoder; leave alone.
pass
def get_input_embeddings(self):
return self.decoder.embed
def set_input_embeddings(self, value):
self.decoder.embed = value
def get_output_embeddings(self):
return self.decoder.lm_head
def forward(
self,
input_ids: torch.LongTensor,
labels: torch.LongTensor = None,
attention_mask=None,
**kwargs,
):
out = self.decoder(input_ids)
logits = out.logits if hasattr(out, "logits") else out
loss = None
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss = nn.functional.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
ignore_index=-100,
)
return CausalLMOutput(loss=loss, logits=logits)
def prepare_inputs_for_generation(self, input_ids, **kwargs):
return {"input_ids": input_ids}
@classmethod
def can_generate(cls):
return True