Vbai-2.6MSS
Description
Vbai-2.6 is a 3D brain MRI segmentation model developed as the latest generation member of the Vbai model family. Unlike previous versions, Vbai-2.6MSS now works exclusively with NIfTI files for professional research purposes. The Vbai-3D versions have been merged with the standard Vbai versions.
The model generates voxel-level segmentation masks instead of image-level labels and provides spatial localization of pathological regions in addition to quantitative tissue volume measurements.
Vbai-2.6MSS also serves as the core engine of the HealFuture image processing library and can run each diagnostic task independently or in combination, depending on the clinical use case. This model is trained exclusively for MS lesions.
Audience / Target
Vbai models are developed exclusively for hospitals, universities, communities, health centres and science centres.
Architecture
| Input | Shared Encoder | Output |
|---|---|---|
| FLAIR (1ch) | 3D ResNet + CBAM + SE + ASPP |
→ Lesion Decoder (Attention-Gated) → Binary MS lesion mask |
- Encoder: Custom 3D ResNet (channels 32→64→128→256, bottleneck 320) with CBAM, Squeeze-and-Excitation, and ASPP modules
- Decoder heads: Single UNet-style decoder with attention gates
- Deep supervision: 3 auxiliary outputs during training
- Inference: Sliding window (96³ patches, 50 % overlap) + optional TTA
- Patch-based: 96³ patches sampled from a 128³ canvas; lesion-component-balanced foreground sampling
- Inference: Sliding window (96³ patches, 50 % overlap, Gaussian blending) + optional TTA
General Tests
| Input/Patch Size | Params | Accuracy | ROC-AUC | F1 Score | F1 Score (Median) | F1 Score (Lesion) | Recall | Precision | IoU | LTPR | F2 Score | HD95 (mean) | HD95 (median) | MCC | Specificity | FPR | FNR | Volumetric Similarity |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 96³ > 128³ (grid) (FLAIR) | 30.73M | %100 | %94.40 | %63.55 | %65.13 | %79.54 | %74.57 | %57.65 | %47.55 | %75.16 | %69.13 | 16.7309mm (128³ grid) | 12.5607mm (128³ grid) | %64.78 | %100 | %0.01< | %25.43 | %83.65 |
*Tested with MSLesSeg Dataset but training is excluding MSLesSeg Dataset.
*Model was trained in 74 epochs.
*No transfer learning or pre-trained weights were used.
Usage
Python Script (PT Models)
"""
==============================================================================
Vbai-2.6MSS - Multiple Sclerosis Lesion Segmentation (standalone)
==============================================================================
Single-file, shareable inference + visualization script.
Give a path to a brain FLAIR file -> the model marks the MS lesions ->
the marked slices are displayed (red overlay) and saved as a PNG.
Architecture (embedded below):
- Input : FLAIR (1 channel)
- Encoder: custom 3D ResNet (32->64->128->256, bottleneck 320)
with Squeeze-and-Excitation + CBAM + ASPP
- Decoder: single UNet-style decoder with attention gates
- Deep supervision: 3 auxiliary heads (training only)
- Inference: sliding window (96^3 patches on a 128^3 canvas, 50% overlap,
Gaussian blending)
- Single task, FLAIR-only, ~30.7M params
Usage:
python Vbai-2.6MSS.py --flair "/path/to/patient_FLAIR.nii.gz"
optional: --ckpt <weights.pt> --out <result.png> --grid 128 --n-slices 6
Requirements: torch, numpy, nibabel, scipy, matplotlib
==============================================================================
"""
import os
import argparse
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import nibabel as nib
from scipy.ndimage import zoom, label as cc_label
# Detect Colab for inline display
try:
import google.colab # noqa: F401
_IN_COLAB = True
except Exception:
_IN_COLAB = False
import matplotlib
if not _IN_COLAB:
matplotlib.use("Agg")
import matplotlib.pyplot as plt
# ---- Inference settings (this is the 96^3-patch / 128^3-grid variant) -------
GRID = (128, 128, 128) # canvas the brain is resampled to
PATCH = (96, 96, 96) # model input patch (must be divisible by 16)
OVERLAP = 0.5 # sliding-window overlap ratio
THRESHOLD = 0.5 # probability threshold for the binary mask
MIN_CC = 10 # drop predicted components smaller than this (noise)
# ============================================================================
# MODEL ARCHITECTURE (Vbai-2.6MSS)
# NOTE: submodule attribute names must stay identical to the trained model
# so the checkpoint state_dict loads correctly.
# ============================================================================
class SEBlock3D(nn.Module):
"""Squeeze-and-Excitation: channel-wise recalibration."""
def __init__(self, ch, r=16):
super().__init__()
mid = max(ch // r, 4)
self.pool = nn.AdaptiveAvgPool3d(1)
self.fc = nn.Sequential(nn.Linear(ch, mid), nn.ReLU(True),
nn.Linear(mid, ch), nn.Sigmoid())
def forward(self, x):
b, c = x.shape[:2]
return x * self.fc(self.pool(x).view(b, c)).view(b, c, 1, 1, 1)
class CBAM3D(nn.Module):
"""Convolutional Block Attention Module (channel + spatial)."""
def __init__(self, ch, r=16, ks=7):
super().__init__()
mid = max(ch // r, 4)
self.avg = nn.AdaptiveAvgPool3d(1)
self.mx = nn.AdaptiveMaxPool3d(1)
self.ch_fc = nn.Sequential(nn.Linear(ch, mid), nn.ReLU(True), nn.Linear(mid, ch))
self.sp = nn.Sequential(nn.Conv3d(2, 1, ks, padding=ks // 2, bias=False), nn.BatchNorm3d(1))
def forward(self, x):
b, c = x.shape[:2]
ch = torch.sigmoid(self.ch_fc(self.avg(x).view(b, c)) +
self.ch_fc(self.mx(x).view(b, c))).view(b, c, 1, 1, 1)
x = x * ch
sp = torch.sigmoid(self.sp(torch.cat([x.mean(1, True), x.max(1, True).values], 1)))
return x * sp
class ResBlock3D(nn.Module):
"""Residual block with SE + CBAM attention."""
def __init__(self, ic, oc, stride=1, drop=0.1, se=True, cbam=True):
super().__init__()
self.conv = nn.Sequential(
nn.Conv3d(ic, oc, 3, stride, 1, bias=False), nn.BatchNorm3d(oc), nn.ReLU(True),
nn.Dropout3d(drop),
nn.Conv3d(oc, oc, 3, 1, 1, bias=False), nn.BatchNorm3d(oc))
self.skip = (nn.Sequential(nn.Conv3d(ic, oc, 1, stride, bias=False), nn.BatchNorm3d(oc))
if ic != oc or stride != 1 else nn.Identity())
self.se = SEBlock3D(oc) if se else nn.Identity()
self.cbam = CBAM3D(oc) if cbam else nn.Identity()
self.act = nn.ReLU(True)
def forward(self, x):
return self.act(self.cbam(self.se(self.conv(x))) + self.skip(x))
class ASPP3D(nn.Module):
"""Atrous Spatial Pyramid Pooling: multi-scale context."""
def __init__(self, ic, oc, dils=(1, 3, 6)):
super().__init__()
mid = oc // (len(dils) + 2)
self.branches = nn.ModuleList([
nn.Sequential(nn.Conv3d(ic, mid, 3, padding=d, dilation=d, bias=False),
nn.BatchNorm3d(mid), nn.ReLU(True)) for d in dils])
self.gp = nn.Sequential(nn.AdaptiveAvgPool3d(1),
nn.Conv3d(ic, mid, 1, bias=False), nn.ReLU(True))
self.pw = nn.Sequential(nn.Conv3d(ic, mid, 1, bias=False),
nn.BatchNorm3d(mid), nn.ReLU(True))
tot = mid * (len(dils) + 2)
self.proj = nn.Sequential(nn.Conv3d(tot, oc, 1, bias=False),
nn.BatchNorm3d(oc), nn.ReLU(True), nn.Dropout3d(0.1))
def forward(self, x):
sz = x.shape[2:]
fs = [b(x) for b in self.branches]
fs.append(F.interpolate(self.gp(x), sz, mode="trilinear", align_corners=False))
fs.append(self.pw(x))
return self.proj(torch.cat(fs, 1))
class AttGate3D(nn.Module):
"""Attention gate: filters skip features using the decoder gating signal."""
def __init__(self, fc, gc):
super().__init__()
ic = fc // 2
self.Wf = nn.Sequential(nn.Conv3d(fc, ic, 1, bias=False), nn.BatchNorm3d(ic))
self.Wg = nn.Sequential(nn.Conv3d(gc, ic, 1, bias=False), nn.BatchNorm3d(ic))
self.ps = nn.Sequential(nn.Conv3d(ic, 1, 1, bias=False), nn.BatchNorm3d(1), nn.Sigmoid())
self.r = nn.ReLU(True)
def forward(self, feat, gate):
if gate.shape[2:] != feat.shape[2:]:
gate = F.interpolate(gate, feat.shape[2:], mode="trilinear", align_corners=False)
return feat * self.ps(self.r(self.Wf(feat) + self.Wg(gate)))
class EncBlock(nn.Module):
"""Two residual blocks + strided downsample. Returns (skip, downsampled)."""
def __init__(self, ic, oc, drop=0.1):
super().__init__()
self.blk = nn.Sequential(ResBlock3D(ic, oc, drop=drop), ResBlock3D(oc, oc, drop=drop))
self.down = nn.Sequential(nn.Conv3d(oc, oc, 3, stride=2, padding=1, bias=False),
nn.BatchNorm3d(oc), nn.ReLU(True))
def forward(self, x):
s = self.blk(x)
return s, self.down(s)
class DecBlock(nn.Module):
"""Upsample + attention-gated skip fusion + two residual blocks."""
def __init__(self, ic, sc, oc, drop=0.1, ag=True):
super().__init__()
self.ag = AttGate3D(sc, ic) if ag else None
self.blk = nn.Sequential(ResBlock3D(ic + sc, oc, drop=drop), ResBlock3D(oc, oc, drop=drop))
def forward(self, x, skip):
x = F.interpolate(x, skip.shape[2:], mode="trilinear", align_corners=False)
if self.ag:
skip = self.ag(skip, x)
return self.blk(torch.cat([x, skip], 1))
class Vbai26MSS(nn.Module):
"""
Vbai-2.6MSS - single-task 3D UNet for MS lesion segmentation.
Input: FLAIR (in_ch=1) -> output: 1-channel lesion logit.
"""
def __init__(self, in_ch=1, bc=32, mults=(1, 2, 4, 8, 10), drop=0.1, ds=True):
super().__init__()
ch = [bc * m for m in mults]
self.ds = ds
self.stem = nn.Sequential(nn.Conv3d(in_ch, ch[0], 3, 1, 1, bias=False),
nn.BatchNorm3d(ch[0]), nn.ReLU(True))
self.e0 = EncBlock(ch[0], ch[0], drop=drop)
self.e1 = EncBlock(ch[0], ch[1], drop=drop)
self.e2 = EncBlock(ch[1], ch[2], drop=drop)
self.e3 = EncBlock(ch[2], ch[3], drop=drop)
self.bn = nn.Sequential(ResBlock3D(ch[3], ch[4], drop=drop), ASPP3D(ch[4], ch[4]))
self.d0 = DecBlock(ch[4], ch[3], ch[3], drop=drop)
self.d1 = DecBlock(ch[3], ch[2], ch[2], drop=drop)
self.d2 = DecBlock(ch[2], ch[1], ch[1], drop=drop)
self.d3 = DecBlock(ch[1], ch[0], ch[0], drop=drop)
self.head = nn.Conv3d(ch[0], 1, 1)
if ds:
self.ds0 = nn.Conv3d(ch[3], 1, 1)
self.ds1 = nn.Conv3d(ch[2], 1, 1)
self.ds2 = nn.Conv3d(ch[1], 1, 1)
def forward(self, x, return_aux=False):
s = self.stem(x)
k0, d0 = self.e0(s)
k1, d1 = self.e1(d0)
k2, d2 = self.e2(d1)
k3, d3 = self.e3(d2)
bn = self.bn(d3)
u3 = self.d0(bn, k3)
u2 = self.d1(u3, k2)
u1 = self.d2(u2, k1)
u0 = self.d3(u1, k0)
return self.head(u0) # inference only: no aux heads needed
# ============================================================================
# PREPROCESS + SLIDING-WINDOW INFERENCE
# ============================================================================
def load_nii(path):
"""Load a NIfTI volume as float32 (handles a .nii path that is actually a folder)."""
if os.path.isdir(path):
inner = [f for f in os.listdir(path) if f.lower().endswith((".nii", ".nii.gz"))]
path = os.path.join(path, inner[0])
v = np.asarray(nib.load(path).dataobj, dtype=np.float32)
return np.nan_to_num(v, nan=0., posinf=0., neginf=0.)
def znorm(v):
"""Z-score normalization over the foreground (v > 0)."""
m = v > 0
if not m.any():
return v.astype(np.float32)
out = np.zeros_like(v, dtype=np.float32)
out[m] = (v[m] - v[m].mean()) / max(v[m].std(), 1e-6)
return out
def resamp(v, target, order=1):
return zoom(v, [t / c for t, c in zip(target, v.shape)], order=order).astype(np.float32)
@torch.no_grad()
def sliding_window_prob(model, x, patch=PATCH, overlap=OVERLAP):
"""Run the model over overlapping patches and blend with a Gaussian window."""
_, _, D, H, W = x.shape
pd, ph, pw = patch
sd = max(1, int(pd * (1 - overlap)))
sh = max(1, int(ph * (1 - overlap)))
sw = max(1, int(pw * (1 - overlap)))
def starts(dim, p, s):
if dim <= p:
return [0]
st = list(range(0, dim - p + 1, s))
if st[-1] != dim - p:
st.append(dim - p)
return st
def gauss1d(n):
c = (n - 1) / 2.0
s = n * 0.125 + 1e-6
return np.exp(-0.5 * ((np.arange(n) - c) / s) ** 2)
win = torch.tensor((gauss1d(pd)[:, None, None] * gauss1d(ph)[None, :, None] *
gauss1d(pw)[None, None, :]).astype(np.float32),
device=x.device).clamp_min(1e-4)
acc = torch.zeros((D, H, W), device=x.device)
cnt = torch.zeros((D, H, W), device=x.device)
for z0 in starts(D, pd, sd):
for y0 in starts(H, ph, sh):
for x0 in starts(W, pw, sw):
patch_in = x[:, :, z0:z0 + pd, y0:y0 + ph, x0:x0 + pw]
p = torch.sigmoid(model(patch_in))[0, 0]
acc[z0:z0 + pd, y0:y0 + ph, x0:x0 + pw] += p * win
cnt[z0:z0 + pd, y0:y0 + ph, x0:x0 + pw] += win
return (acc / cnt.clamp_min(1e-6)).cpu().numpy()
def clean_small(mask, min_size):
"""Remove connected components smaller than min_size voxels."""
if min_size <= 1:
return mask
cc, n = cc_label(mask)
if n == 0:
return mask
sizes = np.bincount(cc.ravel())
sizes[0] = 0
return np.isin(cc, np.where(sizes >= min_size)[0]).astype(np.uint8)
# ============================================================================
# MAIN: mark a brain file and visualize
# ============================================================================
@torch.no_grad()
def mark_brain(flair_path, ckpt_path, out_png, grid, n_slices, device):
# Build model and load weights (load only matching keys: aux/ds heads optional)
model = Vbai26MSS(in_ch=1, ds=True).to(device)
ck = torch.load(ckpt_path, map_location=device, weights_only=False)
state = ck["model"] if isinstance(ck, dict) and "model" in ck else ck
model.load_state_dict(state, strict=False)
model.eval()
print(f"Vbai-2.6MSS loaded: {os.path.basename(ckpt_path)}")
# Preprocess: foreground z-score + resample to canvas grid
raw = load_nii(flair_path)
flair = resamp(znorm(raw), grid)
disp = resamp(raw, grid)
if (disp > 0).any():
lo, hi = np.percentile(disp[disp > 0], [1, 99])
disp = np.clip((disp - lo) / (hi - lo + 1e-6), 0, 1)
# Inference -> probability -> binary mask -> noise cleanup
x = torch.tensor(flair[None, None], dtype=torch.float32).to(device)
prob = sliding_window_prob(model, x, PATCH, OVERLAP)
pred = clean_small((prob >= THRESHOLD).astype(np.uint8), MIN_CC)
n_vox = int(pred.sum())
n_les = int(cc_label(pred)[1])
print(f"Marked lesions: {n_les} (total {n_vox} voxels)")
# Pick the most-lesion axial slices; overlay prediction in red
scores = pred.sum(axis=(0, 1))
zs = sorted(np.argsort(scores)[-n_slices:]) if scores.sum() > 0 else [grid[2] // 2]
cols = len(zs)
fig, axes = plt.subplots(1, cols, figsize=(3 * cols, 3.4))
if cols == 1:
axes = [axes]
for ax, z in zip(axes, zs):
ax.imshow(disp[:, :, z].T, cmap="gray", origin="lower")
overlay = np.zeros((*pred[:, :, z].T.shape, 4), np.float32)
overlay[..., 0] = 1.0
overlay[..., 3] = (pred[:, :, z].T > 0) * 0.5
ax.imshow(overlay, origin="lower")
ax.set_title(f"z={z}", fontsize=8)
ax.axis("off")
fig.suptitle(f"Vbai-2.6MSS | {os.path.basename(flair_path)} | "
f"{n_les} lesion(s), {n_vox} voxels (red = prediction)", fontsize=11)
plt.tight_layout(rect=[0, 0, 1, 0.94])
plt.savefig(out_png, dpi=130, bbox_inches="tight")
print(f"Saved -> {out_png}")
if _IN_COLAB:
plt.show()
plt.close(fig)
def main():
ap = argparse.ArgumentParser(description="Vbai-2.6MSS - mark MS lesions on a FLAIR brain scan")
ap.add_argument("--flair", required=True, help="path to a FLAIR .nii / .nii.gz file")
ap.add_argument("--ckpt", default="Vbai-2.6MSS.pt", help="path to model weights")
ap.add_argument("--out", default=None, help="output PNG path (default: next to input)")
ap.add_argument("--grid", type=int, default=GRID[0], help="canvas size (default 128)")
ap.add_argument("--n-slices", type=int, default=6, help="number of slices to display")
args = ap.parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if not os.path.exists(args.ckpt):
print(f"Checkpoint not found: {args.ckpt}")
return
if not os.path.exists(args.flair):
print(f"FLAIR not found: {args.flair}")
return
out = args.out or (os.path.splitext(args.flair.replace(".nii.gz", ".nii"))[0] + "_Vbai-2.6MSS.png")
grid = (args.grid, args.grid, args.grid)
mark_brain(args.flair, args.ckpt, out, grid, args.n_slices, device)
if __name__ == "__main__":
main()
Requirements
- Python ≥ 3.9
- PyTorch ≥ 2.0
- CUDA-capable GPU, ≥ 16 GB VRAM recommended (Tested with at least an NVIDIA Tesla T4 with 16 GB of VRAM) (Trained with NVIDIA L4 with of 24 GB of VRAM)
- See
requirements.txtfor full dependency list
License
CC-BY-NC-SA 4.0 - see LICENSE file for details.
Support
- Website: Neurazum
- Email: contact@neurazum.com