""" Swin-STCLN Training Script for PASTIS Architecture: Swin SEncoder + STCLN Transformer TEncoder + Cross Scale STFusion + Semantic Decoder + Boundary Decoder + Gated Refinement """ import os import sys import time import json import math import argparse from pathlib import Path from datetime import datetime,timedelta import numpy as np import torch import torch.nn as nn from torch.utils.data import DataLoader from torch.amp import ( GradScaler, autocast ) sys.path.insert( 0, str(Path(__file__).parent) ) # =============================== # PROJECT IMPORTS # =============================== from models.swin_stcln import build_swin_stcln from datasets.pastis_dataset import ( PASTISDataset, IGNORE_INDEX, PASTIS_CLASSES ) from losses.swin_stcln_loss import ( SwinSTCLNLoss ) from evaluation.metrics import ( SegmentationMetrics ) # =============================== # COLORS # =============================== GREEN="\033[92m" CYAN="\033[96m" RESET="\033[0m" def fmt_time(x): return str( timedelta(seconds=int(x)) ) # =============================== # LR Scheduler # =============================== class WarmupCosineScheduler: def __init__( self, optimizer, warmup_iters, total_iters, base_lr, min_lr=1e-6 ): self.optimizer=optimizer self.warmup=warmup_iters self.total=total_iters self.lr=base_lr self.min=min_lr self.i=0 def step(self): self.i+=1 if self.i < self.warmup: lr=self.min+(self.lr-self.min)*( self.i/self.warmup ) else: p=( self.i-self.warmup )/(self.total-self.warmup) lr=self.min+0.5*( self.lr-self.min )*(1+math.cos(math.pi*p)) for g in self.optimizer.param_groups: g["lr"]=lr return lr # =============================== # TRAIN # =============================== def train_epoch( model, loader, optimizer, scheduler, criterion, scaler, device, args ): model.train() total=0 for i,batch in enumerate(loader): s2=batch["S2"].to(device) label=batch["label"].to(device) optimizer.zero_grad( set_to_none=True ) with autocast( "cuda", enabled=args.amp ): outputs=model(s2) loss_dict=criterion( outputs, label ) loss=loss_dict["loss"] scaler.scale( loss ).backward() scaler.unscale_( optimizer ) nn.utils.clip_grad_norm_( model.parameters(), 5.0 ) scaler.step( optimizer ) scaler.update() scheduler.step() total+=loss.item() if i%50==0: print( f"iter {i}/{len(loader)} loss={loss.item():.4f}" ) return total/len(loader) # =============================== # VALIDATION # =============================== @torch.no_grad() def validate( model, loader, criterion, device, args ): model.eval() metrics=SegmentationMetrics( args.num_classes, IGNORE_INDEX ) loss_sum=0 for batch in loader: s2=batch["S2"].to(device) label=batch["label"].to(device) with autocast( "cuda", enabled=args.amp ): outputs=model(s2) loss=criterion( outputs, label )["loss"] logits=outputs["refined"] metrics.update( logits.float(), label ) loss_sum+=loss.item() result=metrics.compute() result["val_loss"]=loss_sum/len(loader) return result # =============================== # ARGS # =============================== def parse_args(): p=argparse.ArgumentParser( "Swin-STCLN PASTIS" ) p.add_argument( "--data_root", default="/workspace/project/PASTIS" ) p.add_argument( "--fold", type=int, default=1 ) p.add_argument( "--epochs", type=int, default=100 ) p.add_argument( "--batch_size", type=int, default=16 ) p.add_argument( "--lr", type=float, default=5e-5 ) p.add_argument( "--weight_decay", type=float, default=0.05 ) p.add_argument( "--warmup_iters", type=int, default=500 ) p.add_argument( "--num_workers", type=int, default=4 ) p.add_argument( "--num_frames", type=int, default=32 ) p.add_argument( "--num_classes", type=int, default=18 ) p.add_argument( "--amp", action="store_true", default=True ) p.add_argument( "--work_dir", default="./work_dirs/swin_stcln" ) return p.parse_args() # =============================== # MAIN # =============================== def main(): args=parse_args() os.makedirs( args.work_dir, exist_ok=True ) device=torch.device( "cuda" if torch.cuda.is_available() else "cpu" ) print( CYAN+ "Swin-STCLN × PASTIS Training" +RESET ) print( "Device:", device ) # DATASETS train_ds=PASTISDataset( args.data_root, fold=args.fold, split="train", num_frames=args.num_frames, augment=True ) val_ds=PASTISDataset( args.data_root, fold=args.fold, split="val", num_frames=args.num_frames, augment=False ) train_loader=DataLoader( train_ds, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True ) val_loader=DataLoader( val_ds, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True ) # MODEL model=build_swin_stcln( num_classes=args.num_classes, in_channels=10, embed_dim=96, temporal_dim=192 ) model=model.to(device) params=sum( p.numel() for p in model.parameters() )/1e6 print( f"Parameters: {params:.2f}M" ) # LOSS criterion=SwinSTCLNLoss( ignore_index=IGNORE_INDEX, boundary_weight=0.5 ) optimizer=torch.optim.AdamW( model.parameters(), lr=args.lr, weight_decay=args.weight_decay ) scheduler=WarmupCosineScheduler( optimizer, args.warmup_iters, args.epochs*len(train_loader), args.lr ) scaler=GradScaler( "cuda", enabled=args.amp ) best=0 for epoch in range(args.epochs): print( f"\nEpoch {epoch+1}/{args.epochs}" ) loss=train_epoch( model, train_loader, optimizer, scheduler, criterion, scaler, device, args ) val=validate( model, val_loader, criterion, device, args ) print( val ) score=val["mFscore"] if score>best: best=score torch.save( { "model":model.state_dict(), "epoch":epoch, "best_mFscore":best }, os.path.join( args.work_dir, "best_model.pth" ) ) print( GREEN+ f"Saved best {best:.2f}" +RESET ) print( "Training Finished" ) if __name__=="__main__": main()