Vbai-2.6MSS

Description

Vbai-2.6 is a 3D brain MRI segmentation model developed as the latest generation member of the Vbai model family. Unlike previous versions, Vbai-2.6MSS now works exclusively with NIfTI files for professional research purposes. The Vbai-3D versions have been merged with the standard Vbai versions.

The model generates voxel-level segmentation masks instead of image-level labels and provides spatial localization of pathological regions in addition to quantitative tissue volume measurements.

Vbai-2.6MSS also serves as the core engine of the HealFuture image processing library and can run each diagnostic task independently or in combination, depending on the clinical use case. This model is trained exclusively for MS lesions.

Audience / Target

Vbai models are developed exclusively for hospitals, universities, communities, health centres and science centres.

Architecture

Input Shared Encoder Output
FLAIR (1ch) 3D ResNet
+ CBAM
+ SE
+ ASPP
→ Lesion Decoder (Attention-Gated) → Binary MS lesion mask
  • Encoder: Custom 3D ResNet (channels 32→64→128→256, bottleneck 320) with CBAM, Squeeze-and-Excitation, and ASPP modules
  • Decoder heads: Single UNet-style decoder with attention gates
  • Deep supervision: 3 auxiliary outputs during training
  • Inference: Sliding window (96³ patches, 50 % overlap) + optional TTA
  • Patch-based: 96³ patches sampled from a 128³ canvas; lesion-component-balanced foreground sampling
  • Inference: Sliding window (96³ patches, 50 % overlap, Gaussian blending) + optional TTA

General Tests

Input/Patch Size Params Accuracy ROC-AUC F1 Score F1 Score (Median) F1 Score (Lesion) Recall Precision IoU LTPR F2 Score HD95 (mean) HD95 (median) MCC Specificity FPR FNR Volumetric Similarity
96³ > 128³ (grid) (FLAIR) 30.73M %100 %94.40 %63.55 %65.13 %79.54 %74.57 %57.65 %47.55 %75.16 %69.13 16.7309mm (128³ grid) 12.5607mm (128³ grid) %64.78 %100 %0.01< %25.43 %83.65

*Tested with MSLesSeg Dataset but training is excluding MSLesSeg Dataset.

*Model was trained in 74 epochs.

*No transfer learning or pre-trained weights were used.

Usage

Python Script (PT Models)
"""
==============================================================================
Vbai-2.6MSS  -  Multiple Sclerosis Lesion Segmentation (standalone)
==============================================================================
Single-file, shareable inference + visualization script.

Give a path to a brain FLAIR file -> the model marks the MS lesions ->
the marked slices are displayed (red overlay) and saved as a PNG.

Architecture (embedded below):
  - Input : FLAIR (1 channel)
  - Encoder: custom 3D ResNet (32->64->128->256, bottleneck 320)
             with Squeeze-and-Excitation + CBAM + ASPP
  - Decoder: single UNet-style decoder with attention gates
  - Deep supervision: 3 auxiliary heads (training only)
  - Inference: sliding window (96^3 patches on a 128^3 canvas, 50% overlap,
               Gaussian blending)
  - Single task, FLAIR-only, ~30.7M params

Usage:
  python Vbai-2.6MSS.py --flair "/path/to/patient_FLAIR.nii.gz"
  optional: --ckpt <weights.pt>  --out <result.png>  --grid 128  --n-slices 6

Requirements: torch, numpy, nibabel, scipy, matplotlib
==============================================================================
"""
import os
import argparse
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import nibabel as nib
from scipy.ndimage import zoom, label as cc_label

# Detect Colab for inline display
try:
  import google.colab  # noqa: F401
  _IN_COLAB = True
except Exception:
  _IN_COLAB = False
import matplotlib
if not _IN_COLAB:
  matplotlib.use("Agg")
import matplotlib.pyplot as plt

# ---- Inference settings (this is the 96^3-patch / 128^3-grid variant) -------
GRID       = (128, 128, 128)   # canvas the brain is resampled to
PATCH      = (96, 96, 96)      # model input patch (must be divisible by 16)
OVERLAP    = 0.5               # sliding-window overlap ratio
THRESHOLD  = 0.5               # probability threshold for the binary mask
MIN_CC     = 10                # drop predicted components smaller than this (noise)


# ============================================================================
#  MODEL ARCHITECTURE  (Vbai-2.6MSS)
#  NOTE: submodule attribute names must stay identical to the trained model
#        so the checkpoint state_dict loads correctly.
# ============================================================================
class SEBlock3D(nn.Module):
  """Squeeze-and-Excitation: channel-wise recalibration."""
  def __init__(self, ch, r=16):
      super().__init__()
      mid = max(ch // r, 4)
      self.pool = nn.AdaptiveAvgPool3d(1)
      self.fc = nn.Sequential(nn.Linear(ch, mid), nn.ReLU(True),
                              nn.Linear(mid, ch), nn.Sigmoid())

  def forward(self, x):
      b, c = x.shape[:2]
      return x * self.fc(self.pool(x).view(b, c)).view(b, c, 1, 1, 1)


class CBAM3D(nn.Module):
  """Convolutional Block Attention Module (channel + spatial)."""
  def __init__(self, ch, r=16, ks=7):
      super().__init__()
      mid = max(ch // r, 4)
      self.avg = nn.AdaptiveAvgPool3d(1)
      self.mx = nn.AdaptiveMaxPool3d(1)
      self.ch_fc = nn.Sequential(nn.Linear(ch, mid), nn.ReLU(True), nn.Linear(mid, ch))
      self.sp = nn.Sequential(nn.Conv3d(2, 1, ks, padding=ks // 2, bias=False), nn.BatchNorm3d(1))

  def forward(self, x):
      b, c = x.shape[:2]
      ch = torch.sigmoid(self.ch_fc(self.avg(x).view(b, c)) +
                         self.ch_fc(self.mx(x).view(b, c))).view(b, c, 1, 1, 1)
      x = x * ch
      sp = torch.sigmoid(self.sp(torch.cat([x.mean(1, True), x.max(1, True).values], 1)))
      return x * sp


class ResBlock3D(nn.Module):
  """Residual block with SE + CBAM attention."""
  def __init__(self, ic, oc, stride=1, drop=0.1, se=True, cbam=True):
      super().__init__()
      self.conv = nn.Sequential(
          nn.Conv3d(ic, oc, 3, stride, 1, bias=False), nn.BatchNorm3d(oc), nn.ReLU(True),
          nn.Dropout3d(drop),
          nn.Conv3d(oc, oc, 3, 1, 1, bias=False), nn.BatchNorm3d(oc))
      self.skip = (nn.Sequential(nn.Conv3d(ic, oc, 1, stride, bias=False), nn.BatchNorm3d(oc))
                   if ic != oc or stride != 1 else nn.Identity())
      self.se = SEBlock3D(oc) if se else nn.Identity()
      self.cbam = CBAM3D(oc) if cbam else nn.Identity()
      self.act = nn.ReLU(True)

  def forward(self, x):
      return self.act(self.cbam(self.se(self.conv(x))) + self.skip(x))


class ASPP3D(nn.Module):
  """Atrous Spatial Pyramid Pooling: multi-scale context."""
  def __init__(self, ic, oc, dils=(1, 3, 6)):
      super().__init__()
      mid = oc // (len(dils) + 2)
      self.branches = nn.ModuleList([
          nn.Sequential(nn.Conv3d(ic, mid, 3, padding=d, dilation=d, bias=False),
                        nn.BatchNorm3d(mid), nn.ReLU(True)) for d in dils])
      self.gp = nn.Sequential(nn.AdaptiveAvgPool3d(1),
                              nn.Conv3d(ic, mid, 1, bias=False), nn.ReLU(True))
      self.pw = nn.Sequential(nn.Conv3d(ic, mid, 1, bias=False),
                              nn.BatchNorm3d(mid), nn.ReLU(True))
      tot = mid * (len(dils) + 2)
      self.proj = nn.Sequential(nn.Conv3d(tot, oc, 1, bias=False),
                                nn.BatchNorm3d(oc), nn.ReLU(True), nn.Dropout3d(0.1))

  def forward(self, x):
      sz = x.shape[2:]
      fs = [b(x) for b in self.branches]
      fs.append(F.interpolate(self.gp(x), sz, mode="trilinear", align_corners=False))
      fs.append(self.pw(x))
      return self.proj(torch.cat(fs, 1))


class AttGate3D(nn.Module):
  """Attention gate: filters skip features using the decoder gating signal."""
  def __init__(self, fc, gc):
      super().__init__()
      ic = fc // 2
      self.Wf = nn.Sequential(nn.Conv3d(fc, ic, 1, bias=False), nn.BatchNorm3d(ic))
      self.Wg = nn.Sequential(nn.Conv3d(gc, ic, 1, bias=False), nn.BatchNorm3d(ic))
      self.ps = nn.Sequential(nn.Conv3d(ic, 1, 1, bias=False), nn.BatchNorm3d(1), nn.Sigmoid())
      self.r = nn.ReLU(True)

  def forward(self, feat, gate):
      if gate.shape[2:] != feat.shape[2:]:
          gate = F.interpolate(gate, feat.shape[2:], mode="trilinear", align_corners=False)
      return feat * self.ps(self.r(self.Wf(feat) + self.Wg(gate)))


class EncBlock(nn.Module):
  """Two residual blocks + strided downsample. Returns (skip, downsampled)."""
  def __init__(self, ic, oc, drop=0.1):
      super().__init__()
      self.blk = nn.Sequential(ResBlock3D(ic, oc, drop=drop), ResBlock3D(oc, oc, drop=drop))
      self.down = nn.Sequential(nn.Conv3d(oc, oc, 3, stride=2, padding=1, bias=False),
                                nn.BatchNorm3d(oc), nn.ReLU(True))

  def forward(self, x):
      s = self.blk(x)
      return s, self.down(s)


class DecBlock(nn.Module):
  """Upsample + attention-gated skip fusion + two residual blocks."""
  def __init__(self, ic, sc, oc, drop=0.1, ag=True):
      super().__init__()
      self.ag = AttGate3D(sc, ic) if ag else None
      self.blk = nn.Sequential(ResBlock3D(ic + sc, oc, drop=drop), ResBlock3D(oc, oc, drop=drop))

  def forward(self, x, skip):
      x = F.interpolate(x, skip.shape[2:], mode="trilinear", align_corners=False)
      if self.ag:
          skip = self.ag(skip, x)
      return self.blk(torch.cat([x, skip], 1))


class Vbai26MSS(nn.Module):
  """
  Vbai-2.6MSS - single-task 3D UNet for MS lesion segmentation.
  Input: FLAIR (in_ch=1) -> output: 1-channel lesion logit.
  """
  def __init__(self, in_ch=1, bc=32, mults=(1, 2, 4, 8, 10), drop=0.1, ds=True):
      super().__init__()
      ch = [bc * m for m in mults]
      self.ds = ds
      self.stem = nn.Sequential(nn.Conv3d(in_ch, ch[0], 3, 1, 1, bias=False),
                                nn.BatchNorm3d(ch[0]), nn.ReLU(True))
      self.e0 = EncBlock(ch[0], ch[0], drop=drop)
      self.e1 = EncBlock(ch[0], ch[1], drop=drop)
      self.e2 = EncBlock(ch[1], ch[2], drop=drop)
      self.e3 = EncBlock(ch[2], ch[3], drop=drop)
      self.bn = nn.Sequential(ResBlock3D(ch[3], ch[4], drop=drop), ASPP3D(ch[4], ch[4]))
      self.d0 = DecBlock(ch[4], ch[3], ch[3], drop=drop)
      self.d1 = DecBlock(ch[3], ch[2], ch[2], drop=drop)
      self.d2 = DecBlock(ch[2], ch[1], ch[1], drop=drop)
      self.d3 = DecBlock(ch[1], ch[0], ch[0], drop=drop)
      self.head = nn.Conv3d(ch[0], 1, 1)
      if ds:
          self.ds0 = nn.Conv3d(ch[3], 1, 1)
          self.ds1 = nn.Conv3d(ch[2], 1, 1)
          self.ds2 = nn.Conv3d(ch[1], 1, 1)

  def forward(self, x, return_aux=False):
      s = self.stem(x)
      k0, d0 = self.e0(s)
      k1, d1 = self.e1(d0)
      k2, d2 = self.e2(d1)
      k3, d3 = self.e3(d2)
      bn = self.bn(d3)
      u3 = self.d0(bn, k3)
      u2 = self.d1(u3, k2)
      u1 = self.d2(u2, k1)
      u0 = self.d3(u1, k0)
      return self.head(u0)   # inference only: no aux heads needed


# ============================================================================
#  PREPROCESS + SLIDING-WINDOW INFERENCE
# ============================================================================
def load_nii(path):
  """Load a NIfTI volume as float32 (handles a .nii path that is actually a folder)."""
  if os.path.isdir(path):
      inner = [f for f in os.listdir(path) if f.lower().endswith((".nii", ".nii.gz"))]
      path = os.path.join(path, inner[0])
  v = np.asarray(nib.load(path).dataobj, dtype=np.float32)
  return np.nan_to_num(v, nan=0., posinf=0., neginf=0.)


def znorm(v):
  """Z-score normalization over the foreground (v > 0)."""
  m = v > 0
  if not m.any():
      return v.astype(np.float32)
  out = np.zeros_like(v, dtype=np.float32)
  out[m] = (v[m] - v[m].mean()) / max(v[m].std(), 1e-6)
  return out


def resamp(v, target, order=1):
  return zoom(v, [t / c for t, c in zip(target, v.shape)], order=order).astype(np.float32)


@torch.no_grad()
def sliding_window_prob(model, x, patch=PATCH, overlap=OVERLAP):
  """Run the model over overlapping patches and blend with a Gaussian window."""
  _, _, D, H, W = x.shape
  pd, ph, pw = patch
  sd = max(1, int(pd * (1 - overlap)))
  sh = max(1, int(ph * (1 - overlap)))
  sw = max(1, int(pw * (1 - overlap)))

  def starts(dim, p, s):
      if dim <= p:
          return [0]
      st = list(range(0, dim - p + 1, s))
      if st[-1] != dim - p:
          st.append(dim - p)
      return st

  def gauss1d(n):
      c = (n - 1) / 2.0
      s = n * 0.125 + 1e-6
      return np.exp(-0.5 * ((np.arange(n) - c) / s) ** 2)

  win = torch.tensor((gauss1d(pd)[:, None, None] * gauss1d(ph)[None, :, None] *
                      gauss1d(pw)[None, None, :]).astype(np.float32),
                     device=x.device).clamp_min(1e-4)

  acc = torch.zeros((D, H, W), device=x.device)
  cnt = torch.zeros((D, H, W), device=x.device)
  for z0 in starts(D, pd, sd):
      for y0 in starts(H, ph, sh):
          for x0 in starts(W, pw, sw):
              patch_in = x[:, :, z0:z0 + pd, y0:y0 + ph, x0:x0 + pw]
              p = torch.sigmoid(model(patch_in))[0, 0]
              acc[z0:z0 + pd, y0:y0 + ph, x0:x0 + pw] += p * win
              cnt[z0:z0 + pd, y0:y0 + ph, x0:x0 + pw] += win
  return (acc / cnt.clamp_min(1e-6)).cpu().numpy()


def clean_small(mask, min_size):
  """Remove connected components smaller than min_size voxels."""
  if min_size <= 1:
      return mask
  cc, n = cc_label(mask)
  if n == 0:
      return mask
  sizes = np.bincount(cc.ravel())
  sizes[0] = 0
  return np.isin(cc, np.where(sizes >= min_size)[0]).astype(np.uint8)


# ============================================================================
#  MAIN: mark a brain file and visualize
# ============================================================================
@torch.no_grad()
def mark_brain(flair_path, ckpt_path, out_png, grid, n_slices, device):
  # Build model and load weights (load only matching keys: aux/ds heads optional)
  model = Vbai26MSS(in_ch=1, ds=True).to(device)
  ck = torch.load(ckpt_path, map_location=device, weights_only=False)
  state = ck["model"] if isinstance(ck, dict) and "model" in ck else ck
  model.load_state_dict(state, strict=False)
  model.eval()
  print(f"Vbai-2.6MSS loaded: {os.path.basename(ckpt_path)}")

  # Preprocess: foreground z-score + resample to canvas grid
  raw = load_nii(flair_path)
  flair = resamp(znorm(raw), grid)
  disp = resamp(raw, grid)
  if (disp > 0).any():
      lo, hi = np.percentile(disp[disp > 0], [1, 99])
      disp = np.clip((disp - lo) / (hi - lo + 1e-6), 0, 1)

  # Inference -> probability -> binary mask -> noise cleanup
  x = torch.tensor(flair[None, None], dtype=torch.float32).to(device)
  prob = sliding_window_prob(model, x, PATCH, OVERLAP)
  pred = clean_small((prob >= THRESHOLD).astype(np.uint8), MIN_CC)

  n_vox = int(pred.sum())
  n_les = int(cc_label(pred)[1])
  print(f"Marked lesions: {n_les} (total {n_vox} voxels)")

  # Pick the most-lesion axial slices; overlay prediction in red
  scores = pred.sum(axis=(0, 1))
  zs = sorted(np.argsort(scores)[-n_slices:]) if scores.sum() > 0 else [grid[2] // 2]
  cols = len(zs)
  fig, axes = plt.subplots(1, cols, figsize=(3 * cols, 3.4))
  if cols == 1:
      axes = [axes]
  for ax, z in zip(axes, zs):
      ax.imshow(disp[:, :, z].T, cmap="gray", origin="lower")
      overlay = np.zeros((*pred[:, :, z].T.shape, 4), np.float32)
      overlay[..., 0] = 1.0
      overlay[..., 3] = (pred[:, :, z].T > 0) * 0.5
      ax.imshow(overlay, origin="lower")
      ax.set_title(f"z={z}", fontsize=8)
      ax.axis("off")
  fig.suptitle(f"Vbai-2.6MSS  |  {os.path.basename(flair_path)}  |  "
               f"{n_les} lesion(s), {n_vox} voxels  (red = prediction)", fontsize=11)
  plt.tight_layout(rect=[0, 0, 1, 0.94])
  plt.savefig(out_png, dpi=130, bbox_inches="tight")
  print(f"Saved -> {out_png}")
  if _IN_COLAB:
      plt.show()
  plt.close(fig)


def main():
  ap = argparse.ArgumentParser(description="Vbai-2.6MSS - mark MS lesions on a FLAIR brain scan")
  ap.add_argument("--flair", required=True, help="path to a FLAIR .nii / .nii.gz file")
  ap.add_argument("--ckpt", default="Vbai-2.6MSS.pt", help="path to model weights")
  ap.add_argument("--out", default=None, help="output PNG path (default: next to input)")
  ap.add_argument("--grid", type=int, default=GRID[0], help="canvas size (default 128)")
  ap.add_argument("--n-slices", type=int, default=6, help="number of slices to display")
  args = ap.parse_args()

  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  if not os.path.exists(args.ckpt):
      print(f"Checkpoint not found: {args.ckpt}")
      return
  if not os.path.exists(args.flair):
      print(f"FLAIR not found: {args.flair}")
      return
  out = args.out or (os.path.splitext(args.flair.replace(".nii.gz", ".nii"))[0] + "_Vbai-2.6MSS.png")
  grid = (args.grid, args.grid, args.grid)
  mark_brain(args.flair, args.ckpt, out, grid, args.n_slices, device)


if __name__ == "__main__":
  main()

Requirements

  • Python ≥ 3.9
  • PyTorch ≥ 2.0
  • CUDA-capable GPU, ≥ 16 GB VRAM recommended (Tested with at least an NVIDIA Tesla T4 with 16 GB of VRAM) (Trained with NVIDIA L4 with of 24 GB of VRAM)
  • See requirements.txt for full dependency list

License

CC-BY-NC-SA 4.0 - see LICENSE file for details.

Support


Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support