# Copyright (c) 2026 Islam I. Abdulaal (Alexandria University, es-eslam.ibrahim2026@alexu.edu.eg) # and Omar A. M. Abdelraouf (IMRE, A*STAR, omar_abdelrahman@a-star.edu.sg). All rights reserved. # # Licensed under the MIT License. Model checkpoints are released under CC BY 4.0. from __future__ import annotations import argparse import math from pathlib import Path import numpy as np import torch import torch.nn as nn from safetensors.torch import load_file class RFFEmbedding(nn.Module): def __init__(self, in_features: int = 2, out_features: int = 512, sigma: float = 10.0): super().__init__() self.proj_dim = out_features // 2 self.register_buffer("B", torch.randn(self.proj_dim, in_features) * sigma) def forward(self, r: torch.Tensor) -> torch.Tensor: projected = torch.matmul(r, self.B.t()) return torch.cat([torch.sin(projected), torch.cos(projected)], dim=-1) / math.sqrt( self.proj_dim ) class MLPBlock(nn.Module): def __init__(self, in_dim: int, out_dim: int): super().__init__() self.linear = nn.Linear(in_dim, out_dim) self.ln = nn.LayerNorm(out_dim) self.activation = nn.GELU() def forward(self, x: torch.Tensor) -> torch.Tensor: return self.activation(self.ln(self.linear(x))) class BranchNet(nn.Module): def __init__(self, in_features: int = 5, hidden_dim: int = 256, synthesis_dim: int = 128): super().__init__() self.mlp = nn.Sequential( MLPBlock(in_features, hidden_dim), MLPBlock(hidden_dim, hidden_dim), MLPBlock(hidden_dim, hidden_dim), ) self.proj_fields = nn.Linear(hidden_dim, 6 * synthesis_dim) self.index_head = nn.Sequential( nn.Linear(hidden_dim, 64), nn.LayerNorm(64), nn.GELU(), nn.Linear(64, 1) ) def forward(self, p: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: features = self.mlp(p) return self.proj_fields(features), self.index_head(features) class TrunkNet(nn.Module): def __init__( self, in_features: int = 2, embedding_dim: int = 512, hidden_dim: int = 512, synthesis_dim: int = 128, rff_sigma: float = 10.0, ): super().__init__() self.rff = RFFEmbedding(in_features, embedding_dim, sigma=rff_sigma) self.mlp = nn.Sequential( MLPBlock(embedding_dim, hidden_dim), MLPBlock(hidden_dim, hidden_dim), MLPBlock(hidden_dim, synthesis_dim), ) def forward(self, r: torch.Tensor) -> torch.Tensor: scale = torch.tensor([30.0, 20.0], device=r.device, dtype=r.dtype) return self.mlp(self.rff(r / scale)) class EigenmodeDeepONet(nn.Module): def __init__( self, branch_in: int = 5, trunk_in: int = 2, hidden_dim: int = 256, synthesis_dim: int = 128, rff_sigma: float = 10.0, ): super().__init__() self.synthesis_dim = synthesis_dim self.branch = BranchNet(branch_in, hidden_dim, synthesis_dim) self.trunk = TrunkNet(trunk_in, 512, 512, synthesis_dim, rff_sigma=rff_sigma) def forward(self, p: torch.Tensor, r: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: branch_feats, n_eff = self.branch(p) if r.dim() == 3 and p.dim() == 2: batch_size, _, _ = r.shape trunk_feats = self.trunk(r[0]).unsqueeze(0).expand(batch_size, -1, -1) branch_fields = branch_feats.view(batch_size, 6, self.synthesis_dim) field_profile = torch.tanh(torch.bmm(branch_fields, trunk_feats.transpose(1, 2))) return field_profile, n_eff trunk_feats = self.trunk(r) branch_fields = branch_feats.view(-1, 6, self.synthesis_dim) field_profile = torch.tanh(torch.sum(branch_fields * trunk_feats.unsqueeze(1), dim=-1)) return field_profile, n_eff class CWEPINN(nn.Module): def __init__(self, in_features: int = 9, hidden_dim: int = 128, out_features: int = 6): super().__init__() self.length_scale = 0.005 self.output_scale = 0.1 self.mlp = nn.Sequential( nn.Linear(in_features, hidden_dim), nn.Tanh(), nn.Linear(hidden_dim, hidden_dim), nn.Tanh(), nn.Linear(hidden_dim, hidden_dim), nn.Tanh(), nn.Linear(hidden_dim, out_features), ) def forward(self, state: torch.Tensor) -> torch.Tensor: scaled = state.clone() scaled[:, 0:1] = torch.clamp(state[:, 0:1], 0.0, 5.0) scaled[:, 1:2] = torch.clamp(state[:, 1:2] * 1e-6, -5.0, 5.0) scaled[:, 2:5] = torch.clamp(state[:, 2:5] / 100.0, 0.0, 5.0) scaled[:, 5:8] = torch.clamp(state[:, 5:8], -2.0, 2.0) z = torch.clamp(state[:, 8:9], min=0.0) scaled[:, 8:9] = torch.clamp(z / self.length_scale, 0.0, 4.0) raw = self.mlp(scaled) alpha = torch.clamp(state[:, 2:5], min=0.0, max=500.0) ic = state[:, 5:8] decay = torch.exp(-0.5 * alpha * z) zeros = torch.zeros_like(z) base = torch.cat( [ ic[:, 0:1] * decay[:, 0:1], zeros, ic[:, 1:2] * decay[:, 1:2], zeros, ic[:, 2:3] * decay[:, 2:3], zeros, ], dim=-1, ) gate = torch.tanh(z / self.length_scale) return base + gate * self.output_scale * raw def compute_overlap_integral(e_pump: torch.Tensor, e_signal: torch.Tensor) -> torch.Tensor: dim_sum = tuple(range(1, e_pump.dim())) numerator = torch.sum(e_pump * torch.conj(e_signal), dim=dim_sum) power_pump = torch.sum(torch.abs(e_pump) ** 2, dim=dim_sum) power_signal = torch.sum(torch.abs(e_signal) ** 2, dim=dim_sum) denom = torch.clamp(power_pump * power_signal, min=1e-24) overlap_factor = torch.abs(numerator) ** 2 / denom return torch.sqrt(torch.clamp(overlap_factor, min=0.0, max=1.0)) def physics_converter( fields_pump: torch.Tensor, fields_signal: torch.Tensor, n_eff_pump: torch.Tensor, n_eff_signal: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: two_pi = 2.0 * math.pi wl_p = 7.75e-7 wl_s = 1.55e-6 coupling = compute_overlap_integral(fields_pump, fields_signal) beta_p = two_pi / wl_p * n_eff_pump beta_s = two_pi / wl_s * n_eff_signal delta_beta = beta_p - 2.0 * beta_s if coupling.dim() == 1: coupling = coupling.unsqueeze(-1) return coupling, delta_beta def alpha_db_cm_to_np_per_m(alpha_db_cm: torch.Tensor) -> torch.Tensor: factor = 100.0 * torch.log(torch.tensor(10.0, device=alpha_db_cm.device)) / 10.0 return alpha_db_cm * factor.to(dtype=alpha_db_cm.dtype) def load_transverse(path: Path, device: torch.device) -> EigenmodeDeepONet: model = EigenmodeDeepONet().to(device) model.load_state_dict(load_file(str(path), device=str(device)), strict=True) model.eval() return model def load_propagator(path: Path, device: torch.device) -> CWEPINN: model = CWEPINN().to(device) model.load_state_dict(load_file(str(path), device=str(device)), strict=True) model.eval() return model def main() -> None: parser = argparse.ArgumentParser( description="Run NanoPhotoNet-MPM inference from safetensors checkpoints." ) parser.add_argument("--width", type=float, default=693.49, help="Waveguide width in nm") parser.add_argument("--height", type=float, default=200.0, help="Waveguide height in nm") parser.add_argument("--length", type=float, default=5.0, help="Interaction length in mm") parser.add_argument( "--transverse-checkpoint", type=Path, default=Path("checkpoints/transverse_solver.safetensors"), help="Path to the transverse solver safetensors file", ) parser.add_argument( "--propagator-checkpoint", type=Path, default=Path("checkpoints/propagator.safetensors"), help="Path to the longitudinal propagator safetensors file", ) args = parser.parse_args() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") transverse_solver = load_transverse(args.transverse_checkpoint, device) propagator = load_propagator(args.propagator_checkpoint, device) x_norm = np.linspace(-30.0, 30.0, 128) y_norm = np.linspace(-20.0, 20.0, 128) xx, yy = np.meshgrid(x_norm, y_norm) coords_np = np.stack([xx.reshape(-1), yy.reshape(-1)], axis=-1) coords = torch.tensor(coords_np, dtype=torch.float32, device=device).unsqueeze(0) geom_p = torch.tensor( [[args.width / 1000.0, args.height / 1000.0, 1.0, 1.0, 1.0]], dtype=torch.float32, device=device, ) geom_s = torch.tensor( [[args.width / 1000.0, args.height / 1000.0, 2.0, 0.0, 1.0]], dtype=torch.float32, device=device, ) with torch.no_grad(): fields_p, n_eff_p = transverse_solver(geom_p, coords) fields_s, n_eff_s = transverse_solver(geom_s, coords) coupling, delta_beta = physics_converter(fields_p, fields_s, n_eff_p, n_eff_s) z_val = torch.tensor([[args.length / 1000.0]], dtype=torch.float32, device=device) alpha_db = torch.tensor([[5.0, 2.0, 2.0]], dtype=torch.float32, device=device) alpha_vals = alpha_db_cm_to_np_per_m(alpha_db) ic_vals = torch.tensor([[1.0, 0.001, 0.001]], dtype=torch.float32, device=device) state_prop = torch.cat([coupling, delta_beta, alpha_vals, ic_vals, z_val], dim=-1) pred_envelopes = propagator(state_prop) env = pred_envelopes[0].cpu().numpy() pump = complex(env[0], env[1]) signal = complex(env[2], env[3]) idler = complex(env[4], env[5]) print("NanoPhotoNet-MPM inference") print(f"Device: {device}") print(f"Geometry: width={args.width:.2f} nm, height={args.height:.2f} nm") print(f"Interaction length: {args.length:.2f} mm") print(f"Pump n_eff: {n_eff_p.item():.6f}") print(f"Signal n_eff: {n_eff_s.item():.6f}") print(f"Overlap proxy K: {coupling.item():.6f}") print(f"Phase mismatch Delta beta: {delta_beta.item():.6e} rad/m") print(f"Pump envelope A_p: {pump.real:.6f} + {pump.imag:.6f}j") print(f"Signal envelope A_s: {signal.real:.6f} + {signal.imag:.6f}j") print(f"Idler envelope A_i: {idler.real:.6f} + {idler.imag:.6f}j") if __name__ == "__main__": main()