Swin-STCLN Γ— PASTIS

Swin-STCLN: Hierarchical Swin Transformer Enhanced Spatio-Temporal Contrastive Learning Network for Crop Mapping

This repository implements an improved STCLN architecture for Sentinel-2 satellite image time-series semantic segmentation on the PASTIS benchmark.

The original STCLN pretrain β†’ finetune workflow is preserved while replacing the flat CNN spatial encoder with a hierarchical Swin Transformer encoder and introducing cross-scale spatiotemporal fusion with boundary-aware refinement.


πŸ† Results β€” 5-Fold Cross Validation

Summary

Metric Fold 1 Fold 2 Fold 3 Fold 4 Fold 5 Mean Β± Std
mFscore 58.37% 56.69% 59.18% 61.00% 56.49% 58.35% Β± 1.70%
mIoU 44.35% 42.90% 45.54% 46.80% 42.58% 44.43% Β± 1.62%
OA 67.10% 66.85% 69.96% 71.48% 66.41% 68.36% Β± 2.13%
Kappa 60.02% 59.71% 62.79% 64.33% 58.99% 61.17% Β± 2.19%
mPrecision 52.78% 51.33% 54.51% 57.40% 51.24% 53.45% Β± 2.47%
mRecall 70.72% 69.00% 68.69% 68.93% 68.72% 69.21% Β± 0.82%

Official benchmark result:

mFscore = 58.35% Β± 1.70%

Trained from scratch on AMD MI300X GPU.


πŸ“Š Per-Class IoU β€” All Folds

Class Mean IoU
🌽 Corn 75.90%
🌿 Winter rapeseed 75.55%
🌾 Beet 73.90%
🌾 Soft winter wheat 71.47%
🌿 Soybeans 61.48%
🌾 Winter barley 58.64%
🌱 Meadow 56.69%
🌻 Sunflower 54.87%
🟫 Background 50.98%
🌾 Winter durum wheat 44.56%
🌿 Grapevine 36.70%
πŸ₯” Potatoes 36.34%
🌾 Spring barley 35.08%
🌿 Leguminous fodder 22.41%
🌾 Winter triticale 22.34%
🍎 Fruits/veg/flowers 21.53%
πŸ‘ Orchard 17.27%
🌾 Mixed cereal 14.66%
🌿 Sorghum 13.91%

πŸ—οΈ Model Architecture β€” Swin-STCLN

1. Swin Spatial Encoder (SEncoder Replacement)

Original STCLN:

DoubleConv
+
DoubleConv

filters:
32 β†’ 256

kernel:
3Γ—3

Output:
(BΓ—T,256,32,32)

Problem:

  • flat representation
  • single scale
  • limited long range spatial modeling

Swin-STCLN replaces it:

Input:

(B,T,10,32,32)

Patch Embedding

Conv2D:

10 β†’ 96 channels

kernel = 2

stride = 2

Spatial:

32Γ—32 β†’ 16Γ—16

Output:

(BΓ—T,96,16,16)

Swin Stage 1

2 Swin Transformer Blocks

  • Window attention
  • Shifted window attention

Configuration:

dim = 96

window = 4

Output:

f_fine:

(BΓ—T,96,16,16)

Patch Merging

16Γ—16 β†’ 8Γ—8

Channel:

96 β†’ 192

Swin Stage 2

2 Swin Transformer Blocks

dim = 192

window = 4

Output:

f_coarse:

(BΓ—T,192,8,8)


2. Temporal Encoder

The original STCLN Transformer Temporal Encoder is kept unchanged.

Configuration:

  • Transformer layers: 3
  • Attention heads: 8
  • Sinusoidal positional encoding
  • GroupNorm

Only input representation changes.

Original STCLN:

(BΓ—32Γ—32) Γ— T Γ— 256

Swin-STCLN:

(BΓ—8Γ—8) Γ— T Γ— 192

Benefits:

  • reduces temporal tokens from 1024 β†’ 64 locations
  • temporal reasoning happens on semantic regions
  • lower memory usage

3. STFusion β€” Cross Scale Fusion

Replacement for STCLN STA module.

Inputs:

Temporal branch:

coarse_agg

(B,192,8,8)

Spatial branch:

fine_agg

(B,96,16,16)

Process:

  1. Upsample

8Γ—8 β†’16Γ—16

  1. Projection

192 β†’96

  1. Cross Attention

Query:

coarse semantic features

Key / Value:

fine spatial features

  1. Residual fusion

  2. Upsample

16Γ—16 β†’32Γ—32

Output:

(B,128,32,32)


4. Pretraining Reconstruction Decoder

The original STCLN self-supervised learning objective is preserved.

Unchanged:

  • Spatiotemporal masking
  • Mask ratio = 0.4
  • MSE reconstruction loss
  • Unlabeled pretraining patches = 7936

Difference:

Base STCLN reconstructs directly using a Linear layer because temporal features remain at 32Γ—32.

Swin-STCLN reconstructs from hierarchical features.

TEncoder output:

(B,T,192,8,8)

Reconstruction Head:

ConvTranspose2D

8Γ—8
 ↓
16Γ—16


ConvTranspose2D

16Γ—16
 ↓
32Γ—32


Conv2D

192 β†’ 10 Sentinel-2 bands

Final reconstruction:

(B,T,10,32,32)

The same masked pixel MSE objective is applied.


5. Finetuning Decoder

The original STCLN linear decoder is replaced with a boundary-aware segmentation decoder.

Input:

STFusion output

(B,128,32,32)

Semantic Decoder

Architecture:

Conv-BN-ReLU

128 β†’ 64

Conv-BN-ReLU

64 β†’64

1Γ—1 Conv classifier

Output:

(B,18,32,32)

Boundary Decoder

Parallel boundary prediction branch:

Conv-BN-ReLU

128 β†’64

Conv-BN-ReLU

64 β†’64

1Γ—1 Conv

Output:

(B,1,32,32)

Boundary supervision is generated automatically using morphological gradient from semantic masks.

No additional annotation required.


Gated Refinement

Semantic feature

Boundary feature

↓

Concatenation

↓

Convolution refinement

↓

Final refined prediction:

(B,18,32,32)

6. Finetuning Loss

Original STCLN:

Cross Entropy

Swin-STCLN:

Total Loss =

CE(semantic output)

+

CE(refined output)

+

0.5 Γ— BCE(boundary output)

The boundary loss improves separation between neighbouring crop parcels.

Boundary weight = 0.5


⚑ Inference Speed

Measured on AMD MI300X GPU.

Batch Size Time (ms) Throughput VRAM Used
1 7.8 ms 128.9 patches/sec 1.12 GB
4 19.6 ms 203.8 patches/sec 2.08 GB
8 37.1 ms 215.4 patches/sec 3.40 GB
16 73.4 ms 218.0 patches/sec 6.03 GB
32 143.0 ms 223.8 patches/sec 11.31 GB
64 281.0 ms 227.8 patches/sec 21.85 GB

πŸ“ Repository Structure

Swin-STCLN-PASTIS/

β”œβ”€β”€ models/

β”‚   β”œβ”€β”€ swin_encoder.py
β”‚       # PatchEmbed + Swin Blocks + PatchMerging

β”‚   β”œβ”€β”€ temporal_encoder.py
β”‚       # STCLN Transformer Encoder

β”‚   β”œβ”€β”€ stfusion.py
β”‚       # Cross-scale attention fusion

β”‚   β”œβ”€β”€ decoder.py
β”‚       # Semantic decoder
β”‚       # Boundary decoder
β”‚       # Gated refinement

β”‚   β”œβ”€β”€ reconstruction.py
β”‚       # ConvTranspose reconstruction head

β”‚   └── swin_stcln.py
β”‚       # Complete architecture


β”œβ”€β”€ datasets/

β”‚   └── pastis_dataset.py


β”œβ”€β”€ losses/

β”‚   β”œβ”€β”€ segmentation_loss.py
β”‚   └── boundary_loss.py


β”œβ”€β”€ evaluation/

β”‚   └── metrics.py


β”œβ”€β”€ train.py

β”œβ”€β”€ pretrain.py

β”œβ”€β”€ finetune.py

β”œβ”€β”€ visualize_results.py

β”œβ”€β”€ checkpoints/

└── results/

πŸš€ Quick Start

Installation

git clone https://huggingface.co/Dhruv1000/Swin-STCLN-PASTIS

cd Swin-STCLN-PASTIS


pip install torch torchvision timm einops geopandas matplotlib scikit-learn

Training

Single fold:

python train.py \
    --data_root /path/to/PASTIS \
    --fold 1 \
    --epochs 100 \
    --batch_size 16 \
    --lr 5e-5 \
    --weight_decay 0.05 \
    --warmup_iters 500 \
    --num_workers 4 \
    --amp \
    --work_dir ./work_dirs/fold1

5-Fold Cross Validation

for fold in 1 2 3 4 5
do

python train.py \
    --data_root /path/to/PASTIS \
    --fold $fold \
    --epochs 100 \
    --batch_size 16 \
    --lr 5e-5 \
    --work_dir ./work_dirs/fold${fold}

done

Inference

import torch

from models.swin_stcln import build_swin_stcln


model = build_swin_stcln(
    num_classes=18
)


checkpoint = torch.load(
    "checkpoints/best_model.pth",
    weights_only=False
)


model.load_state_dict(
    checkpoint["model"]
)


model.eval()



# Input:
# Batch
# Time
# Sentinel-2 Bands
# Height
# Width


x = torch.randn(
    1,
    32,
    10,
    32,
    32
)


logits = model(x)


# Output:

# (1,18,32,32)


prediction = logits.argmax(dim=1)

πŸ“‹ Training Configuration

Parameter Value
Model Swin-STCLN
Spatial Encoder Swin Transformer
Temporal Encoder STCLN Transformer Encoder
Fusion Cross-scale STFusion
Optimizer AdamW Ξ²=(0.9,0.999)
Learning rate 5e-5
Weight decay 0.05
Schedule Warmup 500 iters + cosine decay
Batch size 16
Epochs 100
AMP Enabled
Gradient clipping max_norm=5.0
Loss CE + CE + 0.5 BCE
Input bands Sentinel-2 10 bands
Input size 32Γ—32
Classes 18

πŸ“¦ Dataset β€” PASTIS

The model is evaluated on the PASTIS (Panoptic Agricultural Satellite Time Series) benchmark.

Property Details
Total patches 2,433 geo-referenced tiles
Satellite Sentinel-2
Spectral bands 10
Temporal observations 61
Input crop 32Γ—32 pixels
Classes 18 crop classes
Splits Official 5-fold geographic split
Pretraining data 7936 unlabeled patches
Label fraction 2% labelled setting

The official train / validation / test split is preserved.


🌾 PASTIS Classes

ID Class Avg IoU
0 Background 50.98%
1 Meadow 56.69%
2 Soft winter wheat 71.47%
3 Corn 75.90%
4 Winter barley 58.64%
5 Winter rapeseed 75.55%
6 Spring barley 35.08%
7 Sunflower 54.87%
8 Grapevine 36.70%
9 Beet 73.90%
10 Winter triticale 22.34%
11 Winter durum wheat 44.56%
12 Fruits/veg/flowers 21.53%
13 Potatoes 36.34%
14 Leguminous fodder 22.41%
15 Soybeans 61.48%
16 Orchard 17.27%
17 Mixed cereal 14.66%
18 Sorghum 13.91%

πŸ”₯ What Remains Identical To Original STCLN

The following components are unchanged:

βœ… Pretrain β†’ finetune workflow

βœ… Spatiotemporal masked reconstruction

βœ… Mask ratio:

0.4

βœ… Reconstruction objective:

Mean Squared Error

βœ… Temporal Encoder design:

  • 3 Transformer layers
  • 8 attention heads
  • sinusoidal positional encoding
  • GroupNorm

βœ… Dataset protocol:

  • PASTIS32
  • Sentinel-2
  • 10 spectral channels
  • Official folds

βœ… Weight transfer:

Pretrained:

SEncoder
+
TEncoder


        ↓


Finetuning initialization

πŸ†š STCLN vs Swin-STCLN

Component STCLN Swin-STCLN
Spatial Encoder DoubleConv CNN Swin Transformer
Spatial hierarchy Single scale Multi scale
Feature output 256 @32Γ—32 96@16Γ—16 + 192@8Γ—8
Spatial attention Local CNN Window self-attention
Temporal input 1024 spatial tokens 64 semantic tokens
TEncoder Transformer Same Transformer
Fusion STA Cross-scale STFusion
Reconstruction Linear ConvTranspose decoder
Decoder Linear classifier Semantic + Boundary Decoder
Boundary learning No Yes
Final refinement No Gated refinement
Loss CE CE + CE + Boundary BCE

πŸ“ˆ Training Dynamics (Fold 1)

Epoch Train Loss Val Loss mFscore mIoU Kappa
1 0.878 0.671 2.79% 1.47% 2.39%
4 0.431 0.472 20.08% 12.34% 10.76%
10 0.320 0.380 ~33% ~22% ~24%
18 0.222 0.323 35.55% 24.17% 24.93%
55 0.083 0.363 53.37% 39.92% 53.16%
92 0.050 0.350 58.20% 44.20% 60.0%
100 0.048 0.360 57.90% 44.10% 59.8%

Best checkpoint:

Epoch 92

Total training time per fold:

~32 minutes on AMD MI300X

πŸ–ΌοΈ Visualizations

Generated evaluation plots:

Per Fold

results/fold{N}/plots/

Includes:

  • Training curves
  • Per-class IoU plots
  • Metrics radar chart
  • Confusion matrix
  • Prediction maps
  • Error maps
  • IoU scatter analysis
  • Overfitting analysis

πŸ”§ Implementation Details

Pure PyTorch Implementation

No dependency on:

  • MMSegmentation
  • MMEngine
  • external training framework

Implemented with:

  • PyTorch
  • timm
  • einops

Major Architectural Changes

1. Hierarchical Spatial Learning

CNN encoder replaced by Swin Transformer blocks.

Benefits:

  • larger receptive field
  • window attention
  • shifted window information exchange

2. Efficient Temporal Modelling

Instead of:

1024 temporal sequences/image

Swin-STCLN uses:

64 temporal sequences/image

This reduces memory while keeping semantic information.


3. Cross Scale Feature Recovery

The coarse temporal representation loses boundaries.

STFusion restores:

  • high level semantics
  • fine spatial details

using cross attention fusion.


4. Boundary Aware Learning

Boundary decoder learns crop separation using automatically generated masks.

No manual boundary annotation needed.


πŸ“š Citation

If this implementation is useful, cite the original STCLN work and PASTIS benchmark.

@inproceedings{garnot2021pastis,

title={Panoptic Segmentation of Satellite Image Time Series with Convolutional Temporal Attention Networks},

author={Garnot, Vivien Sainte Fare and Landrieu, Loic},

booktitle={ICCV},

year={2021}

}

For Swin Transformer:

@inproceedings{liu2021swin,

title={Swin Transformer: Hierarchical Vision Transformer using Shifted Windows},

author={Liu, Ze and Lin, Yutong and Cao, Yue and others},

booktitle={ICCV},

year={2021}

}

πŸ“„ License

Apache-2.0


Trained on AMD MI300X
ROCm 7.0
PyTorch 2.x

Swin-STCLN Γ— PASTIS

Hierarchical Swin Spatial Encoder
+ STCLN Temporal Encoder
+ Cross Scale Boundary-Aware Fusion

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Evaluation results