LDW-CNet — Learnable Discrete Wavelet CNN for Brain Tumor MRI Classification

LDW-CNet is a convolutional neural network that combines a learnable Discrete Wavelet Transform (DWT) front-end with a residual+SE backbone for 4-class brain MRI classification. It is trained on the [masoudnickparvar brain-tumor-mri-dataset] (https://www.kaggle.com/datasets/masoudnickparvar/brain-tumor-mri-dataset) and is intended as a research artefact demonstrating the value of integrating classical signal-processing priors into deep networks.

Model description

The model has three stages:

  1. Learnable DWT layer. Decomposes the input into LL/LH/HL/HH subbands using depthwise separable convolutions whose filters are initialised from db4 but are learnable. A wavelet-constraint loss keeps g = QMF(h) and ‖h‖² = ‖g‖² = 1 during training, so the filters remain valid orthogonal wavelets.
  2. SE-recalibration of the wavelet subbands so the network can re-weight subband importance per input.
  3. Residual CNN backbone with 4 stages, channel attention (SE), and linearly increasing stochastic-depth regularization (0 → 0.2). The final classifier is a 2-layer LayerNorm + SiLU MLP.

Variants in this repo

Model Description
PlainCNN Backbone alone — no wavelet front-end
FixedDWT_CNN Backbone with fixed db4 wavelet
LearnDWT_NoLwc Learnable DWT, no constraint loss, no SE on subbands
LearnDWT_SE_NoLwc Learnable DWT + SE, no constraint loss
LDWCNet_Full Full model: learnable DWT + SE + wavelet-constraint loss + stochastic depth

Intended use

  • Primary: a research demonstration of wavelet–CNN hybrids for medical imaging.
  • Out of scope: clinical decision making. This model is not a medical device and has not been validated for diagnostic use. It must not be used to inform real patient care.

How to load

import torch
from huggingface_hub import hf_hub_download

# pip install pywavelets
ckpt = torch.load(hf_hub_download(
    'Shanmuk4622/ldwcnet-brain-mri-5-model-model', 'checkpoints/LDWCNet_Full_best.pt',
    repo_type='model'), map_location='cpu', weights_only=False)

# Build the architecture (copy LDWCNet class from this repo's notebooks)
# from notebook_02_architecture import LDWCNet
model = LDWCNet(in_channels=3, num_classes=4, wavelet_init='db4',
                stochastic_depth_max=0.2)
model.load_state_dict(ckpt['state_dict'])
model.eval()

Performance — held-out test split

Headline metrics (Test-Time Augmentation, 8-way)

Metric Value
Accuracy 0.8950
Balanced Accuracy 0.8950
Macro F1 0.8925
Weighted F1 0.8925
MCC 0.8627
Cohen's κ 0.8600
Macro ROC-AUC 0.9749
Macro AP 0.9506

Bootstrap 95% CI on Macro-F1 (n=2000 resamples): [0.8761, 0.9071]

Per-class breakdown — LDWCNet_Full

Class Precision Recall F1 Support
glioma 0.9589 0.7575 0.8464 400
meningioma 0.8760 0.8475 0.8615 400
notumor 0.8955 0.9850 0.9381 400
pituitary 0.8665 0.9900 0.9242 400

All variants

Model Accuracy Macro F1 MCC Macro ROC-AUC
LDWCNet_Full 0.8950 0.8925 0.8627 0.9749
FixedDWT_CNN 0.8625 0.8589 0.8204 0.9605
LearnDWT_SE_NoLwc 0.8500 0.8460 0.8032 0.9559
LearnDWT_NoLwc 0.8469 0.8431 0.7996 0.9521
PlainCNN 0.8444 0.8388 0.7971 0.9569

Efficiency

Metric Value
Parameters 5.08M
GFLOPs 3.19
Latency (batch=1) 2.89 ± 0.10 ms
Throughput (batch=32) 528 img/s

Measured on the training device (see reports/efficiency.csv).

Training procedure

  • Optimizer: AdamW, base LR 3e-4, weight decay 1e-4 for non-DWT params (0 for DWT)
  • Schedule: OneCycleLR, cosine anneal, 10% warm-up
  • Loss: class-weighted label-smoothing CE (smoothing=0.05) + cosine-scheduled wavelet-constraint loss (λ_wc_max=0.5, warm-up 10 ep, decays to 0)
  • Regularisation: Mixup + CutMix at p=0.5 (50/50 split), stochastic depth 0→0.2, EMA (decay 0.9998 with timm-style warm-up ramp), TTA at evaluation
  • Augmentation: hflip, vflip, rotate ±20°, RandomBrightnessContrast, ShiftScaleRotate, ElasticTransform, GridDistortion, CLAHE, GaussNoise, CoarseDropout
  • Warm-start for LDWCNet_Full: backbone + SE + classifier initialised from FixedDWT_CNN_best.pt; DWT filters frozen for the first 30 epochs, then unfrozen with λ_wc=0.5 to keep them as valid wavelets
  • Epochs: 120 | Batch size: 32 per GPU | Image size: 224×224

Limitations and biases

  • Trained on a single public dataset; performance on data from other scanners, populations, or protocols is unknown and likely worse.
  • Patient demographics in the dataset are not balanced and not fully documented; the model's behaviour across demographic subgroups has not been audited.
  • The model assumes a clean axial T1/T2 brain MRI slice as input. Behaviour on non-brain or non-MRI inputs is undefined.
  • Calibration: ECE before temperature scaling is 0.0665; consider applying T=0.96 before using probabilities as confidence.

Citation

@misc{ldwcnet,
  title  = {LDW-CNet: Learnable Discrete Wavelet CNN for Brain Tumor MRI Classification},
  author = {LDW-CNet authors},
  year   = {2026},
  howpublished = {HuggingFace, url:{https://huggingface.co/Shanmuk4622/ldwcnet-brain-mri-5-model-model}}
}

License

MIT

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Evaluation results