--- license: apache-2.0 language: - en tags: - remote-sensing - semantic-segmentation - agriculture - crop-mapping - sentinel-2 - swin-transformer - stcln - pastis - pytorch - amd-rocm datasets: - pastis metrics: - mean_iou - f1 - accuracy model-index: - name: Swin-STCLN-PASTIS results: - task: type: image-segmentation name: Semantic Segmentation dataset: name: PASTIS type: pastis metrics: - type: mean_iou value: 44.43 name: Mean IoU (5-fold CV) - type: f1 value: 58.35 name: F1 Score (5-fold CV) - type: accuracy value: 68.36 name: Overall Accuracy (5-fold CV) --- # Swin-STCLN × PASTIS > **Swin-STCLN: Hierarchical Swin Transformer Enhanced Spatio-Temporal Contrastive Learning Network for Crop Mapping** This repository implements an improved STCLN architecture for Sentinel-2 satellite image time-series semantic segmentation on the PASTIS benchmark. The original STCLN pretrain → finetune workflow is preserved while replacing the flat CNN spatial encoder with a hierarchical Swin Transformer encoder and introducing cross-scale spatiotemporal fusion with boundary-aware refinement. --- # 🏆 Results — 5-Fold Cross Validation ## Summary | Metric | Fold 1 | Fold 2 | Fold 3 | Fold 4 | Fold 5 | **Mean ± Std** | |--------|--------|--------|--------|--------|--------|----------------| | **mFscore** | 58.37% | 56.69% | 59.18% | 61.00% | 56.49% | **58.35% ± 1.70%** | | **mIoU** | 44.35% | 42.90% | 45.54% | 46.80% | 42.58% | **44.43% ± 1.62%** | | **OA** | 67.10% | 66.85% | 69.96% | 71.48% | 66.41% | **68.36% ± 2.13%** | | **Kappa** | 60.02% | 59.71% | 62.79% | 64.33% | 58.99% | **61.17% ± 2.19%** | | **mPrecision** | 52.78% | 51.33% | 54.51% | 57.40% | 51.24% | **53.45% ± 2.47%** | | **mRecall** | 70.72% | 69.00% | 68.69% | 68.93% | 68.72% | **69.21% ± 0.82%** | > Official benchmark result: > > **mFscore = 58.35% ± 1.70%** Trained from scratch on AMD MI300X GPU. --- # 📊 Per-Class IoU — All Folds | Class | Mean IoU | |---|---| | 🌽 Corn | **75.90%** | | 🌿 Winter rapeseed | **75.55%** | | 🌾 Beet | **73.90%** | | 🌾 Soft winter wheat | **71.47%** | | 🌿 Soybeans | **61.48%** | | 🌾 Winter barley | **58.64%** | | 🌱 Meadow | **56.69%** | | 🌻 Sunflower | **54.87%** | | 🟫 Background | **50.98%** | | 🌾 Winter durum wheat | **44.56%** | | 🌿 Grapevine | **36.70%** | | 🥔 Potatoes | **36.34%** | | 🌾 Spring barley | **35.08%** | | 🌿 Leguminous fodder | **22.41%** | | 🌾 Winter triticale | **22.34%** | | 🍎 Fruits/veg/flowers | **21.53%** | | 🍑 Orchard | **17.27%** | | 🌾 Mixed cereal | **14.66%** | | 🌿 Sorghum | **13.91%** | --- # 🏗️ Model Architecture — Swin-STCLN ## 1. Swin Spatial Encoder (SEncoder Replacement) Original STCLN: ``` DoubleConv + DoubleConv filters: 32 → 256 kernel: 3×3 Output: (B×T,256,32,32) ``` Problem: - flat representation - single scale - limited long range spatial modeling Swin-STCLN replaces it: Input: (B,T,10,32,32) ### Patch Embedding Conv2D: 10 → 96 channels kernel = 2 stride = 2 Spatial: 32×32 → 16×16 Output: (B×T,96,16,16) ### Swin Stage 1 2 Swin Transformer Blocks - Window attention - Shifted window attention Configuration: dim = 96 window = 4 Output: f_fine: (B×T,96,16,16) ### Patch Merging 16×16 → 8×8 Channel: 96 → 192 ### Swin Stage 2 2 Swin Transformer Blocks dim = 192 window = 4 Output: f_coarse: (B×T,192,8,8) --- # 2. Temporal Encoder The original STCLN Transformer Temporal Encoder is kept unchanged. Configuration: - Transformer layers: 3 - Attention heads: 8 - Sinusoidal positional encoding - GroupNorm Only input representation changes. Original STCLN: (B×32×32) × T × 256 Swin-STCLN: (B×8×8) × T × 192 Benefits: - reduces temporal tokens from 1024 → 64 locations - temporal reasoning happens on semantic regions - lower memory usage --- # 3. STFusion — Cross Scale Fusion Replacement for STCLN STA module. Inputs: Temporal branch: coarse_agg (B,192,8,8) Spatial branch: fine_agg (B,96,16,16) Process: 1. Upsample 8×8 →16×16 2. Projection 192 →96 3. Cross Attention Query: coarse semantic features Key / Value: fine spatial features 4. Residual fusion 5. Upsample 16×16 →32×32 Output: (B,128,32,32) --- # 4. Pretraining Reconstruction Decoder The original STCLN self-supervised learning objective is preserved. Unchanged: - Spatiotemporal masking - Mask ratio = 0.4 - MSE reconstruction loss - Unlabeled pretraining patches = 7936 Difference: Base STCLN reconstructs directly using a Linear layer because temporal features remain at 32×32. Swin-STCLN reconstructs from hierarchical features. TEncoder output: ``` (B,T,192,8,8) ``` Reconstruction Head: ``` ConvTranspose2D 8×8 ↓ 16×16 ConvTranspose2D 16×16 ↓ 32×32 Conv2D 192 → 10 Sentinel-2 bands ``` Final reconstruction: ``` (B,T,10,32,32) ``` The same masked pixel MSE objective is applied. --- # 5. Finetuning Decoder The original STCLN linear decoder is replaced with a boundary-aware segmentation decoder. Input: ``` STFusion output (B,128,32,32) ``` ## Semantic Decoder Architecture: Conv-BN-ReLU ``` 128 → 64 ``` Conv-BN-ReLU ``` 64 →64 ``` 1×1 Conv classifier Output: ``` (B,18,32,32) ``` --- ## Boundary Decoder Parallel boundary prediction branch: Conv-BN-ReLU ``` 128 →64 ``` Conv-BN-ReLU ``` 64 →64 ``` 1×1 Conv Output: ``` (B,1,32,32) ``` Boundary supervision is generated automatically using morphological gradient from semantic masks. No additional annotation required. --- ## Gated Refinement Semantic feature + Boundary feature ↓ Concatenation ↓ Convolution refinement ↓ Final refined prediction: ``` (B,18,32,32) ``` --- # 6. Finetuning Loss Original STCLN: ``` Cross Entropy ``` Swin-STCLN: ``` Total Loss = CE(semantic output) + CE(refined output) + 0.5 × BCE(boundary output) ``` The boundary loss improves separation between neighbouring crop parcels. Boundary weight = 0.5 --- # ⚡ Inference Speed Measured on AMD MI300X GPU. | Batch Size | Time (ms) | Throughput | VRAM Used | |-----------|-----------|------------|-----------| | 1 | 7.8 ms | 128.9 patches/sec | 1.12 GB | | 4 | 19.6 ms | 203.8 patches/sec | 2.08 GB | | 8 | 37.1 ms | 215.4 patches/sec | 3.40 GB | | 16 | 73.4 ms | 218.0 patches/sec | 6.03 GB | | 32 | 143.0 ms | 223.8 patches/sec | 11.31 GB | | 64 | 281.0 ms | 227.8 patches/sec | 21.85 GB | --- # 📁 Repository Structure ``` Swin-STCLN-PASTIS/ ├── models/ │ ├── swin_encoder.py │ # PatchEmbed + Swin Blocks + PatchMerging │ ├── temporal_encoder.py │ # STCLN Transformer Encoder │ ├── stfusion.py │ # Cross-scale attention fusion │ ├── decoder.py │ # Semantic decoder │ # Boundary decoder │ # Gated refinement │ ├── reconstruction.py │ # ConvTranspose reconstruction head │ └── swin_stcln.py │ # Complete architecture ├── datasets/ │ └── pastis_dataset.py ├── losses/ │ ├── segmentation_loss.py │ └── boundary_loss.py ├── evaluation/ │ └── metrics.py ├── train.py ├── pretrain.py ├── finetune.py ├── visualize_results.py ├── checkpoints/ └── results/ ``` --- # 🚀 Quick Start ## Installation ```bash git clone https://huggingface.co/Dhruv1000/Swin-STCLN-PASTIS cd Swin-STCLN-PASTIS pip install torch torchvision timm einops geopandas matplotlib scikit-learn ``` --- # Training Single fold: ```bash python train.py \ --data_root /path/to/PASTIS \ --fold 1 \ --epochs 100 \ --batch_size 16 \ --lr 5e-5 \ --weight_decay 0.05 \ --warmup_iters 500 \ --num_workers 4 \ --amp \ --work_dir ./work_dirs/fold1 ``` --- # 5-Fold Cross Validation ```bash for fold in 1 2 3 4 5 do python train.py \ --data_root /path/to/PASTIS \ --fold $fold \ --epochs 100 \ --batch_size 16 \ --lr 5e-5 \ --work_dir ./work_dirs/fold${fold} done ``` --- # Inference ```python import torch from models.swin_stcln import build_swin_stcln model = build_swin_stcln( num_classes=18 ) checkpoint = torch.load( "checkpoints/best_model.pth", weights_only=False ) model.load_state_dict( checkpoint["model"] ) model.eval() # Input: # Batch # Time # Sentinel-2 Bands # Height # Width x = torch.randn( 1, 32, 10, 32, 32 ) logits = model(x) # Output: # (1,18,32,32) prediction = logits.argmax(dim=1) ``` --- # 📋 Training Configuration | Parameter | Value | |-|-| | Model | Swin-STCLN | | Spatial Encoder | Swin Transformer | | Temporal Encoder | STCLN Transformer Encoder | | Fusion | Cross-scale STFusion | | Optimizer | AdamW β=(0.9,0.999) | | Learning rate | 5e-5 | | Weight decay | 0.05 | | Schedule | Warmup 500 iters + cosine decay | | Batch size | 16 | | Epochs | 100 | | AMP | Enabled | | Gradient clipping | max_norm=5.0 | | Loss | CE + CE + 0.5 BCE | | Input bands | Sentinel-2 10 bands | | Input size | 32×32 | | Classes | 18 | --- # 📦 Dataset — PASTIS The model is evaluated on the PASTIS (Panoptic Agricultural Satellite Time Series) benchmark. | Property | Details | |-|-| | Total patches | 2,433 geo-referenced tiles | | Satellite | Sentinel-2 | | Spectral bands | 10 | | Temporal observations | 61 | | Input crop | 32×32 pixels | | Classes | 18 crop classes | | Splits | Official 5-fold geographic split | | Pretraining data | 7936 unlabeled patches | | Label fraction | 2% labelled setting | The official train / validation / test split is preserved. --- # 🌾 PASTIS Classes | ID | Class | Avg IoU | |-|-|-| | 0 | Background | 50.98% | | 1 | Meadow | 56.69% | | 2 | Soft winter wheat | 71.47% | | 3 | Corn | 75.90% | | 4 | Winter barley | 58.64% | | 5 | Winter rapeseed | 75.55% | | 6 | Spring barley | 35.08% | | 7 | Sunflower | 54.87% | | 8 | Grapevine | 36.70% | | 9 | Beet | 73.90% | | 10 | Winter triticale | 22.34% | | 11 | Winter durum wheat | 44.56% | | 12 | Fruits/veg/flowers | 21.53% | | 13 | Potatoes | 36.34% | | 14 | Leguminous fodder | 22.41% | | 15 | Soybeans | 61.48% | | 16 | Orchard | 17.27% | | 17 | Mixed cereal | 14.66% | | 18 | Sorghum | 13.91% | --- # 🔥 What Remains Identical To Original STCLN The following components are unchanged: ✅ Pretrain → finetune workflow ✅ Spatiotemporal masked reconstruction ✅ Mask ratio: ``` 0.4 ``` ✅ Reconstruction objective: ``` Mean Squared Error ``` ✅ Temporal Encoder design: - 3 Transformer layers - 8 attention heads - sinusoidal positional encoding - GroupNorm ✅ Dataset protocol: - PASTIS32 - Sentinel-2 - 10 spectral channels - Official folds ✅ Weight transfer: ``` Pretrained: SEncoder + TEncoder ↓ Finetuning initialization ``` --- # 🆚 STCLN vs Swin-STCLN | Component | STCLN | Swin-STCLN | |-|-|-| | Spatial Encoder | DoubleConv CNN | Swin Transformer | | Spatial hierarchy | Single scale | Multi scale | | Feature output | 256 @32×32 | 96@16×16 + 192@8×8 | | Spatial attention | Local CNN | Window self-attention | | Temporal input | 1024 spatial tokens | 64 semantic tokens | | TEncoder | Transformer | Same Transformer | | Fusion | STA | Cross-scale STFusion | | Reconstruction | Linear | ConvTranspose decoder | | Decoder | Linear classifier | Semantic + Boundary Decoder | | Boundary learning | No | Yes | | Final refinement | No | Gated refinement | | Loss | CE | CE + CE + Boundary BCE | --- # 📈 Training Dynamics (Fold 1) | Epoch | Train Loss | Val Loss | mFscore | mIoU | Kappa | |-|-|-|-|-|-| | 1 | 0.878 | 0.671 | 2.79% | 1.47% | 2.39% | | 4 | 0.431 | 0.472 | 20.08% | 12.34% | 10.76% | | 10 | 0.320 | 0.380 | ~33% | ~22% | ~24% | | 18 | 0.222 | 0.323 | 35.55% | 24.17% | 24.93% | | 55 | 0.083 | 0.363 | 53.37% | 39.92% | 53.16% | | 92 | 0.050 | 0.350 | 58.20% | 44.20% | 60.0% | | 100 | 0.048 | 0.360 | 57.90% | 44.10% | 59.8% | Best checkpoint: ``` Epoch 92 ``` Total training time per fold: ``` ~32 minutes on AMD MI300X ``` --- # 🖼️ Visualizations Generated evaluation plots: ## Per Fold ``` results/fold{N}/plots/ ``` Includes: - Training curves - Per-class IoU plots - Metrics radar chart - Confusion matrix - Prediction maps - Error maps - IoU scatter analysis - Overfitting analysis --- # 🔧 Implementation Details ## Pure PyTorch Implementation No dependency on: - MMSegmentation - MMEngine - external training framework Implemented with: - PyTorch - timm - einops --- # Major Architectural Changes ## 1. Hierarchical Spatial Learning CNN encoder replaced by Swin Transformer blocks. Benefits: - larger receptive field - window attention - shifted window information exchange --- ## 2. Efficient Temporal Modelling Instead of: ``` 1024 temporal sequences/image ``` Swin-STCLN uses: ``` 64 temporal sequences/image ``` This reduces memory while keeping semantic information. --- ## 3. Cross Scale Feature Recovery The coarse temporal representation loses boundaries. STFusion restores: - high level semantics - fine spatial details using cross attention fusion. --- ## 4. Boundary Aware Learning Boundary decoder learns crop separation using automatically generated masks. No manual boundary annotation needed. --- # 📚 Citation If this implementation is useful, cite the original STCLN work and PASTIS benchmark. ```bibtex @inproceedings{garnot2021pastis, title={Panoptic Segmentation of Satellite Image Time Series with Convolutional Temporal Attention Networks}, author={Garnot, Vivien Sainte Fare and Landrieu, Loic}, booktitle={ICCV}, year={2021} } ``` For Swin Transformer: ```bibtex @inproceedings{liu2021swin, title={Swin Transformer: Hierarchical Vision Transformer using Shifted Windows}, author={Liu, Ze and Lin, Yutong and Cao, Yue and others}, booktitle={ICCV}, year={2021} } ``` --- # 📄 License Apache-2.0 --- Trained on AMD MI300X ROCm 7.0 PyTorch 2.x **Swin-STCLN × PASTIS** Hierarchical Swin Spatial Encoder + STCLN Temporal Encoder + Cross Scale Boundary-Aware Fusion