""" ECG Lead Generator — Model Architecture CLIP-Conditioned 1D U-Net: 7 known leads → 5 predicted leads (V2-V6) """ import torch import torch.nn as nn import torch.nn.functional as F class FiLM(nn.Module): """Feature-wise Linear Modulation for CLIP conditioning (scale + shift).""" def __init__(self, cond_d: int, ch: int): super().__init__() self.scale = nn.Linear(cond_d, ch) self.shift = nn.Linear(cond_d, ch) def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: return x * (1 + self.scale(c).unsqueeze(-1)) + self.shift(c).unsqueeze(-1) class ResBlk(nn.Module): """Residual conv block with GroupNorm + GELU + optional FiLM conditioning.""" def __init__(self, ci: int, co: int, cd: int = None, drop: float = 0.1): super().__init__() g = lambda ch: min(8, ch) self.body = nn.Sequential( nn.GroupNorm(g(ci), ci), nn.GELU(), nn.Conv1d(ci, co, 3, padding=1), nn.Dropout(drop), nn.GroupNorm(g(co), co), nn.GELU(), nn.Conv1d(co, co, 3, padding=1), ) self.skip = nn.Conv1d(ci, co, 1) if ci != co else nn.Identity() self.film = FiLM(cd, co) if cd else None def forward(self, x, c=None): h = self.body(x) if self.film and c is not None: h = self.film(h, c) return h + self.skip(x) class Down(nn.Module): def __init__(self, ch): super().__init__() self.p = nn.Conv1d(ch, ch, 4, 2, 1) def forward(self, x): return self.p(x) class Up(nn.Module): def __init__(self, ci, co): super().__init__() self.u = nn.ConvTranspose1d(ci, co, 4, 2, 1) def forward(self, x, skip): x = self.u(x) d = skip.shape[-1] - x.shape[-1] if d > 0: x = F.pad(x, [0, d]) return torch.cat([x[:, :, :skip.shape[-1]], skip], dim=1) class LeadGenerator(nn.Module): """ CLIP-conditioned 1D U-Net. Input : [B, 7, L] — 7 known ECG leads (I, II, III, aVR, aVL, aVF, V1) Cond : [B, D] — CLIP visual embedding (FiLM-injected at every scale) Output: [B, 5, L] — predicted leads V2, V3, V4, V5, V6 """ def __init__(self, ni=7, no=5, ch=64, cd=1024, drop=0.1): super().__init__() self.cproj = nn.Sequential( nn.Linear(cd, ch * 4), nn.GELU(), nn.Linear(ch * 4, ch * 4) ) C = ch * 4 self.e1, self.d1 = ResBlk(ni, ch, C, drop), Down(ch) self.e2, self.d2 = ResBlk(ch, ch*2, C, drop), Down(ch*2) self.e3, self.d3 = ResBlk(ch*2, ch*4, C, drop), Down(ch*4) self.e4, self.d4 = ResBlk(ch*4, ch*8, C, drop), Down(ch*8) self.m1 = ResBlk(ch*8, ch*8, C, drop) self.m2 = ResBlk(ch*8, ch*8, C, drop) self.u4, self.r4 = Up(ch*8, ch*8), ResBlk(ch*16, ch*8, C, drop) self.u3, self.r3 = Up(ch*8, ch*4), ResBlk(ch*8, ch*4, C, drop) self.u2, self.r2 = Up(ch*4, ch*2), ResBlk(ch*4, ch*2, C, drop) self.u1, self.r1 = Up(ch*2, ch), ResBlk(ch*2, ch, C, drop) self.out = nn.Sequential( nn.GroupNorm(min(8, ch), ch), nn.GELU(), nn.Conv1d(ch, no, 1) ) def forward(self, x, clip_emb): c = self.cproj(clip_emb) s1 = self.e1(x, c); x = self.d1(s1) s2 = self.e2(x, c); x = self.d2(s2) s3 = self.e3(x, c); x = self.d3(s3) s4 = self.e4(x, c); x = self.d4(s4) x = self.m2(self.m1(x, c), c) x = self.r4(self.u4(x, s4), c) x = self.r3(self.u3(x, s3), c) x = self.r2(self.u2(x, s2), c) x = self.r1(self.u1(x, s1), c) return self.out(x) def load_from_hub(repo_id: str = "your-username/ecg-lead-generator") -> LeadGenerator: """Load LeadGenerator weights from Hugging Face Hub.""" from huggingface_hub import hf_hub_download import json config_path = hf_hub_download(repo_id, "config.json") with open(config_path) as f: cfg = json.load(f) model = LeadGenerator( ni=cfg["n_in"], no=cfg["n_out"], ch=cfg["base_ch"], cd=cfg["clip_dim"], ) try: from safetensors.torch import load_file w_path = hf_hub_download(repo_id, "model.safetensors") state = load_file(w_path) except Exception: w_path = hf_hub_download(repo_id, "lead_generator_weights.pt") ckpt = torch.load(w_path, map_location="cpu") state = ckpt["model_state"] model.load_state_dict(state) model.eval() return model