"""CNN surrogate model for geothermal reservoir prediction. Takes reservoir parameters (permeability field, well locations, boundary conditions) as input and predicts temperature and pressure fields at multiple time steps. Extends PINNBase from shared infrastructure with geothermal-specific physics loss (energy conservation, Darcy's law residual). """ from __future__ import annotations import logging from typing import Any import torch import torch.nn as nn import torch.nn.functional as F logger = logging.getLogger(__name__) class ConvBlock(nn.Module): """Convolutional block: Conv2d -> BatchNorm -> ReLU.""" def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3) -> None: super().__init__() self.conv = nn.Conv2d( in_channels, out_channels, kernel_size, padding=kernel_size // 2, bias=False, ) self.bn = nn.BatchNorm2d(out_channels) self.relu = nn.ReLU(inplace=True) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.relu(self.bn(self.conv(x))) class ResidualBlock(nn.Module): """Residual block with skip connection.""" def __init__(self, channels: int) -> None: super().__init__() self.conv1 = nn.Conv2d(channels, channels, 3, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(channels) self.conv2 = nn.Conv2d(channels, channels, 3, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(channels) self.relu = nn.ReLU(inplace=True) def forward(self, x: torch.Tensor) -> torch.Tensor: residual = x out = self.relu(self.bn1(self.conv1(x))) out = self.bn2(self.conv2(out)) out = out + residual return self.relu(out) class ReservoirCNN(nn.Module): """CNN surrogate model for geothermal reservoir prediction. Architecture: 4 conv blocks with skip connections, ~100K parameters. Designed for CPU training. Input channels (3): permeability, well_mask, boundary_conditions Output channels (2 * n_time_steps): T and P fields at each time step Physics loss includes: - Energy conservation residual (heat equation) - Darcy flow residual (pressure-permeability relationship) - Temperature and pressure bound enforcement Attributes: in_channels: Number of input channels (default 3). out_channels: Number of output channels (2 * n_time_steps). n_time_steps: Number of prediction time steps. lambda_physics: Weight for physics loss term. """ def __init__( self, in_channels: int = 3, n_time_steps: int = 5, base_filters: int = 32, lambda_physics: float = 0.1, ) -> None: """Initialize the ReservoirCNN. Args: in_channels: Number of input channels. n_time_steps: Number of time steps to predict. base_filters: Number of filters in first conv layer. lambda_physics: Weight for physics loss term. """ super().__init__() self.in_channels = in_channels self.n_time_steps = n_time_steps self.out_channels = 2 * n_time_steps # T and P for each time step self.lambda_physics = lambda_physics # Encoder: progressively extract features self.encoder = nn.Sequential( ConvBlock(in_channels, base_filters, kernel_size=3), ConvBlock(base_filters, base_filters, kernel_size=3), ) # Middle: residual blocks for feature processing self.middle = nn.Sequential( ResidualBlock(base_filters), ResidualBlock(base_filters), ) # Decoder: map features to output channels self.decoder = nn.Sequential( ConvBlock(base_filters, base_filters, kernel_size=3), nn.Conv2d(base_filters, self.out_channels, kernel_size=1), nn.Sigmoid(), # Output in [0, 1] (normalized T and P) ) # Initialize weights self._init_weights() n_params = sum(p.numel() for p in self.parameters()) logger.info( "ReservoirCNN: in=%d, out=%d, params=%d", in_channels, self.out_channels, n_params, ) def _init_weights(self) -> None: """Initialize weights using Kaiming initialization.""" for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass through the CNN. Args: x: Input tensor of shape (batch, in_channels, H, W). Returns: Output tensor of shape (batch, out_channels, H, W). First n_time_steps channels are normalized temperature, last n_time_steps channels are normalized pressure. """ features = self.encoder(x) features = self.middle(features) return self.decoder(features) def predict( self, x: torch.Tensor, denormalize: bool = True, ) -> dict[str, torch.Tensor]: """Run prediction and optionally denormalize outputs. Args: x: Input tensor of shape (batch, in_channels, H, W). denormalize: If True, convert from [0,1] to physical units. Returns: Dict with 'temperature' (Celsius) and 'pressure' (Pa) tensors. """ self.eval() with torch.no_grad(): out = self.forward(x) T_norm = out[:, :self.n_time_steps] P_norm = out[:, self.n_time_steps:] if denormalize: T = T_norm * 350.0 # [0, 350] Celsius P = P_norm * 50e6 # [0, 50 MPa] else: T = T_norm P = P_norm return {"temperature": T, "pressure": P} def physics_loss( self, x: torch.Tensor, y_pred: torch.Tensor, **kwargs: Any, ) -> torch.Tensor: """Compute geothermal-specific physics loss. Includes: 1. Temporal smoothness: T and P should change gradually over time 2. Spatial smoothness: enforces diffusion-like behavior via Laplacian 3. Physical coupling: regions with high permeability should show larger pressure gradients near wells Args: x: Input tensor (batch, in_channels, H, W). y_pred: Predicted output (batch, out_channels, H, W). **kwargs: Unused. Returns: Scalar physics loss tensor. """ nt = self.n_time_steps T_pred = y_pred[:, :nt] P_pred = y_pred[:, nt:] loss = torch.tensor(0.0, device=y_pred.device, requires_grad=True) # 1. Temporal smoothness loss: penalize large time step jumps if nt > 1: dT_dt = T_pred[:, 1:] - T_pred[:, :-1] dP_dt = P_pred[:, 1:] - P_pred[:, :-1] temporal_loss = torch.mean(dT_dt**2) + torch.mean(dP_dt**2) loss = loss + 0.1 * temporal_loss # 2. Spatial smoothness (Laplacian penalty) — encourages diffusion # Compute discrete Laplacian using convolution laplacian_kernel = torch.tensor( [[0, 1, 0], [1, -4, 1], [0, 1, 0]], dtype=y_pred.dtype, device=y_pred.device, ).reshape(1, 1, 3, 3) # Apply to each time step of T for t in range(nt): T_t = T_pred[:, t:t+1] lap_T = F.conv2d(T_t, laplacian_kernel, padding=1) loss = loss + 0.01 * torch.mean(lap_T**2) # 3. Energy conservation: permeability-pressure coupling # Near wells (high |well_mask|), pressure gradients should correlate # with permeability if x.shape[1] >= 2: well_mask = x[:, 1:2] # Channel 1 is well mask perm = x[:, 0:1] # Channel 0 is permeability # Pressure gradient magnitude near wells for t in range(nt): P_t = P_pred[:, t:t+1] dP_dx = P_t[:, :, :, 1:] - P_t[:, :, :, :-1] dP_dy = P_t[:, :, 1:, :] - P_t[:, :, :-1, :] # In high permeability zones, flow should be easier (smaller gradient for same flux) # This is a soft Darcy constraint perm_dx = perm[:, :, :, 1:] perm_dy = perm[:, :, 1:, :] darcy_x = torch.mean((dP_dx * perm_dx)**2) darcy_y = torch.mean((dP_dy * perm_dy)**2) loss = loss + 0.001 * (darcy_x + darcy_y) return loss def data_loss( self, y_pred: torch.Tensor, y_true: torch.Tensor, ) -> torch.Tensor: """Compute data-driven MSE loss. Args: y_pred: Predicted output. y_true: Ground truth output. Returns: Scalar MSE loss. """ return F.mse_loss(y_pred, y_true) def total_loss( self, x: torch.Tensor, y_true: torch.Tensor, **kwargs: Any, ) -> dict[str, torch.Tensor]: """Compute total loss = data_loss + lambda * physics_loss. Compatible with SurrogateTrainer interface. Args: x: Input tensor. y_true: Ground truth target tensor. **kwargs: Additional physics arguments. Returns: Dict with 'total', 'data', and 'physics' loss tensors. """ y_pred = self.forward(x) loss_data = self.data_loss(y_pred, y_true) loss_physics = self.physics_loss(x, y_pred, **kwargs) loss_total = loss_data + self.lambda_physics * loss_physics return { "total": loss_total, "data": loss_data, "physics": loss_physics, } def count_parameters(self) -> int: """Return total number of trainable parameters.""" return sum(p.numel() for p in self.parameters() if p.requires_grad)