#!/usr/bin/env python3 """ NEXUS-VideoModel v1.1 - HuggingFace Hub Modeling File Auto-generated for dynamic loading and fine-tuning. This file contains the complete model architecture for loading from HuggingFace Hub. """ import os import math import json import logging from typing import Dict, List, Optional, Tuple, Any, Union from dataclasses import dataclass, field import torch import torch.nn as nn import torch.nn.functional as F logger = logging.getLogger(__name__) # ============================================================================== # CONFIGURATION # ============================================================================== @dataclass class VideoConfig: height: int = 64 width: int = 64 channels: int = 3 n_frames: int = 16 fps: int = 8 temporal_downsample: int = 4 spatial_downsample: int = 8 @classmethod def from_dict(cls, d: dict) -> "VideoConfig": valid_keys = {'height', 'width', 'channels', 'n_frames', 'fps', 'temporal_downsample', 'spatial_downsample'} return cls(**{k: v for k, v in d.items() if k in valid_keys}) @dataclass class NexusVideoModelConfig: model_type: str = "nexus-videomodel" version: str = "1.1" codename: str = "VideoSim-Cognitive" architecture_type: str = "cognitive-video" d_model: int = 512 d_ff: int = 2048 n_layers: int = 8 n_heads: int = 8 dropout: float = 0.1 latent_dim: int = 256 temporal_latent_dim: int = 512 latent_channels: int = 4 max_frames: int = 64 context_frames: int = 16 prediction_frames: int = 8 encoder_channels: List[int] = field(default_factory=lambda: [64, 128, 256, 512]) decoder_channels: List[int] = field(default_factory=lambda: [512, 256, 128, 64]) kl_weight: float = 0.0001 use_lpol: bool = True memory_size: int = 256 memory_slots_per_domain: int = 32 memory_k: int = 8 domain_types: List[str] = field(default_factory=lambda: [ 'motion', 'appearance', 'temporal', 'spatial', 'object', 'scene', 'action', 'causality', 'physics' ]) use_gqa: bool = True gqa_num_heads: int = 8 gqa_num_kv_groups: int = 2 expert_types: List[str] = field(default_factory=lambda: [ 'Motion', 'Appearance', 'Temporal', 'Spatial', 'Prediction', 'Generation' ]) max_experts: int = 12 growth_threshold_coherence: float = 0.3 growth_patience: int = 10 neurogenesis_enabled: bool = True min_neurons: int = 32 max_neurons: int = 256 neuron_birth_threshold: float = 0.8 neuron_death_threshold: float = 0.05 energy_enabled: bool = True energy_cost_encode: float = 0.01 energy_cost_decode: float = 0.02 energy_cost_predict: float = 0.03 energy_regeneration: float = 0.05 dream_enabled: bool = True dream_cycle_length: int = 100 dream_duration: int = 20 temporal_coherence_weight: float = 0.1 flow_prediction: bool = True perceptual_loss_weight: float = 0.05 batch_size: int = 8 learning_rate: float = 1e-4 epochs: int = 50 gradient_accumulation: int = 4 warmup_steps: int = 500 push_to_hub: bool = True hub_model_id: str = "amewebstudio/nexus-videomodel-v1.1" video: VideoConfig = field(default_factory=VideoConfig) def to_dict(self) -> Dict: d = {} for key, value in self.__dict__.items(): if key == 'video': d[key] = {k: v for k, v in value.__dict__.items()} elif isinstance(value, (list, dict, str, int, float, bool, type(None))): d[key] = value return d @classmethod def from_dict(cls, d: Dict) -> "NexusVideoModelConfig": d = d.copy() if 'video' in d and isinstance(d['video'], dict): d['video'] = VideoConfig.from_dict(d['video']) for k in ['_dynamic_state', '_class_name', '_version', '_architecture']: d.pop(k, None) known = set(cls.__dataclass_fields__.keys()) return cls(**{k: v for k, v in d.items() if k in known}) # ============================================================================== # BUILDING BLOCKS # ============================================================================== class CausalConv3d(nn.Module): def __init__(self, in_channels: int, out_channels: int, kernel_size: Tuple[int, int, int] = (3, 3, 3), stride: Tuple[int, int, int] = (1, 1, 1)): super().__init__() if isinstance(kernel_size, int): kernel_size = (kernel_size, kernel_size, kernel_size) if isinstance(stride, int): stride = (stride, stride, stride) self.kernel_size = kernel_size self.stride = stride self.temporal_pad = kernel_size[0] - 1 self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=(0, kernel_size[1] // 2, kernel_size[2] // 2)) def forward(self, x: torch.Tensor) -> torch.Tensor: if self.temporal_pad > 0: x = F.pad(x, (0, 0, 0, 0, self.temporal_pad, 0), mode='replicate') return self.conv(x) class SpatioTemporalBlock(nn.Module): def __init__(self, channels: int): super().__init__() self.scale = nn.Parameter(torch.ones(1)) self.spatial = nn.Sequential( nn.GroupNorm(min(8, channels), channels), nn.SiLU(), nn.Conv3d(channels, channels, kernel_size=(1, 3, 3), padding=(0, 1, 1)) ) self.temporal = nn.Sequential( nn.GroupNorm(min(8, channels), channels), nn.SiLU(), CausalConv3d(channels, channels, kernel_size=(3, 1, 1)) ) def forward(self, x: torch.Tensor) -> torch.Tensor: h = self.spatial(x) h = self.temporal(h) return x + self.scale * h class EncoderStage(nn.ModuleList): def __init__(self, in_channels: int, out_channels: int, temporal_down: bool = True): stride = (2, 2, 2) if temporal_down else (1, 2, 2) super().__init__([ SpatioTemporalBlock(in_channels), SpatioTemporalBlock(in_channels), nn.Conv3d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1) ]) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self[0](x) x = self[1](x) x = self[2](x) return x class DecoderStage(nn.ModuleList): def __init__(self, in_channels: int, out_channels: int, temporal_up: bool = True): stride = (2, 2, 2) if temporal_up else (1, 2, 2) output_padding = (1, 1, 1) if temporal_up else (0, 1, 1) super().__init__([ SpatioTemporalBlock(in_channels), SpatioTemporalBlock(in_channels), nn.ConvTranspose3d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, output_padding=output_padding) ]) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self[0](x) x = self[1](x) x = self[2](x) return x class RotaryPositionalEmbedding(nn.Module): def __init__(self, dim: int, max_seq_len: int = 1024, base: int = 10000): super().__init__() self.dim = dim self.max_seq_len = max_seq_len inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer('inv_freq', inv_freq, persistent=False) self._build_cache(max_seq_len) def _build_cache(self, seq_len: int): t = torch.arange(seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype) freqs = torch.einsum('i,j->ij', t, self.inv_freq) emb = torch.cat([freqs, freqs], dim=-1) self.register_buffer('cos_cache', emb.cos(), persistent=False) self.register_buffer('sin_cache', emb.sin(), persistent=False) def forward(self, seq_len: int, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]: if seq_len > self.cos_cache.shape[0]: self._build_cache(seq_len) return self.cos_cache[:seq_len].to(device), self.sin_cache[:seq_len].to(device) def rotate_half(x: torch.Tensor) -> torch.Tensor: x1, x2 = x.chunk(2, dim=-1) return torch.cat([-x2, x1], dim=-1) # ============================================================================== # VIDEO VAE # ============================================================================== class VideoVAEEncoder(nn.Module): def __init__(self, config: NexusVideoModelConfig): super().__init__() self.config = config channels = config.encoder_channels self.input_conv = nn.Conv3d(config.video.channels, channels[0], kernel_size=(1, 3, 3), padding=(0, 1, 1)) self.encoders = nn.ModuleList([ EncoderStage(channels[0], channels[1], temporal_down=True), EncoderStage(channels[1], channels[2], temporal_down=True), EncoderStage(channels[2], channels[3], temporal_down=False), ]) self.final_blocks = nn.ModuleList([SpatioTemporalBlock(channels[-1]), SpatioTemporalBlock(channels[-1])]) self.to_mu = nn.Conv3d(channels[-1], config.latent_channels, 1) self.to_logvar = nn.Conv3d(channels[-1], config.latent_channels, 1) self.to_d_model = nn.Linear(config.latent_channels, config.d_model) self.adaptive_proj = None def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: if x.dim() == 5 and x.shape[2] == self.config.video.channels: x = x.permute(0, 2, 1, 3, 4) B, C, T, H, W = x.shape h = self.input_conv(x) for encoder_stage in self.encoders: h = encoder_stage(h) for block in self.final_blocks: h = block(h) mu_spatial = self.to_mu(h) logvar_spatial = self.to_logvar(h).clamp(-10, 10) std = torch.exp(0.5 * logvar_spatial) eps = torch.randn_like(std) z_spatial = mu_spatial + eps * std B, C_lat, T_lat, H_lat, W_lat = z_spatial.shape z_flat = z_spatial.permute(0, 2, 1, 3, 4).reshape(B, T_lat, -1) if self.adaptive_proj is None or self.adaptive_proj.in_features != z_flat.shape[-1]: self.adaptive_proj = nn.Linear(z_flat.shape[-1], self.config.d_model).to(z_flat.device) z = self.adaptive_proj(z_flat) mu_flat = mu_spatial.permute(0, 2, 1, 3, 4).reshape(B, T_lat, -1) mu = self.adaptive_proj(mu_flat) logvar_flat = logvar_spatial.permute(0, 2, 1, 3, 4).reshape(B, T_lat, -1) logvar = self.adaptive_proj(logvar_flat) return z, mu, logvar, z_spatial class VideoVAEDecoder(nn.Module): def __init__(self, config: NexusVideoModelConfig): super().__init__() self.config = config channels = config.decoder_channels self.from_latent = nn.Conv3d(config.latent_channels, channels[0], 1) self.init_blocks = nn.ModuleList([SpatioTemporalBlock(channels[0]), SpatioTemporalBlock(channels[0])]) self.decoders = nn.ModuleList([ DecoderStage(channels[0], channels[1], temporal_up=True), DecoderStage(channels[1], channels[2], temporal_up=True), DecoderStage(channels[2], channels[3], temporal_up=False), ]) self.final_blocks = nn.ModuleList([SpatioTemporalBlock(channels[-1]), SpatioTemporalBlock(channels[-1])]) self.to_rgb = nn.Sequential( nn.Conv3d(channels[-1], channels[-1] // 2, (1, 3, 3), padding=(0, 1, 1)), nn.SiLU(), nn.Conv3d(channels[-1] // 2, config.video.channels, (1, 3, 3), padding=(0, 1, 1)), nn.Sigmoid() ) self.temporal_refine = nn.Sequential( CausalConv3d(config.video.channels, 32, (3, 3, 3)), nn.SiLU(), nn.Conv3d(32, config.video.channels, 1), nn.Tanh() ) self.refine_scale = nn.Parameter(torch.tensor(0.05)) def forward(self, z: torch.Tensor) -> torch.Tensor: h = self.from_latent(z) for block in self.init_blocks: h = block(h) for decoder_stage in self.decoders: h = decoder_stage(h) for block in self.final_blocks: h = block(h) video = self.to_rgb(h) refine = self.temporal_refine(video) * self.refine_scale video = torch.clamp(video + refine, 0, 1) return video # ============================================================================== # GQA ATTENTION # ============================================================================== class VideoGQA(nn.Module): def __init__(self, config: NexusVideoModelConfig): super().__init__() self.d_model = config.d_model self.num_heads = config.gqa_num_heads self.num_kv_groups = config.gqa_num_kv_groups self.head_dim = config.d_model // self.num_heads self.heads_per_group = self.num_heads // self.num_kv_groups self.scale = self.head_dim ** -0.5 self.q_proj = nn.Linear(config.d_model, self.num_heads * self.head_dim, bias=False) self.k_proj = nn.Linear(config.d_model, self.num_kv_groups * self.head_dim, bias=False) self.v_proj = nn.Linear(config.d_model, self.num_kv_groups * self.head_dim, bias=False) self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.d_model, bias=False) self.rope = RotaryPositionalEmbedding(self.head_dim, max_seq_len=config.max_frames * 2) self.dropout = nn.Dropout(config.dropout) self.residual_scale = nn.Parameter(torch.ones(1) * 0.1) def forward(self, x: torch.Tensor, causal: bool = True, use_rope: bool = True) -> torch.Tensor: B, T, D = x.shape q = self.q_proj(x).view(B, T, self.num_heads, self.head_dim).transpose(1, 2) k = self.k_proj(x).view(B, T, self.num_kv_groups, self.head_dim).transpose(1, 2) v = self.v_proj(x).view(B, T, self.num_kv_groups, self.head_dim).transpose(1, 2) if use_rope: cos, sin = self.rope(T, x.device) q = (q * cos.unsqueeze(0).unsqueeze(0)) + (rotate_half(q) * sin.unsqueeze(0).unsqueeze(0)) k = (k * cos.unsqueeze(0).unsqueeze(0)) + (rotate_half(k) * sin.unsqueeze(0).unsqueeze(0)) k = k.repeat_interleave(self.heads_per_group, dim=1) v = v.repeat_interleave(self.heads_per_group, dim=1) attn_weights = torch.matmul(q, k.transpose(-2, -1)) * self.scale if causal: causal_mask = torch.triu(torch.ones(T, T, dtype=torch.bool, device=x.device), diagonal=1) attn_weights = attn_weights.masked_fill(causal_mask, float('-inf')) attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype) attn_weights = self.dropout(attn_weights) attn_output = torch.matmul(attn_weights, v) attn_output = attn_output.transpose(1, 2).contiguous().view(B, T, -1) return self.o_proj(attn_output) # ============================================================================== # LPOL MEMORY # ============================================================================== class VideoLPOL(nn.Module): def __init__(self, config: NexusVideoModelConfig): super().__init__() self.config = config self.n_domains = len(config.domain_types) self.slots_per_domain = config.memory_slots_per_domain self.memories = nn.ParameterDict({ d: nn.Parameter(torch.randn(self.slots_per_domain, config.d_model) * 0.02) for d in config.domain_types }) self.memory_attn = nn.ModuleDict({ 'q_proj': nn.Linear(config.d_model, config.d_model, bias=False), 'k_proj': nn.Linear(config.d_model, config.d_model, bias=False), 'v_proj': nn.Linear(config.d_model, config.d_model, bias=False), 'o_proj': nn.Linear(config.d_model, config.d_model, bias=False), }) self.domain_clf = nn.Sequential( nn.Linear(config.d_model, config.d_model // 2), nn.GELU(), nn.Dropout(config.dropout), nn.Linear(config.d_model // 2, self.n_domains) ) self.fusion = nn.Sequential(nn.Linear(config.d_model * 2, config.d_model), nn.GELU(), nn.Linear(config.d_model, config.d_model)) self.gate = nn.Sequential(nn.Linear(config.d_model * 2, config.d_model), nn.Sigmoid()) def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Dict]: B, T, D = x.shape x_pooled = x.mean(dim=1) domain_logits = self.domain_clf(x_pooled) domain_probs = F.softmax(domain_logits, dim=-1) all_memories = [] for i, domain_name in enumerate(self.config.domain_types): mem = self.memories[domain_name] weight = domain_probs[:, i:i+1] weighted_mem = mem.unsqueeze(0) * weight.unsqueeze(-1) all_memories.append(weighted_mem) memory_bank = torch.cat(all_memories, dim=1) q = self.memory_attn['q_proj'](x) k = self.memory_attn['k_proj'](memory_bank) v = self.memory_attn['v_proj'](memory_bank) n_heads = 8 head_dim = D // n_heads q = q.view(B, T, n_heads, head_dim).transpose(1, 2) k = k.view(B, -1, n_heads, head_dim).transpose(1, 2) v = v.view(B, -1, n_heads, head_dim).transpose(1, 2) scale = head_dim ** -0.5 attn_weights = torch.matmul(q, k.transpose(-2, -1)) * scale attn_weights = F.softmax(attn_weights, dim=-1) retrieved = torch.matmul(attn_weights, v) retrieved = retrieved.transpose(1, 2).contiguous().view(B, T, D) retrieved = self.memory_attn['o_proj'](retrieved) concat = torch.cat([x, retrieved], dim=-1) gate = self.gate(concat) fused = self.fusion(concat) output = x + gate * fused return output, {'domain_probs': domain_probs, 'top_domain': domain_probs.argmax(dim=-1)} # ============================================================================== # EXPERTS AND EARCP # ============================================================================== class VideoExpert(nn.Module): def __init__(self, config: NexusVideoModelConfig, expert_type: str): super().__init__() self.expert_type = expert_type self.confidence = nn.Sequential(nn.Linear(config.d_model, 64), nn.ReLU(), nn.Linear(64, 1), nn.Sigmoid()) self.gate = nn.Linear(config.d_model, config.d_model) self.fc1 = nn.Linear(config.d_model, config.d_ff) self.fc2 = nn.Linear(config.d_ff, config.d_model) self.dropout = nn.Dropout(config.dropout) def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: conf = self.confidence(x.mean(dim=1)) gate_val = torch.sigmoid(self.gate(x)) h = self.dropout(F.gelu(self.fc1(x))) h = self.fc2(h) out = h * gate_val return out, conf class VideoEARCPLayer(nn.Module): def __init__(self, config: NexusVideoModelConfig, layer_idx: int, n_experts: int = None): super().__init__() self.config = config self.layer_idx = layer_idx if n_experts is None: n_experts = len(config.expert_types) self.attn_norm = nn.LayerNorm(config.d_model, elementwise_affine=False) self.attn_scale = nn.Parameter(torch.ones(1)) self.temporal_attn = VideoGQA(config) self.experts = nn.ModuleList([ VideoExpert(config, config.expert_types[i] if i < len(config.expert_types) else f"Hybrid_{i}") for i in range(n_experts) ]) self.router = nn.Linear(config.d_model, n_experts) self.register_buffer('low_coh_count', torch.tensor(0)) self.coherence_score = 0.5 def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, float, bool]: h = self.attn_norm(x) attn_out = self.temporal_attn(h, causal=True) x = x + self.attn_scale * attn_out router_input = x.mean(dim=1) router_logits = self.router(router_input) weights = F.softmax(router_logits, dim=-1) expert_outputs = [] confs = [] for expert in self.experts: out, conf = expert(x) expert_outputs.append(out) confs.append(conf) expert_outputs = torch.stack(expert_outputs, dim=1) weighted_out = torch.einsum('be,betd->btd', weights, expert_outputs) x = x + weighted_out confs_tensor = torch.stack(confs, dim=1) expert_conf = confs_tensor.mean().item() entropy = -(weights * weights.log().clamp(min=-100)).sum(dim=-1).mean().item() max_entropy = math.log(len(self.experts)) if len(self.experts) > 1 else 1.0 routing_focus = 1 - (entropy / max_entropy) coherence = 0.5 * expert_conf + 0.5 * routing_focus self.coherence_score = coherence grew = False if coherence < self.config.growth_threshold_coherence: self.low_coh_count += 1 if self.low_coh_count >= self.config.growth_patience and len(self.experts) < self.config.max_experts: new_expert = VideoExpert(self.config, f"Hybrid_{len(self.experts)}").to(x.device) self.experts.append(new_expert) old_router = self.router self.router = nn.Linear(self.config.d_model, len(self.experts)).to(x.device) with torch.no_grad(): self.router.weight[:old_router.out_features] = old_router.weight self.router.bias[:old_router.out_features] = old_router.bias self.low_coh_count.zero_() grew = True else: self.low_coh_count.zero_() return x, coherence, grew def get_expert_count(self) -> int: return len(self.experts) # ============================================================================== # NEUROGENESIS # ============================================================================== class VideoNeurogenesis(nn.Module): def __init__(self, input_dim: int, n_neurons: int, config: NexusVideoModelConfig): super().__init__() self.config = config self.input_dim = input_dim self.weights = nn.Parameter(torch.randn(n_neurons, input_dim) * 0.02) self.bias = nn.Parameter(torch.zeros(n_neurons)) self.temporal_gate = nn.Linear(input_dim, n_neurons) self.register_buffer('n_neurons', torch.tensor(n_neurons)) self.register_buffer('usage', torch.ones(n_neurons)) self.register_buffer('lifetime', torch.zeros(n_neurons)) self.register_buffer('births', torch.tensor(0)) self.register_buffer('deaths', torch.tensor(0)) def forward(self, x: torch.Tensor) -> torch.Tensor: n = self.n_neurons.item() gate = torch.sigmoid(self.temporal_gate(x)) gate = gate[..., :n] out = torch.tanh(F.linear(x, self.weights[:n], self.bias[:n])) out = out * gate with torch.no_grad(): act = out.abs().mean(dim=(0, 1)) if act.size(0) == n: self.usage[:n] = 0.99 * self.usage[:n] + 0.01 * act self.lifetime[:n] += 1 return out def maybe_grow(self, coherence: float) -> int: if not self.config.neurogenesis_enabled: return 0 n = self.n_neurons.item() if n >= self.config.max_neurons or coherence < self.config.neuron_birth_threshold: return 0 device = self.weights.device with torch.no_grad(): new_w = torch.randn(1, self.input_dim, device=device) * 0.02 new_b = torch.zeros(1, device=device) self.weights = nn.Parameter(torch.cat([self.weights.data, new_w], dim=0)) self.bias = nn.Parameter(torch.cat([self.bias.data, new_b])) old_gate = self.temporal_gate self.temporal_gate = nn.Linear(self.input_dim, n + 1).to(device) self.temporal_gate.weight.data[:n] = old_gate.weight.data self.temporal_gate.weight.data[n:] = torch.randn(1, self.input_dim, device=device) * 0.02 self.temporal_gate.bias.data[:n] = old_gate.bias.data self.temporal_gate.bias.data[n:] = 0 self.usage = torch.cat([self.usage, torch.ones(1, device=device)]) self.lifetime = torch.cat([self.lifetime, torch.zeros(1, device=device)]) self.n_neurons += 1 self.births += 1 return 1 def resize(self, target_neurons: int): current = self.n_neurons.item() if target_neurons == current: return device = self.weights.device if target_neurons > current: extra = target_neurons - current new_w = torch.randn(extra, self.input_dim, device=device) * 0.02 new_b = torch.zeros(extra, device=device) self.weights = nn.Parameter(torch.cat([self.weights.data, new_w], dim=0)) self.bias = nn.Parameter(torch.cat([self.bias.data, new_b])) old_gate = self.temporal_gate self.temporal_gate = nn.Linear(self.input_dim, target_neurons).to(device) with torch.no_grad(): self.temporal_gate.weight[:current] = old_gate.weight self.temporal_gate.bias[:current] = old_gate.bias self.usage = torch.cat([self.usage, torch.ones(extra, device=device)]) self.lifetime = torch.cat([self.lifetime, torch.zeros(extra, device=device)]) else: keep_indices = torch.argsort(self.usage, descending=True)[:target_neurons] self.weights = nn.Parameter(self.weights.data[keep_indices]) self.bias = nn.Parameter(self.bias.data[keep_indices]) old_gate = self.temporal_gate self.temporal_gate = nn.Linear(self.input_dim, target_neurons).to(device) with torch.no_grad(): self.temporal_gate.weight[:] = old_gate.weight[keep_indices] self.temporal_gate.bias[:] = old_gate.bias[keep_indices] self.usage = self.usage[keep_indices] self.lifetime = self.lifetime[keep_indices] self.n_neurons.fill_(target_neurons) def get_stats(self) -> Dict: return {'total_neurons': self.n_neurons.item(), 'total_births': self.births.item(), 'total_deaths': self.deaths.item(), 'avg_usage': self.usage[:self.n_neurons.item()].mean().item(), 'max_lifetime': self.lifetime[:self.n_neurons.item()].max().item()} # ============================================================================== # TEMPORAL COHERENCE & FLOW # ============================================================================== class TemporalCoherenceModule(nn.Module): def __init__(self, config: NexusVideoModelConfig): super().__init__() d = config.d_model self.diff_predictor = nn.Sequential(nn.Linear(d * 2, d), nn.SiLU(), nn.Linear(d, d)) self.smooth = nn.Conv1d(d, d, kernel_size=3, padding=1, groups=d) self.alpha = nn.Parameter(torch.tensor(0.2)) self.register_buffer('coherence_history', torch.zeros(100)) self.register_buffer('history_idx', torch.tensor(0)) def forward(self, z_seq: torch.Tensor) -> Tuple[torch.Tensor, float]: B, T, D = z_seq.shape if T > 1: diffs = z_seq[:, 1:] - z_seq[:, :-1] pairs = torch.cat([z_seq[:, :-1], z_seq[:, 1:]], dim=-1) pred_diffs = self.diff_predictor(pairs) coherence = 1 - F.mse_loss(pred_diffs, diffs).item() coherence = max(0, min(1, coherence)) else: coherence = 1.0 z_t = z_seq.transpose(1, 2) smoothed = self.smooth(z_t).transpose(1, 2) alpha = torch.sigmoid(self.alpha) output = (1 - alpha) * z_seq + alpha * smoothed idx = self.history_idx.item() % 100 self.coherence_history[idx] = coherence self.history_idx += 1 return output, coherence def get_average_coherence(self) -> float: valid = min(self.history_idx.item(), 100) return self.coherence_history[:valid].mean().item() if valid > 0 else 0.0 class FlowPredictionModule(nn.Module): def __init__(self, config: NexusVideoModelConfig): super().__init__() d = config.d_model self.flow_encoder = nn.Sequential(nn.Linear(d * 2, d), nn.SiLU(), nn.Linear(d, d // 2), nn.SiLU(), nn.Linear(d // 2, d)) self.warp_net = nn.Sequential(nn.Linear(d * 2, d), nn.Tanh()) self.motion_magnitude = nn.Sequential(nn.Linear(d, 64), nn.ReLU(), nn.Linear(64, 1), nn.Sigmoid()) def forward(self, z_seq: torch.Tensor) -> Dict: B, T, D = z_seq.shape if T < 2: return {'flow': None, 'warped': z_seq, 'motion_magnitude': torch.zeros(B, 1, device=z_seq.device), 'flow_loss': torch.tensor(0.0, device=z_seq.device)} z_t = z_seq[:, :-1] z_t1 = z_seq[:, 1:] pairs = torch.cat([z_t, z_t1], dim=-1) flow = self.flow_encoder(pairs) warp_input = torch.cat([z_t, flow], dim=-1) warped = self.warp_net(warp_input) motion = self.motion_magnitude(flow.mean(dim=1)) flow_loss = F.mse_loss(warped, z_t1) return {'flow': flow, 'warped': warped, 'motion_magnitude': motion, 'flow_loss': flow_loss} class VideoEnergySystem(nn.Module): def __init__(self, config: NexusVideoModelConfig): super().__init__() self.config = config self.register_buffer('energy', torch.tensor(1.0)) self.register_buffer('consumed', torch.tensor(0.0)) self.costs = {'encode': config.energy_cost_encode, 'decode': config.energy_cost_decode, 'predict': config.energy_cost_predict, 'process': 0.01, 'memory': 0.005, 'attention': 0.008} def consume(self, operation: str, amount: float = None) -> bool: cost = amount if amount else self.costs.get(operation, 0.01) if self.energy.item() >= cost: self.energy -= cost self.consumed += cost return True return False def regenerate(self): regen = min(self.config.energy_regeneration, 1.0 - self.energy.item()) self.energy += regen def reset(self): self.energy.fill_(1.0) def get_stats(self) -> Dict: return {'energy': self.energy.item(), 'consumed': self.consumed.item()} # ============================================================================== # MAIN MODEL # ============================================================================== class NexusVideoModel(nn.Module): def __init__(self, config: NexusVideoModelConfig = None, expert_counts: List[int] = None): super().__init__() self.config = config or NexusVideoModelConfig() if expert_counts is None: expert_counts = [len(self.config.expert_types)] * self.config.n_layers self.encoder = VideoVAEEncoder(self.config) self.decoder = VideoVAEDecoder(self.config) self.lpol = VideoLPOL(self.config) if self.config.use_lpol else None self.layers = nn.ModuleList([ VideoEARCPLayer(self.config, i, n_experts=expert_counts[i]) for i in range(self.config.n_layers) ]) self.neurogenesis = VideoNeurogenesis(self.config.d_model, 64, self.config) self.neuro_proj = nn.Linear(self.config.max_neurons, self.config.d_model) self.temporal_coherence = TemporalCoherenceModule(self.config) self.flow_module = FlowPredictionModule(self.config) if self.config.flow_prediction else None self.energy = VideoEnergySystem(self.config) self.frame_predictor = nn.Sequential( nn.Linear(self.config.d_model, self.config.d_model), nn.SiLU(), nn.Dropout(self.config.dropout), nn.Linear(self.config.d_model, self.config.latent_channels * 8 * 8) ) def encode(self, video: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: self.energy.consume('encode') return self.encoder(video) def decode(self, z_spatial: torch.Tensor) -> torch.Tensor: self.energy.consume('decode') return self.decoder(z_spatial) def process_temporal(self, z: torch.Tensor) -> Dict: self.energy.consume('process') lpol_info = {} if self.lpol is not None: z, lpol_info = self.lpol(z) coherences = [] total_growth = 0 for layer in self.layers: z, coh, grew = layer(z) coherences.append(coh) if grew: total_growth += 1 neuro_out = self.neurogenesis(z) current_neurons = neuro_out.shape[-1] if current_neurons < self.config.max_neurons: padding = torch.zeros(*neuro_out.shape[:-1], self.config.max_neurons - current_neurons, device=neuro_out.device, dtype=neuro_out.dtype) neuro_out_padded = torch.cat([neuro_out, padding], dim=-1) else: neuro_out_padded = neuro_out[..., :self.config.max_neurons] neuro_proj = self.neuro_proj(neuro_out_padded) z = z + 0.1 * neuro_proj.mean(dim=-1, keepdim=True).expand_as(z) avg_coherence = sum(coherences) / len(coherences) if coherences else 0.0 neuro_growth = self.neurogenesis.maybe_grow(avg_coherence) z, temp_coherence = self.temporal_coherence(z) flow_info = {} if self.flow_module is not None: flow_info = self.flow_module(z) self.energy.regenerate() return {'z': z, 'coherence': avg_coherence, 'temporal_coherence': temp_coherence, 'expert_growth': total_growth, 'neuro_growth': neuro_growth, 'lpol_info': lpol_info, 'flow_info': flow_info, 'energy': self.energy.get_stats()} def forward(self, video: torch.Tensor) -> Dict: B = video.shape[0] z, mu, logvar, z_spatial = self.encode(video) proc_out = self.process_temporal(z) z_processed = proc_out['z'] recon = self.decode(z_spatial) if video.shape[2] == self.config.video.channels: video_compare = video.permute(0, 2, 1, 3, 4) else: video_compare = video if recon.shape != video_compare.shape: recon = F.interpolate(recon, size=video_compare.shape[2:], mode='trilinear', align_corners=False) recon_loss = F.mse_loss(recon, video_compare) kl_loss = 0.5 * torch.mean(mu.pow(2) + logvar.exp() - logvar - 1) coherence_loss = float(1.0 - proc_out['temporal_coherence']) flow_loss_tensor = torch.tensor(0.0, device=recon.device) if 'flow_info' in proc_out and proc_out['flow_info'].get('flow_loss') is not None: flow_loss_tensor = proc_out['flow_info']['flow_loss'] total_loss = recon_loss + self.config.kl_weight * kl_loss + self.config.temporal_coherence_weight * coherence_loss + 0.01 * flow_loss_tensor flow_loss = flow_loss_tensor.item() if hasattr(flow_loss_tensor, 'item') else float(flow_loss_tensor) return {'loss': total_loss, 'recon_loss': recon_loss, 'kl_loss': kl_loss, 'coherence_loss': coherence_loss, 'flow_loss': flow_loss, 'recon': recon, 'z': z, 'z_spatial': z_spatial, 'coherence': proc_out['coherence'], 'temporal_coherence': proc_out['temporal_coherence'], 'neurogenesis': proc_out['neuro_growth'], 'expert_growth': proc_out['expert_growth'], 'energy': proc_out['energy']} def generate(self, n_frames: int = 16, z_init: torch.Tensor = None, temperature: float = 1.0, batch_size: int = 1) -> torch.Tensor: self.eval() device = next(self.parameters()).device if z_init is None: B = batch_size T = max(1, n_frames // 4) H = self.config.video.height // 8 W = self.config.video.width // 8 z_init = torch.randn(B, self.config.latent_channels, T, H, W, device=device) * temperature with torch.no_grad(): video = self.decode(z_init) if video.shape[2] != n_frames: video = F.interpolate(video, size=(n_frames, video.shape[3], video.shape[4]), mode='trilinear', align_corners=False) return video.clamp(0, 1) def count_params(self) -> int: return sum(p.numel() for p in self.parameters()) def diagnostics(self) -> Dict: total_experts = sum(layer.get_expert_count() for layer in self.layers) return {'model_version': self.config.version, 'total_params': self.count_params(), 'total_experts': total_experts, 'expert_counts': [layer.get_expert_count() for layer in self.layers], 'neurogenesis': self.neurogenesis.get_stats(), 'energy': self.energy.get_stats()} @classmethod def from_pretrained(cls, pretrained_path: str, device: str = None, **kwargs) -> "NexusVideoModel": from huggingface_hub import snapshot_download if os.path.isdir(pretrained_path): load_dir = pretrained_path else: load_dir = snapshot_download(repo_id=pretrained_path) config_path = os.path.join(load_dir, "config.json") dynamic_state = None if os.path.exists(config_path): with open(config_path, 'r') as f: config_dict = json.load(f) dynamic_state = config_dict.pop('_dynamic_state', None) config_dict.pop('_class_name', None) config_dict.pop('_version', None) config_dict.pop('_architecture', None) config_dict.update(kwargs) config = NexusVideoModelConfig.from_dict(config_dict) else: config = NexusVideoModelConfig(**kwargs) expert_counts = None neuron_count = 64 neuro_proj_in = None # Will use max_neurons by default adaptive_proj_in = None adaptive_proj_out = None if dynamic_state is not None: expert_counts = dynamic_state.get('expert_counts') neuron_count = dynamic_state.get('neuron_count', 64) neuro_proj_in = dynamic_state.get('neuro_proj_in', neuron_count) adaptive_proj_in = dynamic_state.get('adaptive_proj_in') adaptive_proj_out = dynamic_state.get('adaptive_proj_out') model = cls(config, expert_counts=expert_counts) if neuron_count != model.neurogenesis.n_neurons.item(): model.neurogenesis.resize(neuron_count) # For backward compatibility with old checkpoints that had smaller neuro_proj if neuro_proj_in and neuro_proj_in != model.neuro_proj.in_features: model.neuro_proj = nn.Linear(neuro_proj_in, config.d_model) if adaptive_proj_in and adaptive_proj_out: model.encoder.adaptive_proj = nn.Linear(adaptive_proj_in, adaptive_proj_out) model_path = os.path.join(load_dir, "pytorch_model.bin") if os.path.exists(model_path): state_dict = torch.load(model_path, map_location='cpu') missing, unexpected = model.load_state_dict(state_dict, strict=False) if missing: logger.warning(f"Missing keys: {len(missing)}") if unexpected: logger.warning(f"Unexpected keys: {len(unexpected)}") if device is None: device = 'cuda' if torch.cuda.is_available() else 'cpu' return model.to(device)