""" inpainter_lama.py — big-lama PyTorch inpainter Architecture built to match checkpoint indices exactly. """ import os, sys, warnings import cv2 import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from PIL import Image sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from config import CHECKPOINT_DIR, MASK_DILATION_PX, DEVICE LAMA_CKPT = os.path.join(CHECKPOINT_DIR, "big-lama", "models", "best.ckpt") # ── FFC primitives ──────────────────────────────────────────────────────────── class FourierUnit(nn.Module): def __init__(self, in_ch, out_ch): super().__init__() self.conv_layer = nn.Conv2d(in_ch * 2, out_ch * 2, 1, bias=False) self.bn = nn.BatchNorm2d(out_ch * 2) self.relu = nn.ReLU(inplace=True) def forward(self, x): b, c, h, w = x.shape f = torch.fft.rfftn(x, dim=(-2,-1), norm="ortho") f = torch.stack([f.real, f.imag], -1).permute(0,1,4,2,3).contiguous() f = f.view(b, -1, f.shape[-2], f.shape[-1]) f = self.relu(self.bn(self.conv_layer(f))) f = f.view(b, -1, 2, f.shape[-2], f.shape[-1]).permute(0,1,3,4,2).contiguous() f = torch.view_as_complex(f) return torch.fft.irfftn(f, s=(h,w), dim=(-2,-1), norm="ortho") class SpectralTransform(nn.Module): def __init__(self, in_ch, out_ch, stride=1): super().__init__() self.downsample = nn.AvgPool2d(2) if stride == 2 else nn.Identity() self.conv1 = nn.Sequential( nn.Conv2d(in_ch, out_ch//2, 1, bias=False), nn.BatchNorm2d(out_ch//2), nn.ReLU(inplace=True)) self.fu = FourierUnit(out_ch//2, out_ch//2) self.conv2 = nn.Conv2d(out_ch//2, out_ch, 1, bias=False) def forward(self, x): return self.conv2(self.fu(self.conv1(self.downsample(x)))) class FFC(nn.Module): def __init__(self, in_ch, out_ch, ksize, ratio_gin, ratio_gout, stride=1, pad=0): super().__init__() in_cg = int(in_ch * ratio_gin); in_cl = in_ch - in_cg out_cg = int(out_ch * ratio_gout); out_cl = out_ch - out_cg self.in_cg = in_cg self.convl2l = nn.Conv2d(in_cl, out_cl, ksize, stride, pad, bias=False) if in_cl>0 and out_cl>0 else None self.convl2g = nn.Conv2d(in_cl, out_cg, ksize, stride, pad, bias=False) if in_cl>0 and out_cg>0 else None self.convg2l = nn.Conv2d(in_cg, out_cl, ksize, stride, pad, bias=False) if in_cg>0 and out_cl>0 else None self.convg2g = SpectralTransform(in_cg, out_cg, stride) if in_cg>0 and out_cg>0 else None def forward(self, x): xl, xg = x if isinstance(x, tuple) else (x, None) yl_parts = [f(t) for f, t in [(self.convl2l, xl), (self.convg2l, xg)] if f is not None and t is not None] yg_parts = [f(t) for f, t in [(self.convl2g, xl), (self.convg2g, xg)] if f is not None and t is not None] yl = yl_parts[0] + yl_parts[1] if len(yl_parts) == 2 else (yl_parts[0] if yl_parts else torch.zeros_like(xl)) yg = yg_parts[0] + yg_parts[1] if len(yg_parts) == 2 else (yg_parts[0] if yg_parts else None) return yl, yg class FFCBNAct(nn.Module): def __init__(self, in_ch, out_ch, ksize, ratio_gin, ratio_gout, stride=1, pad=0): super().__init__() out_cl = int(out_ch*(1-ratio_gout)); out_cg = out_ch - out_cl self.ffc = FFC(in_ch, out_ch, ksize, ratio_gin, ratio_gout, stride, pad) self.bn_l = nn.BatchNorm2d(out_cl) if out_cl>0 else None self.bn_g = nn.BatchNorm2d(out_cg) if out_cg>0 else None def forward(self, x): xl, xg = self.ffc(x) if self.bn_l: xl = F.relu(self.bn_l(xl), inplace=True) if self.bn_g and xg is not None: xg = F.relu(self.bn_g(xg), inplace=True) return xl, xg class FFCResBlock(nn.Module): def __init__(self, dim, ratio_gin, ratio_gout): super().__init__() self.conv1 = FFCBNAct(dim, dim, 3, ratio_gin, ratio_gout, pad=1) self.conv2 = FFCBNAct(dim, dim, 3, ratio_gout, ratio_gout, pad=1) def forward(self, x): xl, xg = x if isinstance(x, tuple) else (x, None) rl, rg = xl, xg xl, xg = self.conv1((xl, xg)) xl, xg = self.conv2((xl, xg)) xl = xl + rl if xg is not None and rg is not None: xg = xg + rg return xl, xg # ── Generator — indices match checkpoint exactly ────────────────────────────── # Index map from checkpoint: # 0 = ReflectionPad2d(3) [no params] # 1 = FFCBNAct(4→64, k7, 0→0) # 2 = FFCBNAct(64→128, k3, 0→0, stride=2) # 3 = FFCBNAct(128→256,k3, 0→0, stride=2) # 4 = FFCBNAct(256→512,k3, 0→0.75, stride=2) 128 local + 384 global # 5-22 = FFCResBlock(512, 0.75→0.75) × 18 # 23 = ReLU [no params] # 24 = ConvTranspose2d(512→256) # 25 = BatchNorm2d(256) # 26 = ReLU [no params] # 27 = ConvTranspose2d(256→128) # 28 = BatchNorm2d(128) # 29 = ReLU [no params] # 30 = ConvTranspose2d(128→64) # 31 = BatchNorm2d(64) # 32 = ReLU [no params] # 33 = ReflectionPad2d(3) [no params] # 34 = Conv2d(64→3, k7) # 35 = Sigmoid [no params] class LaMaGenerator(nn.Module): def __init__(self): super().__init__() ngf = 64 self.model = nn.ModuleDict({ "1": FFCBNAct(4, ngf, 7, 0, 0, pad=0), "2": FFCBNAct(ngf, ngf*2, 3, 0, 0, stride=2, pad=1), "3": FFCBNAct(ngf*2, ngf*4, 3, 0, 0, stride=2, pad=1), "4": FFCBNAct(ngf*4, ngf*8, 3, 0, 0.75, stride=2, pad=1), **{str(i): FFCResBlock(ngf*8, 0.75, 0.75) for i in range(5, 23)}, "24": nn.ConvTranspose2d(ngf*8, ngf*4, 3, stride=2, padding=1, output_padding=1), "25": nn.BatchNorm2d(ngf*4), "27": nn.ConvTranspose2d(ngf*4, ngf*2, 3, stride=2, padding=1, output_padding=1), "28": nn.BatchNorm2d(ngf*2), "30": nn.ConvTranspose2d(ngf*2, ngf, 3, stride=2, padding=1, output_padding=1), "31": nn.BatchNorm2d(ngf), "34": nn.Conv2d(ngf, 3, 7, padding=0), }) def forward(self, x): x = F.pad(x, (3,3,3,3), mode="reflect") # idx 0 x = self.model["1"](x) # FFCBNAct → tuple x = self.model["2"](x) x = self.model["3"](x) x = self.model["4"](x) for i in range(5, 23): x = self.model[str(i)](x) # merge local+global xl, xg = x x = torch.cat([xl, xg], dim=1) # 512 ch x = F.relu(x, inplace=True) # idx 23 x = F.relu(self.model["25"](self.model["24"](x)), inplace=True) x = F.relu(self.model["28"](self.model["27"](x)), inplace=True) x = F.relu(self.model["31"](self.model["30"](x)), inplace=True) x = F.pad(x, (3,3,3,3), mode="reflect") # idx 33 x = torch.sigmoid(self.model["34"](x)) # idx 34+35 return x # ── Inpainter ───────────────────────────────────────────────────────────────── class LamaInpainter: def __init__(self): self._model = None self._session = None self._load_pytorch() if self._model is None: self._load_onnx() def _load_pytorch(self): if not os.path.exists(LAMA_CKPT): print(f" big-lama checkpoint not found at {LAMA_CKPT}"); return try: print(" Loading big-lama (PyTorch) ...") ckpt = torch.load(LAMA_CKPT, map_location="cpu", weights_only=False) state = ckpt.get("state_dict", ckpt) # strip "generator." prefix gen_state = {k[len("generator."):]: v for k, v in state.items() if k.startswith("generator.")} model = LaMaGenerator() missing, unexpected = model.load_state_dict(gen_state, strict=False) if missing: print(f" Missing keys: {missing[:3]}") model.eval().to(DEVICE) self._model = model print(" big-lama loaded successfully.") except Exception as e: warnings.warn(f" big-lama PyTorch load failed: {e}") def _load_onnx(self): onnx_path = os.path.join(CHECKPOINT_DIR, "lama_fp32.onnx") try: import onnxruntime as ort from huggingface_hub import hf_hub_download if not os.path.exists(onnx_path): print(" Downloading LaMa ONNX (~125 MB)...") os.makedirs(CHECKPOINT_DIR, exist_ok=True) hf_hub_download(repo_id="Carve/LaMa-ONNX", filename="lama_fp32.onnx", local_dir=CHECKPOINT_DIR) self._session = ort.InferenceSession( onnx_path, providers=["CUDAExecutionProvider","CPUExecutionProvider"]) print(" LaMa ONNX loaded as fallback.") except Exception as e: warnings.warn(f" LaMa ONNX load failed: {e}. Will use OpenCV.") # ── public ──────────────────────────────────────────────────────────────── def inpaint(self, image_pil: Image.Image, mask: np.ndarray) -> Image.Image: img_np = np.array(image_pil.convert("RGB")) if self._model is not None: import gc try: return self._pytorch_inpaint(img_np, mask) except Exception as e: if "CUDA out of memory" in str(e) or "out of memory" in str(e).lower(): print(" [WARN] CUDA VRAM exceeded! Rerunning LaMa on CPU to guarantee sharp, clear output...") import torch torch.cuda.empty_cache() gc.collect() old_device = next(self._model.parameters()).device self._model = self._model.to("cpu") try: res = self._pytorch_inpaint(img_np, mask, device_override="cpu") self._model = self._model.to(old_device) # Restore return res except Exception as e2: print(f" [ERROR] CPU PyTorch inpaint failed as well: {e2}") self._model = self._model.to(old_device) else: print(f" [ERROR] PyTorch inpaint failed with non-OOM error: {e}") import torch torch.cuda.empty_cache() gc.collect() # Try ONNX as next best fallback for quality if self._session is None: self._load_onnx() if self._session is not None: try: return self._onnx_inpaint(img_np, mask) except Exception as e: print(f" [ERROR] ONNX inpaint failed: {e}") # If all AI methods fail, we still want to avoid the 'rainbow' smudge # We will try a different OpenCV method or just return original to avoid 'broken' look print(" [CRITICAL] All AI inpainting failed. Returning original to avoid blurry corruption.") return Image.fromarray(img_np) def _pytorch_inpaint(self, img_np, mask, device_override=None): exec_device = device_override or DEVICE h, w = img_np.shape[:2] ys, xs = np.where(mask > 0) if len(xs) == 0: return Image.fromarray(img_np) x1,x2 = int(xs.min()), int(xs.max()) y1,y2 = int(ys.min()), int(ys.max()) pad = max(64, int(max(x2-x1, y2-y1)*0.3)) x1=max(0,x1-pad); y1=max(0,y1-pad) x2=min(w,x2+pad); y2=min(h,y2+pad) pi = img_np[y1:y2, x1:x2].copy() pm = mask[y1:y2, x1:x2].copy() ph, pw = pi.shape[:2] # pad to multiple of 8 ph_p = ((ph+7)//8)*8; pw_p = ((pw+7)//8)*8 ip = np.pad(pi, ((0,ph_p-ph),(0,pw_p-pw),(0,0))) mp = np.pad(pm, ((0,ph_p-ph),(0,pw_p-pw))) it = torch.from_numpy(ip.astype(np.float32)/255.).permute(2,0,1).unsqueeze(0).to(exec_device) mt = torch.from_numpy((mp>0).astype(np.float32)).unsqueeze(0).unsqueeze(0).to(exec_device) it = it*(1-mt) inp = torch.cat([it, mt], dim=1).to(exec_device) with torch.no_grad(): out = self._model(inp) res = (out[0].permute(1,2,0).cpu().numpy()*255).clip(0,255).astype(np.uint8)[:ph,:pw] out_np = img_np.copy() mask_bool = pm > 0 out_np[y1:y2, x1:x2][mask_bool] = res[mask_bool] return Image.fromarray(out_np) def _onnx_inpaint(self, img_np, mask): SZ = 512 h, w = img_np.shape[:2] ys, xs = np.where(mask>0) if len(xs)==0: return Image.fromarray(img_np) x1,x2=int(xs.min()),int(xs.max()); y1,y2=int(ys.min()),int(ys.max()) px=max(40,int((x2-x1)*.25)); py=max(40,int((y2-y1)*.25)) x1=max(0,x1-px); y1=max(0,y1-py); x2=min(w,x2+px); y2=min(h,y2+py) pi=img_np[y1:y2,x1:x2]; pm=mask[y1:y2,x1:x2]; ph,pw=pi.shape[:2] ir=cv2.resize(pi,(SZ,SZ),interpolation=cv2.INTER_LINEAR) mr=cv2.resize(pm,(SZ,SZ),interpolation=cv2.INTER_NEAREST) it=(ir.astype(np.float32)/255.).transpose(2,0,1)[None] mt=(mr>0).astype(np.float32)[None,None] n0=self._session.get_inputs()[0].name; n1=self._session.get_inputs()[1].name o0=self._session.get_outputs()[0].name res=self._session.run([o0],{n0:it,n1:mt})[0][0] res=(res.transpose(1,2,0)*(255. if res.max()<=1.5 else 1.)).clip(0,255).astype(np.uint8) res=cv2.resize(res,(pw,ph),interpolation=cv2.INTER_LINEAR) out=img_np.copy() mask_bool = pm > 0 out[y1:y2,x1:x2][mask_bool] = res[mask_bool] return Image.fromarray(out) @staticmethod def _opencv_inpaint(img_np, mask): print(" (OpenCV NS fallback)") bgr=cv2.cvtColor(img_np,cv2.COLOR_RGB2BGR) res=cv2.inpaint(bgr,(mask>0).astype(np.uint8)*255,5,cv2.INPAINT_NS) return Image.fromarray(cv2.cvtColor(res,cv2.COLOR_BGR2RGB))