#!/usr/bin/env python3 # -*- coding: utf-8 -*- # ============================================================================= # Colab setup (uncomment if needed) # ============================================================================= # !pip install -q e3nn torch-scatter torch-cluster matplotlib numpy # !pip install -q pymatgen # optional, for real MP data import os import random import math import copy import json import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader, Dataset import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt try: import e3nn from e3nn import o3 from e3nn.nn import Gate from e3nn.o3 import FullyConnectedTensorProduct HAS_E3NN = True except ImportError: raise RuntimeError("pip install e3nn") try: from torch_cluster import radius_graph from torch_scatter import scatter HAS_PYG = True except ImportError: HAS_PYG = False try: from mp_api.client import MPRester HAS_PMG = True except Exception: HAS_PMG = False device = torch.device("cuda" if torch.cuda.is_available() else "cpu") USE_COMPILE = False # Disabled: buffer-mutation in neuro-morphic layers breaks torch.compile USE_BF16 = torch.cuda.is_available() and torch.cuda.is_bf16_supported() USE_FP16 = torch.cuda.is_available() and not USE_BF16 phi = (1 + np.sqrt(5)) / 2 SAVE_DIR = "./checkpoints_v41_q" PLOT_DIR = "./plots_v41_q" DFT_DIR = "./dft_data_v41_q" AL_DIR = "./active_learning_v41_q" os.makedirs(SAVE_DIR, exist_ok=True) os.makedirs(PLOT_DIR, exist_ok=True) os.makedirs(DFT_DIR, exist_ok=True) os.makedirs(AL_DIR, exist_ok=True) DEBUG_MODE = False if DEBUG_MODE: NUM_MODELS, NUM_LAYERS, NUM_CONFIGS = 1, 8, 200 else: NUM_MODELS, NUM_LAYERS, NUM_CONFIGS = 1, 16, 2000 FORCE_WEIGHT = 1.0 STRESS_WEIGHT = 0.05 CONSIST_WEIGHT = 0.02 AUX_WEIGHT = 0.03 ELECTRO_WEIGHT = 0.03 MULTIFID_WEIGHT = 0.03 SSL_WEIGHT = 0.02 BIO_WEIGHT = 0.03 GRAD_CLIP = 1.0 CUTOFF = 9.0 MAX_NEIGHB = 96 kB_eV = 8.617333262e-5 CONSIST_EVERY = 5 IRREPS_NODE = o3.Irreps("35x0e + 14x1o + 7x2e") IRREPS_FORCE = o3.Irreps("1x1o") IRREPS_EDGE = o3.Irreps("1x0e + 1x1o + 1x2e") # ============================================================================= # 0. EXACT SRI CHAKRA GEOMETRY (User-supplied intersection arithmetic) # ============================================================================= SRI_TRIANGLES_MM = np.array([ [[-169.71, 240.0], [169.71, 240.0], [0.0, 0.0]], [[-78.10499177949904, 197.92315674002487], [78.10499177949904, 197.92315674002487], [0.0, 27.32284325642113]], [[-98.71823312733801, 220.03828947357886], [98.71823312733801, 220.03828947357886], [0.0, 55.02257894715772]], [[-114.9488500600036, 160.66014976539617], [114.9488500600036, 160.66014976539617], [0.0, 105.32229953019834]], [[-160.66014976539617, 0.0], [160.66014976539617, 0.0], [0.0, 250.0]], [[-103.12199145016105, 0.0], [103.12199145016105, 0.0], [0.0, 220.03828947357886]], [[-160.66014976539617, 0.0], [160.66014976539617, 0.0], [0.0, 90.4856922951427]], [[-197.92315674002487, 0.0], [197.92315674002487, 0.0], [0.0, 116.35142605010424]], [[-174.24660560764943, 0.0], [174.24660560764943, 0.0], [0.0, 124.61190803072795]] ], dtype=np.float64) SRI_TRIANGLES = SRI_TRIANGLES_MM / 180.0 def line_intersection(p1, p2, p3, p4): x1, y1 = p1; x2, y2 = p2 x3, y3 = p3; x4, y4 = p4 denom = (x1 - x2) * (y3 - y4) - (y1 - y2) * (x3 - x4) if abs(denom) < 1e-12: return None px = ((x1 * y2 - y1 * x2) * (x3 - x4) - (x1 - x2) * (x3 * y4 - y3 * x4)) / denom py = ((x1 * y2 - y1 * x2) * (y3 - y4) - (y1 - y2) * (x3 * y4 - y3 * x4)) / denom return np.array([px, py], dtype=np.float64) def build_sri_chakra_43(): points = [] pt_key_to_idx = {} def add_point(x, y): key = (round(float(x), 10), round(float(y), 10)) if key not in pt_key_to_idx: pt_key_to_idx[key] = len(points) points.append(np.array([float(x), float(y)], dtype=np.float64)) return pt_key_to_idx[key] tri_vert_ids = [] for tri in SRI_TRIANGLES: ids = [add_point(p[0], p[1]) for p in tri] tri_vert_ids.append(ids) raw_edges = [] for ti, verts in enumerate(tri_vert_ids): for a in range(3): b = (a + 1) % 3 v1, v2 = verts[a], verts[b] if v1 > v2: v1, v2 = v2, v1 raw_edges.append((ti, v1, v2)) # Pairwise intersections (different triangles only) for i in range(len(raw_edges)): ti, a1, a2 = raw_edges[i] p1 = points[a1]; p2 = points[a2] for j in range(i + 1, len(raw_edges)): tj, b1, b2 = raw_edges[j] if ti == tj: continue p3 = points[b1]; p4 = points[b2] ipt = line_intersection(p1, p2, p3, p4) if ipt is None: continue def on_seg(pt, a, b, tol=1e-8): return (min(a[0], b[0]) - tol <= pt[0] <= max(a[0], b[0]) + tol and min(a[1], b[1]) - tol <= pt[1] <= max(a[1], b[1]) + tol) if on_seg(ipt, p1, p2) and on_seg(ipt, p3, p4): add_point(ipt[0], ipt[1]) n_pts = len(points) def pts_on_edge(v1, v2): a = points[v1]; b = points[v2] vec = b - a len2 = np.dot(vec, vec) on_seg = [v1, v2] for idx in range(n_pts): if idx == v1 or idx == v2: continue p = points[idx] cross = (p[0] - a[0]) * (b[1] - a[1]) - (p[1] - a[1]) * (b[0] - a[0]) if abs(cross) > 1e-7: continue if (min(a[0], b[0]) - 1e-7 <= p[0] <= max(a[0], b[0]) + 1e-7 and min(a[1], b[1]) - 1e-7 <= p[1] <= max(a[1], b[1]) + 1e-7): on_seg.append(idx) def proj(idx): p = points[idx] - a return np.dot(p, vec) / (len2 + 1e-12) on_seg.sort(key=proj) return on_seg adj = np.zeros((n_pts, n_pts), dtype=np.float32) for (ti, v1, v2) in raw_edges: seg = pts_on_edge(v1, v2) for k in range(len(seg) - 1): u, v = seg[k], seg[k + 1] adj[u, v] = adj[v, u] = 1.0 faces = [] for i in range(n_pts): nbr_i = np.where(adj[i])[0] for j in nbr_i: if j <= i: continue nbr_j = np.where(adj[j])[0] common = np.intersect1d(nbr_i, nbr_j) for k in common: if k > j: xi, yi = points[i] xj, yj = points[j] xk, yk = points[k] area = (xj - xi) * (yk - yi) - (xk - xi) * (yj - yi) if area > 1e-8: faces.append((int(i), int(j), int(k))) if len(faces) != 43: # Guaranteed fallback n_faces = 43 adj_f = torch.zeros(n_faces, n_faces) offsets = [0, 1, 9, 19, 29, 43] for ring in range(5): start, end = offsets[ring], offsets[ring + 1] n = end - start for i in range(n): node_i = start + i for offset in [1, 2]: j = (i + offset) % n node_j = start + j adj_f[node_i, node_j] = 1 adj_f[node_j, node_i] = 1 for ring in range(4): r_start, r_end = offsets[ring], offsets[ring + 1] next_start, next_end = offsets[ring + 1], offsets[ring + 2] r_n = r_end - r_start next_n = next_end - next_start for i in range(next_n): outer_node = next_start + i inner_idx_1 = int((i / next_n) * r_n) % r_n inner_idx_2 = (inner_idx_1 + 1) % r_n adj_f[outer_node, r_start + inner_idx_1] = 1 adj_f[r_start + inner_idx_1, outer_node] = 1 adj_f[outer_node, r_start + inner_idx_2] = 1 adj_f[r_start + inner_idx_2, outer_node] = 1 for i in range(1, n_faces): adj_f[0, i] = 1 adj_f[i, 0] = 1 adj_f = ((adj_f + adj_f.t()) > 0).float() avarana = torch.zeros(n_faces, dtype=torch.long) for i in range(43): if i < 1: avarana[i] = 6 elif i < 9: avarana[i] = 5 elif i < 19: avarana[i] = 4 elif i < 29: avarana[i] = 3 else: avarana[i] = 2 return adj_f, avarana # fallback: no points/faces n_faces = len(faces) SRI_ADJ = np.zeros((n_faces, n_faces), dtype=np.float32) for i in range(n_faces): set_i = set(faces[i]) for j in range(i + 1, n_faces): if len(set_i & set(faces[j])) >= 2: SRI_ADJ[i, j] = SRI_ADJ[j, i] = 1.0 centroids = [] for (i, j, k) in faces: cx = (points[i][0] + points[j][0] + points[k][0]) / 3.0 cy = (points[i][1] + points[j][1] + points[k][1]) / 3.0 centroids.append(np.sqrt(cx ** 2 + cy ** 2)) order = np.argsort(centroids)[::-1] avarana = torch.zeros(n_faces, dtype=torch.long) ring_counts = [(14, 4), (10, 5), (10, 6), (8, 7), (1, 8)] idx = 0 for cnt, av in ring_counts: for _ in range(cnt): if idx < n_faces: avarana[order[idx]] = av idx += 1 return torch.from_numpy(SRI_ADJ), avarana, points, faces def build_srichakra_cortical_graph(device=None, dtype=torch.float32): result = build_sri_chakra_43() if len(result) == 4: adj, avarana, points, faces = result adj = adj.to(device=device, dtype=dtype) avarana = avarana.to(device=device) # Build 3D coordinates from exact face centroids coords_list = [] for (i, j, k) in faces: cx = (points[i][0] + points[j][0] + points[k][0]) / 3.0 cy = (points[i][1] + points[j][1] + points[k][1]) / 3.0 theta = math.atan2(cy, cx) z = 0.08 * math.sin(3.0 * theta) coords_list.append([float(cx), float(cy), float(z)]) coords = torch.tensor(coords_list, dtype=dtype, device=device) else: # Fallback: use idealized ring coordinates adj, avarana = result[:2] adj = adj.to(device=device, dtype=dtype) avarana = avarana.to(device=device) radii = [0.0, 0.18, 0.36, 0.58, 0.84] counts = [1, 8, 10, 10, 14] coords_list = [] for ring_idx, (r, n) in enumerate(zip(radii, counts)): if n == 1: coords_list.append([0.0, 0.0, 0.0]) continue phase = 0.0 if ring_idx == 1: phase = np.pi / 8 elif ring_idx == 2: phase = np.pi / 10 elif ring_idx == 3: phase = np.pi / 20 elif ring_idx == 4: phase = np.pi / 14 for i in range(n): theta = 2 * np.pi * i / n + phase x = r * np.cos(theta) y = r * np.sin(theta) z = 0.08 * np.sin(3.0 * theta) if ring_idx > 0 else 0.0 coords_list.append([x, y, z]) coords = torch.tensor(coords_list, dtype=dtype, device=device) # Map avarana (4..8) to region (2..6) for laminar profiles region = torch.zeros_like(avarana) region[avarana == 8] = 6 # bindu / center region[avarana == 7] = 5 region[avarana == 6] = 4 region[avarana == 5] = 3 region[avarana == 4] = 2 # outermost n_nodes = len(region) laminar_profile = torch.zeros(n_nodes, 6, dtype=dtype, device=device) for i in range(n_nodes): r = region[i].item() if r == 6: laminar_profile[i] = torch.tensor([0.10, 0.10, 0.18, 0.22, 0.20, 0.20], dtype=dtype, device=device) elif r >= 5: laminar_profile[i] = torch.tensor([0.18, 0.18, 0.22, 0.16, 0.14, 0.12], dtype=dtype, device=device) elif r == 4: laminar_profile[i] = torch.tensor([0.14, 0.16, 0.24, 0.18, 0.16, 0.12], dtype=dtype, device=device) elif r == 3: laminar_profile[i] = torch.tensor([0.12, 0.14, 0.20, 0.20, 0.18, 0.16], dtype=dtype, device=device) else: laminar_profile[i] = torch.tensor([0.10, 0.12, 0.18, 0.20, 0.20, 0.20], dtype=dtype, device=device) functional_module = torch.zeros_like(region) functional_module[region == 6] = 0 functional_module[region == 5] = 1 functional_module[(region == 4) | (region == 3)] = 2 functional_module[region == 2] = 3 return adj, region, coords, laminar_profile, functional_module def generate_105node_coordinates(device=None, dtype=torch.float32): counts = [1, 8, 24, 36, 36] radii = [0.0, 0.16, 0.33, 0.59, 0.92] coords = [] for ring_idx, (n, r) in enumerate(zip(counts, radii)): if n == 1: coords.append([0.0, 0.0, 0.0]) continue phase = np.pi / max(2, n) for i in range(n): theta = 2 * np.pi * i / n + phase coords.append([r * np.cos(theta), r * np.sin(theta), 0.04 * np.sin(theta * (ring_idx + 1) * phi)]) coords = torch.tensor(coords, dtype=dtype, device=device) assert coords.shape == (105, 3) return coords def build_105node_hierarchical_graph(device=None, dtype=torch.float32): counts = [1, 8, 24, 36, 36] offsets = np.cumsum([0] + counts).tolist() n_nodes = 105 coords = generate_105node_coordinates(device=device, dtype=dtype) adj = torch.zeros(n_nodes, n_nodes, dtype=dtype, device=device) for ring in range(1, len(counts)): s, e = offsets[ring], offsets[ring + 1] n = e - s for i in range(n): for hop in [1, 2]: j = (i + hop) % n adj[s + i, s + j] = 1 adj[s + j, s + i] = 1 for ring in range(4): a0, a1 = offsets[ring], offsets[ring + 1] b0, b1 = offsets[ring + 1], offsets[ring + 2] na = a1 - a0 nb = b1 - b0 for i in range(nb): j0 = int((i / nb) * na) % na j1 = (j0 + 1) % na adj[b0 + i, a0 + j0] = 1 adj[b0 + i, a0 + j1] = 1 adj[a0 + j0, b0 + i] = 1 adj[a0 + j1, b0 + i] = 1 for i in range(1, n_nodes): adj[0, i] = 1 adj[i, 0] = 1 region = torch.zeros(n_nodes, dtype=torch.long, device=device) region[0:1] = 0 region[1:9] = 1 region[9:33] = 2 region[33:69] = 3 region[69:105] = 4 return adj, region, coords # ============================================================================= # 1. GEOMETRY # ============================================================================= class BSplineBasis(nn.Module): def __init__(self, num_basis=32, cutoff=9.0, degree=3): super().__init__() self.num_basis = num_basis self.cutoff = cutoff self.degree = degree self.spacing = cutoff / (num_basis - degree) knots = torch.arange(-degree, num_basis + 1) * self.spacing self.register_buffer("knots", knots) def _cox_de_boor(self, t, i, k): if k == 0: return ((self.knots[i] <= t) & (t < self.knots[i + 1])).float() left_num = t - self.knots[i] left_den = self.knots[i + k] - self.knots[i] left = torch.where(left_den > 1e-8, left_num / left_den, torch.zeros_like(t)) right_num = self.knots[i + k + 1] - t right_den = self.knots[i + k + 1] - self.knots[i + 1] right = torch.where(right_den > 1e-8, right_num / right_den, torch.zeros_like(t)) return left * self._cox_de_boor(t, i, k - 1) + right * self._cox_de_boor(t, i + 1, k - 1) def forward(self, dist): t = dist.clamp(min=0, max=self.cutoff - 1e-6) basis = [self._cox_de_boor(t, i, self.degree) for i in range(self.num_basis)] return torch.stack(basis, dim=-1) def build_neighbor_list(pos, cutoff=9.0, max_n=96): B, N, _ = pos.shape pos_flat = pos.view(B * N, 3) batch = torch.arange(B, device=pos.device).unsqueeze(1).expand(B, N).reshape(-1) if HAS_PYG: edge_index = radius_graph(pos_flat, r=cutoff, batch=batch, max_num_neighbors=max_n) src, dst = edge_index edge_vec = pos_flat[dst] - pos_flat[src] edge_dist = torch.norm(edge_vec, dim=-1) else: edge_list = [] for b in range(B): pb = pos[b] dist = torch.cdist(pb, pb) dist = dist + torch.eye(N, device=dist.device) * 1e9 sorted_dist, sorted_idx = torch.sort(dist, dim=-1) valid = sorted_dist < cutoff for i in range(N): valid_nbrs = sorted_idx[i][valid[i]] n_valid = min(len(valid_nbrs), max_n) if n_valid == 0: continue nbrs = valid_nbrs[:n_valid] src = torch.full((n_valid,), i + b * N, dtype=torch.long, device=pos.device) dst = nbrs + b * N edge_list.append(torch.stack([src, dst], dim=0)) if not edge_list: edge_index = torch.stack([torch.arange(B * N, device=pos.device), torch.arange(B * N, device=pos.device)], dim=0) else: edge_index = torch.cat(edge_list, dim=1) src, dst = edge_index edge_vec = pos_flat[dst] - pos_flat[src] edge_dist = torch.norm(edge_vec, dim=-1) neigh = [[] for _ in range(B * N)] for s, d in zip(src.tolist(), dst.tolist()): neigh[s].append(d) neigh[d].append(s) triplet_src, triplet_mid, triplet_dst = [], [], [] for j in range(B * N): nb = neigh[j] if len(nb) < 2: continue for ii in range(len(nb)): i = nb[ii] for kk in range(ii + 1, len(nb)): k = nb[kk] triplet_src.append(i) triplet_mid.append(j) triplet_dst.append(k) if not triplet_src: triplet_idx = torch.zeros((3, 0), dtype=torch.long, device=pos.device) else: triplet_idx = torch.stack([ torch.tensor(triplet_src, device=pos.device), torch.tensor(triplet_mid, device=pos.device), torch.tensor(triplet_dst, device=pos.device) ], dim=0) return edge_index, edge_vec, edge_dist, pos_flat, batch, triplet_idx def build_quadruplets(pos, cutoff=6.0, max_quads=4096): B, N, _ = pos.shape quads = [] for b in range(B): d = torch.cdist(pos[b], pos[b]) for i in range(N): nbrs = torch.where((d[i] < cutoff) & (d[i] > 1e-8))[0] if len(nbrs) >= 3: m = min(len(nbrs), 6) nbrs = nbrs[:m] for a in range(len(nbrs)): for c in range(a + 1, len(nbrs)): for e in range(c + 1, len(nbrs)): quads.append([b * N + i, b * N + nbrs[a], b * N + nbrs[c], b * N + nbrs[e]]) if len(quads) >= max_quads: break if len(quads) >= max_quads: break if len(quads) >= max_quads: break if len(quads) >= max_quads: break if len(quads) == 0: return torch.zeros((4, 0), dtype=torch.long, device=pos.device) return torch.tensor(quads, dtype=torch.long, device=pos.device).t() def build_latent_rewire_graph(x, pos, k=16, alpha=0.6): B, N, D = x.shape x_norm = F.normalize(x[..., :min(32, D)], dim=-1) sim = torch.einsum("bid,bjd->bij", x_norm, x_norm) geo = torch.cdist(pos, pos) score = alpha * sim - (1 - alpha) * (geo / (geo.mean(dim=(-1, -2), keepdim=True) + 1e-8)) eye = torch.eye(N, device=pos.device).unsqueeze(0) score = score - 1e6 * eye topk = torch.topk(score, k=min(k, N - 1), dim=-1).indices src_all, dst_all = [], [] for b in range(B): for i in range(N): nbrs = topk[b, i] src = torch.full((len(nbrs),), b * N + i, dtype=torch.long, device=pos.device) dst = nbrs + b * N src_all.append(src) dst_all.append(dst) src_all = torch.cat(src_all) dst_all = torch.cat(dst_all) edge_index = torch.stack([src_all, dst_all], dim=0) pos_flat = pos.reshape(B * N, 3) edge_vec = pos_flat[dst_all] - pos_flat[src_all] edge_dist = torch.norm(edge_vec, dim=-1) return edge_index, edge_vec, edge_dist # ============================================================================= # 2. CHANNELS # ============================================================================= def compute_defect_channels(pos, z): if pos.dim() == 2: pos = pos.unsqueeze(0) z = z.unsqueeze(0) B, N, _ = pos.shape dist = torch.cdist(pos, pos) + torch.eye(N, device=pos.device).unsqueeze(0) * 1e6 nn_dist = dist.min(dim=-1).values local_mean = dist.masked_fill(dist > 20, 0.0).sum(dim=-1) / (dist < 20).float().sum(dim=-1).clamp(min=1.0) vacancy_like = torch.relu(nn_dist - 2.4) interstitial_like = torch.relu(2.0 - nn_dist) substitution_like = torch.abs(z.float() - z.float().mean(dim=1, keepdim=True)) / 20.0 return torch.stack([ vacancy_like, interstitial_like, substitution_like, local_mean, nn_dist, torch.exp(-nn_dist), torch.relu(local_mean - 3.0), torch.relu(2.5 - local_mean), ], dim=-1) def compute_diffusion_channels(pos, z): if pos.dim() == 2: pos = pos.unsqueeze(0) z = z.unsqueeze(0) B, N, _ = pos.shape dist = torch.cdist(pos, pos) + torch.eye(N, device=pos.device).unsqueeze(0) * 1e6 h_mask = (z == 1).float() weighted_h = dist * h_mask.unsqueeze(1) + (1.0 - h_mask.unsqueeze(1)) * 1e6 nearest_h = weighted_h.min(dim=-1).values.clamp(max=20.0) nearest_any = dist.min(dim=-1).values return torch.stack([ nearest_any, nearest_h, torch.exp(-nearest_any / 2.0), torch.relu(nearest_any - 1.5), torch.exp(-nearest_h / 1.0), torch.sin(nearest_any * phi), torch.cos(nearest_any * np.pi), torch.relu(3.5 - nearest_any), ], dim=-1) def compute_phonon_channels(pos, z): if pos.dim() == 2: pos = pos.unsqueeze(0) z = z.unsqueeze(0) centered = pos - pos.mean(dim=1, keepdim=True) amp = centered.norm(dim=-1) return torch.stack([ amp, amp ** 2, torch.sin(amp), torch.cos(amp * phi), z.float().view(z.shape[0], -1) / 30.0, torch.exp(-amp), torch.relu(amp - amp.mean(dim=1, keepdim=True)), torch.relu(amp.mean(dim=1, keepdim=True) - amp), ], dim=-1) def build_multi_physics_channels(pos, z): if pos.dim() == 2: pos = pos.unsqueeze(0) z = z.unsqueeze(0) B, N, _ = pos.shape diff = pos.unsqueeze(2) - pos.unsqueeze(1) z = z.reshape(z.shape[0] if pos.dim()==3 else 1, -1); dist = torch.norm(diff, dim=-1) + 1e-8 local_density = torch.stack([ dist.mean(dim=-1), dist.std(dim=-1), torch.exp(-dist).mean(dim=-1), torch.sin(dist).mean(dim=-1), dist.max(dim=-1).values, dist.min(dim=-1).values, z.float().view(z.shape[0], -1) / 30.0, dist.var(dim=-1), ], dim=-1) gradient_flow = torch.stack([ dist.mean(dim=-1), dist.std(dim=-1), pos.norm(dim=-1), z.float().view(z.shape[0], -1).view(z.shape[0], -1) / 20.0, torch.exp(-dist).mean(dim=-1), torch.sin(dist * phi).mean(dim=-1), torch.cos(dist).mean(dim=-1), (pos - pos.mean(dim=1, keepdim=True)).norm(dim=-1), ], dim=-1) interaction_energy = torch.stack([ dist.var(dim=-1), torch.relu(dist - 2.0).mean(dim=-1), torch.exp(-dist / 2.0).mean(dim=-1), torch.sin(dist * phi).mean(dim=-1), (dist ** 2).mean(dim=-1), z.float().view(z.shape[0], -1).float().view(z.shape[0], -1).float().view(z.shape[0], -1).float() / 10.0, torch.cos(dist).mean(dim=-1), torch.relu(3.0 - dist).mean(dim=-1), ], dim=-1) global_context = torch.stack([ dist.mean(dim=-1), dist.max(dim=-1).values, dist.min(dim=-1).values, torch.log(dist.mean(dim=-1) + 1e-6), (pos - pos.mean(dim=1, keepdim=True)).norm(dim=-1), pos.norm(dim=-1), torch.exp(-dist).mean(dim=-1), z.float().view(z.shape[0], -1) / 50.0, ], dim=-1) structural_stability = torch.stack([ torch.abs(pos[..., 0]), torch.abs(pos[..., 1]), pos.norm(dim=-1), dist.mean(dim=-1), dist.var(dim=-1), z.float(), z.float() ** 0.5, torch.relu(dist - 1.5).mean(dim=-1), ], dim=-1) phase_aware = torch.stack([ torch.sin(pos[..., 0] * phi), torch.cos(pos[..., 1] * phi), torch.sin(pos[..., 2] * np.pi), torch.exp(-pos.norm(dim=-1) / 2.0), torch.sin(dist.mean(dim=-1) * np.pi), torch.cos(dist.mean(dim=-1) * np.pi), torch.exp(-dist).std(dim=-1), torch.sin(dist.std(dim=-1) * phi), ], dim=-1) defect = compute_defect_channels(pos, z) diffusion = compute_diffusion_channels(pos, z) phonon = compute_phonon_channels(pos, z) return local_density, gradient_flow, interaction_energy, global_context, structural_stability, phase_aware, defect, diffusion, phonon # ============================================================================= # 3. COMPLEX HILBERT / ENCODER (Quantum -> Neuron bridge) # ============================================================================= class ComplexHilbertModule(nn.Module): def __init__(self, in_dim=8, hidden=16): super().__init__() self.local_real = nn.Sequential(nn.Linear(in_dim, hidden), nn.SiLU(), nn.Linear(hidden, hidden)) self.local_imag = nn.Sequential(nn.Linear(in_dim, hidden), nn.SiLU(), nn.Linear(hidden, hidden)) self.global_real = nn.Sequential(nn.Linear(in_dim, hidden), nn.SiLU(), nn.Linear(hidden, hidden)) self.global_imag = nn.Sequential(nn.Linear(in_dim, hidden), nn.SiLU(), nn.Linear(hidden, hidden)) self.coherence_head = nn.Sequential(nn.Linear(hidden * 4, hidden), nn.SiLU(), nn.Linear(hidden, 1)) self.phase_head = nn.Sequential(nn.Linear(hidden * 4, hidden), nn.SiLU(), nn.Linear(hidden, 1)) self.out_proj = nn.Sequential(nn.Linear(hidden * 8, 16), nn.SiLU()) # v4.1: Quantum fields -> neural oscillator parameters self.freq_head = nn.Sequential(nn.Linear(hidden * 4, hidden), nn.SiLU(), nn.Linear(hidden, 16)) self.sync_head = nn.Sequential(nn.Linear(hidden * 4, hidden), nn.SiLU(), nn.Linear(hidden, 16)) self.thresh_head = nn.Sequential(nn.Linear(hidden * 4, hidden), nn.SiLU(), nn.Linear(hidden, 1), nn.Sigmoid()) def forward(self, local_x, global_x): rl = self.local_real(local_x) il = self.local_imag(local_x) rg = self.global_real(global_x) ig = self.global_imag(global_x) psi_local = torch.complex(rl.float(), il.float()) psi_global = torch.complex(rg.float(), ig.float()) psi_total = psi_local + (1.0 / phi) * psi_global interference = psi_local * torch.conj(psi_global) feat = torch.cat([ psi_total.real, psi_total.imag, torch.abs(psi_total), torch.angle(psi_total), interference.real, interference.imag, torch.abs(interference), torch.angle(interference) ], dim=-1) aux = torch.cat([rl, il, rg, ig], dim=-1) coherence = torch.sigmoid(self.coherence_head(aux)) phase_order = self.phase_head(aux) # Neural modulation from quantum Hilbert space neural_freq = self.freq_head(aux) # drives oscillator frequency neural_sync = self.sync_head(aux) # drives phase coupling threshold_mod = self.thresh_head(aux) # modulates LIF threshold return self.out_proj(feat), coherence, phase_order, neural_freq, neural_sync, threshold_mod class MultiPhysicsEncoderV4(nn.Module): def __init__(self, irreps_out=IRREPS_NODE): super().__init__() def block(): return nn.Sequential(nn.Linear(8, 16), nn.SiLU()) self.pingala = nn.ModuleDict({ "jal": block(), "vayu": block(), "agni": block(), "akasha": block(), "prithvi": block(), "defect": block(), "diffusion": block(), "phonon": block() }) self.ida = nn.ModuleDict({ "jal": block(), "vayu": block(), "agni": block(), "akasha": block(), "prithvi": block(), "defect": block(), "diffusion": block(), "phonon": block() }) self.hilbert = ComplexHilbertModule(8, 16) self.scalar_mlp = nn.Sequential(nn.Linear(16 * 17, 128), nn.SiLU(), nn.Linear(128, 56)) self.to_irreps = o3.Linear(o3.Irreps("56x0e"), irreps_out) self.mask_token = nn.Parameter(torch.zeros(1, 1, 8)) def forward(self, ld, gf, ie, gc, ss, hs, defect, diffusion, phonon, mask_ratio=0.0): if self.training and mask_ratio > 0: B, N, _ = ld.shape mask = (torch.rand(B, N, 1, device=ld.device) < mask_ratio).float() ld = ld * (1 - mask) + self.mask_token.expand(B, N, -1) * mask jal_p = self.pingala["jal"](ld); jal_i = -self.ida["jal"](ld) vayu_p = self.pingala["vayu"](gf); vayu_i = -self.ida["vayu"](gf) agni_p = self.pingala["agni"](ie); agni_i = -self.ida["agni"](ie) akasha_p = self.pingala["akasha"](gc); akasha_i = -self.ida["akasha"](gc) prithvi_p = self.pingala["prithvi"](ss); prithvi_i = -self.ida["prithvi"](ss) defect_p = self.pingala["defect"](defect); defect_i = -self.ida["defect"](defect) diffusion_p = self.pingala["diffusion"](diffusion); diffusion_i = -self.ida["diffusion"](diffusion) phonon_p = self.pingala["phonon"](phonon); phonon_i = -self.ida["phonon"](phonon) hilbert_feat, coherence, phase_order, neural_freq, neural_sync, threshold_mod = self.hilbert(hs, gc) scalars = torch.cat([ jal_p, jal_i, vayu_p, vayu_i, agni_p, agni_i, akasha_p, akasha_i, prithvi_p, prithvi_i, defect_p, defect_i, diffusion_p, diffusion_i, phonon_p, phonon_i, hilbert_feat ], dim=-1) scalars_56 = self.scalar_mlp(scalars) return self.to_irreps(scalars_56), coherence, phase_order, neural_freq, neural_sync, threshold_mod # ============================================================================= # 4. MAYA IDA/PINGALA ROUTER (7-head E/I attention) # ============================================================================= class MayaIdaPingalaRouter(nn.Module): def __init__(self, dim, heads=7): super().__init__() self.heads = heads self.dim = dim self.head_dim = dim // heads assert self.head_dim * heads == dim # Pingala = Excitatory direct pathway self.q_p = nn.ModuleList([nn.Linear(dim, self.head_dim) for _ in range(heads)]) self.k_p = nn.ModuleList([nn.Linear(dim, self.head_dim) for _ in range(heads)]) self.v_p = nn.ModuleList([nn.Linear(dim, self.head_dim) for _ in range(heads)]) # Ida = Inhibitory suppressive pathway self.q_i = nn.ModuleList([nn.Linear(dim, self.head_dim) for _ in range(heads)]) self.k_i = nn.ModuleList([nn.Linear(dim, self.head_dim) for _ in range(heads)]) self.v_i = nn.ModuleList([nn.Linear(dim, self.head_dim) for _ in range(heads)]) # Maya: active attention that gates E/I ratio per head per sample self.maya = nn.ModuleList([ nn.Sequential(nn.Linear(dim, 2), nn.Softmax(dim=-1)) for _ in range(heads) ]) self.out = nn.Linear(self.head_dim * heads, dim) def forward(self, x, temperature=(1 / phi)): B, N, D = x.shape outs = [] for h in range(self.heads): # --- Pingala (Excitatory) --- q = self.q_p[h](x) k = self.k_p[h](x) v = self.v_p[h](x) scores = torch.einsum("bid,bjd->bij", q, k) / math.sqrt(self.head_dim) attn_p = F.softmax(scores / temperature, dim=-1) out_p = torch.einsum("bij,bjd->bid", attn_p, v) # --- Ida (Inhibitory) --- q = self.q_i[h](x) k = self.k_i[h](x) v = self.v_i[h](x) scores = torch.einsum("bid,bjd->bij", q, k) / math.sqrt(self.head_dim) attn_i = F.softmax(scores / temperature, dim=-1) out_i = torch.einsum("bij,bjd->bid", attn_i, v) # --- Maya attention: blend [excitation, inhibition] --- maya_w = self.maya[h](x.mean(dim=1)).unsqueeze(1) # (B, 1, 2) w_exc = maya_w[..., 0:1] w_inh = maya_w[..., 1:2] # Inhibition subtracts; net E/I output outs.append(w_exc * out_p - w_inh * out_i) return self.out(torch.cat(outs, dim=-1)) # ============================================================================= # 4.5 QUANTUM-NEURAL OSCILLATOR (Hilbert waves -> synchrony) # ============================================================================= class QuantumNeuralOscillator(nn.Module): def __init__(self, dim): super().__init__() self.dim = dim self.register_buffer("time", torch.zeros(1)) # Learn to map quantum Hilbert outputs to continuous neural frequencies self.freq_scale = nn.Linear(16, dim) self.phase_scale = nn.Linear(16, dim) self.amp_gate = nn.Sequential(nn.Linear(dim + 16, dim), nn.Sigmoid()) def forward(self, x, neural_freq, neural_sync, threshold_mod): B, N, D = x.shape t = self.time.to(x.device) # Quantum-modulated frequencies freq = 0.06 + 0.05 * torch.tanh(self.freq_scale(neural_freq)) # (B, N, D) phase = 0.5 * torch.tanh(self.phase_scale(neural_sync)) # (B, N, D) # Amplitude gated by synchrony field and state energy gate_input = torch.cat([x, neural_sync], dim=-1) amp = self.amp_gate(gate_input) * torch.exp(-torch.norm(x, dim=-1, keepdim=True) * 0.1) # Triple-band oscillatory waves (golden ratio harmonics) o1 = torch.sin(t * freq + phase) o2 = torch.sin(t * freq * phi + phase) o3 = torch.sin(t * freq * (phi ** 2) + phase) wave = amp * (o1 + o2 + o3) / 3.0 self.time = self.time + 1.0 return x + 0.08 * wave, threshold_mod # ============================================================================= # 5. GLOBAL / HIGHER-BODY / ELECTROSTATICS # ============================================================================= class GlobalGraphTransformerBlock(nn.Module): def __init__(self, dim, num_heads=8, ff_mult=4): super().__init__() self.norm1 = nn.LayerNorm(dim) self.attn = nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, batch_first=True) self.norm2 = nn.LayerNorm(dim) self.ff = nn.Sequential(nn.Linear(dim, ff_mult * dim), nn.GELU(), nn.Linear(ff_mult * dim, dim)) def forward(self, x): y = self.norm1(x) y, _ = self.attn(y, y, y, need_weights=False) x = x + y x = x + self.ff(self.norm2(x)) return x class TripletPhysicsModule(nn.Module): def __init__(self, dim, hidden=64): super().__init__() self.mlp = nn.Sequential(nn.Linear(3, hidden), nn.SiLU(), nn.Linear(hidden, dim)) def forward(self, x, pos, triplet_idx): B, N, D = x.shape if triplet_idx.shape[1] == 0: return x pos_flat = pos.reshape(B * N, 3) i, j, k = triplet_idx vij = pos_flat[j] - pos_flat[i] vkj = pos_flat[j] - pos_flat[k] dot = (vij * vkj).sum(dim=-1) nij = torch.norm(vij, dim=-1) + 1e-8 nkj = torch.norm(vkj, dim=-1) + 1e-8 theta = torch.acos(torch.clamp(dot / (nij * nkj), -1.0, 1.0)) feat = torch.stack([nij, nkj, theta], dim=-1) msg = self.mlp(feat) out = torch.zeros(B * N, D, device=x.device, dtype=x.dtype) out.index_add_(0, j, msg.to(out.dtype)) return x + 0.03 * out.view(B, N, D) class QuadrupletPhysicsModule(nn.Module): def __init__(self, dim, hidden=64): super().__init__() self.mlp = nn.Sequential(nn.Linear(6, hidden), nn.SiLU(), nn.Linear(hidden, dim)) def forward(self, x, pos, quad_idx): B, N, D = x.shape if quad_idx.shape[1] == 0: return x pos_flat = pos.reshape(B * N, 3) i, a, b, c = quad_idx pi, pa, pb, pc = pos_flat[i], pos_flat[a], pos_flat[b], pos_flat[c] ria = torch.norm(pa - pi, dim=-1) rib = torch.norm(pb - pi, dim=-1) ric = torch.norm(pc - pi, dim=-1) ab = torch.norm(pa - pb, dim=-1) bc = torch.norm(pb - pc, dim=-1) ca = torch.norm(pc - pa, dim=-1) feat = torch.stack([ria, rib, ric, ab, bc, ca], dim=-1) msg = self.mlp(feat) out = torch.zeros(B * N, D, device=x.device, dtype=x.dtype) out.index_add_(0, i, msg.to(out.dtype)) return x + 0.02 * out.view(B, N, D) class ChargeEquilibrationModule(nn.Module): def __init__(self, dim, hidden=64, n_iter=3): super().__init__() self.n_iter = n_iter self.chi = nn.Sequential(nn.Linear(dim, hidden), nn.SiLU(), nn.Linear(hidden, 1)) self.hardness = nn.Sequential(nn.Linear(dim, hidden), nn.SiLU(), nn.Linear(hidden, 1), nn.Softplus()) self.charge_update = nn.Sequential(nn.Linear(3, hidden), nn.SiLU(), nn.Linear(hidden, 1)) def forward(self, x, pos, z): B, N, D = x.shape q = torch.zeros(B, N, 1, device=x.device, dtype=x.dtype) chi = self.chi(x) eta = self.hardness(x) + 1e-4 dist = torch.cdist(pos, pos) + torch.eye(N, device=pos.device).unsqueeze(0) * 1e6 for _ in range(self.n_iter): coul = (q.transpose(1, 2) / dist).sum(dim=-1, keepdim=True) dq = self.charge_update(torch.cat([chi, eta, coul], dim=-1)) q = q + 0.2 * dq q = q - q.mean(dim=1, keepdim=True) return q, chi, eta class LongRangeElectrostatics(nn.Module): def __init__(self, alpha=0.35, gmax=2): super().__init__() self.alpha = alpha self.gmax = gmax def reciprocal_energy(self, q, pos, cell): B, N, _ = pos.shape energy = torch.zeros(B, 1, device=pos.device, dtype=pos.dtype) for b in range(B): lat = cell[b] vol = torch.det(lat).abs() + 1e-8 rec = 2 * np.pi * torch.inverse(lat).t() eb = torch.tensor(0.0, device=pos.device, dtype=pos.dtype) for h in range(-self.gmax, self.gmax + 1): for k in range(-self.gmax, self.gmax + 1): for l in range(-self.gmax, self.gmax + 1): if h == 0 and k == 0 and l == 0: continue G = h * rec[:, 0] + k * rec[:, 1] + l * rec[:, 2] G2 = (G * G).sum() + 1e-8 phase = torch.einsum("nd,d->n", pos[b], G) sf_real = torch.sum(q[b, :, 0] * torch.cos(phase)) sf_imag = torch.sum(q[b, :, 0] * torch.sin(phase)) s2 = sf_real ** 2 + sf_imag ** 2 damp = torch.exp(-G2 / (4 * self.alpha ** 2)) eb = eb + (4 * np.pi / vol) * damp * s2 / G2 energy[b, 0] = 0.5 * eb return energy def real_energy(self, q, pos): B, N, _ = pos.shape dist = torch.cdist(pos, pos) + torch.eye(N, device=pos.device).unsqueeze(0) * 1e6 qq = q @ q.transpose(1, 2) erfc_term = torch.erfc(self.alpha * dist) / dist return 0.5 * (qq * erfc_term).sum(dim=(-1, -2)).unsqueeze(-1) def self_energy(self, q): return -(self.alpha / np.sqrt(np.pi)) * (q.squeeze(-1) ** 2).sum(dim=-1, keepdim=True) def forward(self, q, pos, cell): return self.real_energy(q, pos) + self.reciprocal_energy(q, pos, cell) + self.self_energy(q) # ============================================================================= # 6. TUNNELS # ============================================================================= class EdgeMemoryBank(nn.Module): def __init__(self, node_dim, mem_dim=32): super().__init__() self.mem_dim = mem_dim self.edge_proj = nn.Sequential(nn.Linear(mem_dim + 4, mem_dim), nn.SiLU(), nn.Linear(mem_dim, mem_dim)) self.node_to_mem = nn.Linear(node_dim, mem_dim) self.update = nn.GRUCell(mem_dim, mem_dim) def forward(self, node_src, edge_dist, edge_vec, prev_mem=None): src_mem = self.node_to_mem(node_src) feat = torch.cat([src_mem, edge_dist.unsqueeze(-1), edge_vec], dim=-1) msg = self.edge_proj(feat) if prev_mem is None: prev_mem = torch.zeros_like(msg) return self.update(msg, prev_mem) class TunnelWaveModule(nn.Module): def __init__(self, dim): super().__init__() self.phase = nn.Linear(dim, dim) self.freq = nn.Parameter(torch.randn(1, 1, dim) * 0.05) self.gain = nn.Sequential(nn.Linear(dim, dim), nn.Sigmoid()) def forward(self, x): amp = torch.exp(-torch.norm(x, dim=-1, keepdim=True) * 0.1) return self.gain(x) * amp * torch.sin(self.phase(x) * self.freq) class TunnelMessageBlock(nn.Module): def __init__(self, layer_idx, cutoff, irreps_node=IRREPS_NODE, irreps_sh=IRREPS_EDGE): super().__init__() self.cutoff = cutoff self.bspline = BSplineBasis(num_basis=32, cutoff=cutoff, degree=3) self.tp = FullyConnectedTensorProduct(irreps_node, irreps_sh, irreps_node, shared_weights=False) self.radial = nn.Sequential(nn.Linear(32, 64), nn.SiLU(), nn.Linear(64, 64), nn.SiLU(), nn.Linear(64, self.tp.weight_numel)) self.self_interact = o3.Linear(irreps_node, irreps_node) n_scalar = irreps_node[0].mul self.gate = Gate( o3.Irreps(f"{n_scalar}x0e"), [torch.tanh], o3.Irreps("21x0e"), [torch.sigmoid], o3.Irreps("14x1o + 7x2e") ) self.gate_gen = nn.Linear(n_scalar, self.gate.irreps_gates.dim) self.edge_mem = EdgeMemoryBank(irreps_node.dim, 32) def forward(self, x_flat, edge_index, edge_vec, edge_dist, prev_mem=None): src, dst = edge_index edge_sh = o3.spherical_harmonics(IRREPS_EDGE, edge_vec, normalize=True, normalization="component") weights = self.radial(self.bspline(edge_dist)) node_src = x_flat[src] messages = self.tp(node_src, edge_sh, weights) mem = self.edge_mem(node_src, edge_dist, edge_vec, prev_mem) m = min(messages.shape[-1], mem.shape[-1]) messages[..., :m] = messages[..., :m] + 0.05 * mem[..., :m] out = torch.zeros_like(x_flat) if HAS_PYG: out = scatter(messages, dst, dim=0, dim_size=x_flat.shape[0], reduce="sum") else: out.index_add_(0, dst, messages) out = self.self_interact(x_flat) + out n_scalar = IRREPS_NODE[0].mul scalars = out[..., :n_scalar] gates = self.gate_gen(scalars) out = self.gate(torch.cat([scalars, gates, out[..., n_scalar:]], dim=-1)) return out, mem class CrossTunnelCommunication(nn.Module): def __init__(self, dim): super().__init__() self.t36 = nn.Linear(dim * 2, dim) self.t69 = nn.Linear(dim * 2, dim) self.t39 = nn.Linear(dim * 2, dim) self.norm = nn.LayerNorm(dim) def forward(self, x3, x6, x9): y3 = x3 + 0.05 * self.t39(torch.cat([x3, x9], dim=-1)) y6 = x6 + 0.05 * self.t36(torch.cat([x6, x3], dim=-1)) + 0.05 * self.t69(torch.cat([x6, x9], dim=-1)) y9 = x9 + 0.05 * self.t39(torch.cat([x9, x3], dim=-1)) return self.norm(y3), self.norm(y6), self.norm(y9) class ThreeTunnelPhysics(nn.Module): def __init__(self, num_layers=2): super().__init__() self.layers3 = nn.ModuleList([TunnelMessageBlock(i, cutoff=3.0) for i in range(num_layers)]) self.layers6 = nn.ModuleList([TunnelMessageBlock(i, cutoff=6.0) for i in range(num_layers)]) self.layers9 = nn.ModuleList([TunnelMessageBlock(i, cutoff=9.0) for i in range(num_layers)]) self.cross = nn.ModuleList([CrossTunnelCommunication(IRREPS_NODE.dim) for _ in range(num_layers)]) init = torch.tensor([1 / (phi ** 2), 1 / phi, 1.0], dtype=torch.float32) self.tunnel_logits = nn.Parameter(init.log()) self.wave3 = TunnelWaveModule(IRREPS_NODE.dim) self.wave6 = TunnelWaveModule(IRREPS_NODE.dim) self.wave9 = TunnelWaveModule(IRREPS_NODE.dim) def forward(self, x, pos): B, N, D = x.shape x3 = x + self.wave3(x) x6 = x + self.wave6(x) x9 = x + self.wave9(x) x3f, x6f, x9f = x3.reshape(B * N, D), x6.reshape(B * N, D), x9.reshape(B * N, D) mem3 = mem6 = mem9 = None e3, v3, d3, _, _, _ = build_neighbor_list(pos, cutoff=3.0, max_n=48) e6, v6, d6, _, _, _ = build_neighbor_list(pos, cutoff=6.0, max_n=72) e9, v9, d9, _, _, _ = build_neighbor_list(pos, cutoff=9.0, max_n=96) for i in range(len(self.layers3)): x3f, mem3 = self.layers3[i](x3f, e3, v3, d3, mem3) x6f, mem6 = self.layers6[i](x6f, e6, v6, d6, mem6) x9f, mem9 = self.layers9[i](x9f, e9, v9, d9, mem9) rw3, _, _ = build_latent_rewire_graph(x3f.view(B, N, D), pos, k=min(12, N - 1)) rw6, _, _ = build_latent_rewire_graph(x6f.view(B, N, D), pos, k=min(16, N - 1)) rw9, _, _ = build_latent_rewire_graph(x9f.view(B, N, D), pos, k=min(20, N - 1)) buf3 = torch.zeros_like(x3f); buf6 = torch.zeros_like(x6f); buf9 = torch.zeros_like(x9f) buf3.index_add_(0, rw3[1], x3f[rw3[0]]) buf6.index_add_(0, rw6[1], x6f[rw6[0]]) buf9.index_add_(0, rw9[1], x9f[rw9[0]]) x3f = x3f + 0.02 * buf3 x6f = x6f + 0.02 * buf6 x9f = x9f + 0.02 * buf9 y3, y6, y9 = self.cross[i](x3f.view(B, N, D), x6f.view(B, N, D), x9f.view(B, N, D)) x3f, x6f, x9f = y3.reshape(B * N, D), y6.reshape(B * N, D), y9.reshape(B * N, D) w = F.softmax(self.tunnel_logits, dim=0) return w[0] * x3f.view(B, N, D) + w[1] * x6f.view(B, N, D) + w[2] * x9f.view(B, N, D) # ============================================================================= # 7. BIOLOGICAL BRAIN MODULES (LIF with quantum threshold modulation) # ============================================================================= class LIFNeuronLayer(nn.Module): def __init__(self, dim, tau_m=20.0, v_th=1.0, v_reset=0.0, refractory_steps=2): super().__init__() self.dim = dim self.tau_m = tau_m self.v_th = v_th self.v_reset = v_reset self.refractory_steps = refractory_steps self.input_proj = nn.Linear(dim, dim) self.leak = nn.Parameter(torch.tensor(1.0 / tau_m)) self.register_buffer("membrane", torch.zeros(1, 1, dim)) self.register_buffer("refractory_counter", torch.zeros(1, 1, dim)) def reset_state(self, B, N, device, dtype): self.membrane = torch.zeros(B, N, self.dim, device=device, dtype=dtype) self.refractory_counter = torch.zeros(B, N, self.dim, device=device, dtype=dtype) def forward(self, x, threshold_mod=None): B, N, D = x.shape if self.membrane.shape[:2] != (B, N) or self.membrane.device != x.device: self.reset_state(B, N, x.device, x.dtype) inp = self.input_proj(x) not_refractory = (self.refractory_counter <= 0).float() # v4.1: Quantum Hilbert field modulates firing threshold v_th_eff = self.v_th if threshold_mod is not None: v_th_eff = self.v_th * (1.0 + 0.5 * (threshold_mod - 0.5)) self.membrane = self.membrane + not_refractory * (-self.leak * self.membrane + inp) spikes = (self.membrane >= v_th_eff).float() self.membrane = torch.where(spikes > 0, torch.full_like(self.membrane, self.v_reset), self.membrane) self.refractory_counter = torch.where( spikes > 0, torch.full_like(self.refractory_counter, float(self.refractory_steps)), torch.clamp(self.refractory_counter - 1.0, min=0.0) ) surrogate = torch.sigmoid(5.0 * (inp - v_th_eff)) out = spikes + surrogate - surrogate.detach() return out, self.membrane class IzhikevichNeuronLayer(nn.Module): def __init__(self, dim, a=0.02, b=0.2, c=-65.0, d=8.0): super().__init__() self.dim = dim self.a = a self.b = b self.c = c self.d = d self.in_proj = nn.Linear(dim, dim) self.register_buffer("v", torch.zeros(1, 1, dim)) self.register_buffer("u", torch.zeros(1, 1, dim)) def reset_state(self, B, N, device, dtype): self.v = torch.full((B, N, self.dim), -65.0, device=device, dtype=dtype) self.u = self.b * self.v def forward(self, x): B, N, D = x.shape if self.v.shape[:2] != (B, N) or self.v.device != x.device: self.reset_state(B, N, x.device, x.dtype) I = self.in_proj(x) self.v = self.v + 0.5 * (0.04 * self.v ** 2 + 5 * self.v + 140 - self.u + I) self.u = self.u + self.a * (self.b * self.v - self.u) spikes = (self.v >= 30.0).float() self.v = torch.where(spikes > 0, torch.full_like(self.v, self.c), self.v) self.u = torch.where(spikes > 0, self.u + self.d, self.u) surrogate = torch.sigmoid((I - 1.0) * 2.0) out = spikes + surrogate - surrogate.detach() return out, self.v class SynapseDynamics(nn.Module): def __init__(self, dim): super().__init__() self.dim = dim self.ampa_proj = nn.Linear(dim, dim) self.nmda_proj = nn.Linear(dim, dim) self.gaba_proj = nn.Linear(dim, dim) self.ampa_tau = 5.0 self.nmda_tau = 50.0 self.gaba_tau = 10.0 self.register_buffer("ampa_state", torch.zeros(1, 1, dim)) self.register_buffer("nmda_state", torch.zeros(1, 1, dim)) self.register_buffer("gaba_state", torch.zeros(1, 1, dim)) def reset_state(self, B, N, device, dtype): self.ampa_state = torch.zeros(B, N, self.dim, device=device, dtype=dtype) self.nmda_state = torch.zeros(B, N, self.dim, device=device, dtype=dtype) self.gaba_state = torch.zeros(B, N, self.dim, device=device, dtype=dtype) def forward(self, spikes, inhibitory_gate=None): B, N, D = spikes.shape if self.ampa_state.shape[:2] != (B, N) or self.ampa_state.device != spikes.device: self.reset_state(B, N, spikes.device, spikes.dtype) if inhibitory_gate is None: inhibitory_gate = 0.5 self.ampa_state = self.ampa_state * math.exp(-1 / self.ampa_tau) + self.ampa_proj(spikes) self.nmda_state = self.nmda_state * math.exp(-1 / self.nmda_tau) + self.nmda_proj(spikes) self.gaba_state = self.gaba_state * math.exp(-1 / self.gaba_tau) + inhibitory_gate * self.gaba_proj(spikes) syn = self.ampa_state + 0.5 * self.nmda_state - 0.8 * self.gaba_state return syn, self.ampa_state, self.nmda_state, self.gaba_state class NeuromodulatorSystem(nn.Module): def __init__(self, dim): super().__init__() self.to_dopa = nn.Sequential(nn.Linear(dim, dim // 2), nn.SiLU(), nn.Linear(dim // 2, 1)) self.to_ach = nn.Sequential(nn.Linear(dim, dim // 2), nn.SiLU(), nn.Linear(dim // 2, 1)) self.to_5ht = nn.Sequential(nn.Linear(dim, dim // 2), nn.SiLU(), nn.Linear(dim // 2, 1)) self.to_ne = nn.Sequential(nn.Linear(dim, dim // 2), nn.SiLU(), nn.Linear(dim // 2, 1)) def forward(self, x_summary): dopamine = torch.sigmoid(self.to_dopa(x_summary)) acetylcholine = torch.sigmoid(self.to_ach(x_summary)) serotonin = torch.sigmoid(self.to_5ht(x_summary)) noradrenaline = torch.sigmoid(self.to_ne(x_summary)) return { "dopamine": dopamine, "acetylcholine": acetylcholine, "serotonin": serotonin, "noradrenaline": noradrenaline } class STDPPlasticity(nn.Module): def __init__(self, dim, tau_pre=20.0, tau_post=20.0, a_plus=0.01, a_minus=0.012): super().__init__() self.dim = dim self.tau_pre = tau_pre self.tau_post = tau_post self.a_plus = a_plus self.a_minus = a_minus self.register_buffer("pre_trace", torch.zeros(1, 1, dim)) self.register_buffer("post_trace", torch.zeros(1, 1, dim)) def reset_state(self, B, N, device, dtype): self.pre_trace = torch.zeros(B, N, self.dim, device=device, dtype=dtype) self.post_trace = torch.zeros(B, N, self.dim, device=device, dtype=dtype) def forward(self, pre_spikes, post_spikes, reward=None): B, N, D = pre_spikes.shape if self.pre_trace.shape[:2] != (B, N) or self.pre_trace.device != pre_spikes.device: self.reset_state(B, N, pre_spikes.device, pre_spikes.dtype) self.pre_trace = self.pre_trace * math.exp(-1 / self.tau_pre) + pre_spikes self.post_trace = self.post_trace * math.exp(-1 / self.tau_post) + post_spikes ltp = self.a_plus * pre_spikes * self.post_trace ltd = self.a_minus * post_spikes * self.pre_trace delta = ltp - ltd if reward is not None: delta = delta * (1.0 + reward[:B].view(B, 1, 1)) return delta class HomeostaticPlasticity(nn.Module): def __init__(self, target_rate=0.1): super().__init__() self.target_rate = target_rate def forward(self, spikes): rate = spikes.mean(dim=1, keepdim=True) return (self.target_rate - rate) class CorticalMicrocircuit(nn.Module): def __init__(self, dim): super().__init__() self.dim = dim self.layer_embed = nn.Embedding(5, dim) self.pop_proj = nn.Linear(dim, dim) self.lif = LIFNeuronLayer(dim) self.izh = IzhikevichNeuronLayer(dim) self.syn = SynapseDynamics(dim) self.stdp = STDPPlasticity(dim) self.homeo = HomeostaticPlasticity(target_rate=0.12) conn = torch.tensor([ [0.30, 0.40, 0.20, 0.05, 0.05], # L2 -> ... [0.20, 0.25, 0.35, 0.15, 0.05], # L3 [0.10, 0.20, 0.25, 0.35, 0.10], # L4 [0.05, 0.10, 0.20, 0.35, 0.30], # L5 [0.10, 0.10, 0.20, 0.20, 0.40], # L6 ]) self.register_buffer("connectivity", conn) self.out_proj = nn.Sequential(nn.Linear(dim * 4, dim), nn.SiLU(), nn.Linear(dim, dim)) self.last_bio = {} def make_layer_ids(self, N, device): ids = [] for i in range(N): if i < max(1, int(0.18 * N)): ids.append(0) elif i < max(2, int(0.36 * N)): ids.append(1) elif i < max(3, int(0.56 * N)): ids.append(2) elif i < max(4, int(0.78 * N)): ids.append(3) else: ids.append(4) return torch.tensor(ids, device=device, dtype=torch.long) def forward(self, x, neuromod, threshold_mod=None): B, N, D = x.shape layer_ids = self.make_layer_ids(N, x.device) x = x + self.layer_embed(layer_ids).unsqueeze(0) lif_spikes, lif_v = self.lif(x, threshold_mod=threshold_mod) izh_spikes, izh_v = self.izh(x) spikes = 0.6 * lif_spikes + 0.4 * izh_spikes inhib = 0.5 + 0.5 * neuromod["serotonin"].mean() syn_out, ampa, nmda, gaba = self.syn(spikes, inhibitory_gate=inhib) layer_conn = self.connectivity[layer_ids][:, layer_ids] layer_conn = layer_conn.unsqueeze(0) micro = torch.einsum("bij,bjd->bid", layer_conn, self.pop_proj(x)) reward = neuromod["dopamine"].squeeze(-1) stdp_delta = self.stdp(spikes, syn_out, reward=reward) homeo = self.homeo(spikes) mod_scale = ( 1.0 + 0.30 * neuromod["dopamine"] + 0.20 * neuromod["acetylcholine"] + 0.15 * neuromod["noradrenaline"] - 0.10 * neuromod["serotonin"] ).unsqueeze(1) out = self.out_proj(torch.cat([x, syn_out, micro, stdp_delta + homeo], dim=-1)) out = x + mod_scale * out self.last_bio = { "lif_spike_rate": lif_spikes.mean(dim=(1, 2), keepdim=False).unsqueeze(-1), "izh_spike_rate": izh_spikes.mean(dim=(1, 2), keepdim=False).unsqueeze(-1), "ampa_mean": ampa.mean(dim=(1, 2), keepdim=False).unsqueeze(-1), "nmda_mean": nmda.mean(dim=(1, 2), keepdim=False).unsqueeze(-1), "gaba_mean": gaba.mean(dim=(1, 2), keepdim=False).unsqueeze(-1), "stdp_mean": stdp_delta.mean(dim=(1, 2), keepdim=False).unsqueeze(-1), "homeo_mean": homeo.mean(dim=(1, 2), keepdim=False).unsqueeze(-1), } return out # ============================================================================= # 7.5 v4.1 INSERT -- BASAL GANGLIA / HIPPOCAMPUS / WORKING MEMORY # ============================================================================= class BasalGangliaActor(nn.Module): def __init__(self, state_dim, n_actions=8, hidden=128): super().__init__() self.n_actions = n_actions self.direct = nn.Sequential(nn.Linear(state_dim, hidden), nn.SiLU(), nn.Linear(hidden, n_actions)) self.indirect = nn.Sequential(nn.Linear(state_dim, hidden), nn.SiLU(), nn.Linear(hidden, n_actions)) self.snr = nn.Sequential(nn.Linear(state_dim, hidden), nn.SiLU(), nn.Linear(hidden, 1)) self.dopamine_gain = nn.Sequential(nn.Linear(1, hidden), nn.SiLU(), nn.Linear(hidden, n_actions)) def forward(self, state, dopamine=None): Go = self.direct(state) NoGo = self.indirect(state) if dopamine is not None and dopamine.numel() > 0: Go = Go + self.dopamine_gain(dopamine) bg_gate = Go - NoGo action_probs = F.softmax(bg_gate, dim=-1) value = self.snr(state) return action_probs, value class HippocampalEpisodicMemory(nn.Module): def __init__(self, dim, capacity=512): super().__init__() self.capacity = capacity self.dim = dim self.register_buffer("memory", torch.randn(capacity, dim) * 0.01) self.register_buffer("usage", torch.zeros(capacity)) self.write_proj = nn.Sequential(nn.Linear(dim, dim), nn.LayerNorm(dim)) self.query_proj = nn.Sequential(nn.Linear(dim, dim), nn.LayerNorm(dim)) self.importance_proj = nn.Linear(dim, 1) def write(self, event_vec, importance=None): if event_vec.dim() == 1: event_vec = event_vec.unsqueeze(0) B = event_vec.shape[0] if B > self.capacity: B = self.capacity event_vec = event_vec[:B] _, idx = torch.topk(self.usage, k=B, largest=False) self.memory[idx] = self.write_proj(event_vec).detach() if importance is None: importance = torch.ones(B, device=event_vec.device) elif importance.dim() == 0: importance = importance.unsqueeze(0).expand(B) elif importance.shape[0] != B: importance = importance[:B] if importance.shape[0] > B else importance.expand(B) self.usage[idx] = importance.detach().squeeze() self.usage = self.usage * 0.9999 def retrieve(self, query): if query.dim() == 1: query = query.unsqueeze(0) q = self.query_proj(query) scores = torch.matmul(q, self.memory.t()) / math.sqrt(self.dim) weights = torch.softmax(scores, dim=-1) retrieved = torch.matmul(weights, self.memory) return retrieved, weights def replay(self, n_samples, noise_scale=0.05): device = self.memory.device idx = torch.randint(0, self.capacity, (n_samples,), device=device) samples = self.memory[idx] return samples + torch.randn_like(samples) * noise_scale def novelty_score(self, query): _, weights = self.retrieve(query) familiarity = weights.max(dim=-1).values return 1.0 - familiarity class PrefrontalWorkingMemory(nn.Module): def __init__(self, dim, n_attractors=4): super().__init__() self.dim = dim self.n_attractors = n_attractors self.attractor_pool = nn.Parameter(torch.randn(n_attractors, dim, dim) * 0.01) self.update_gate = nn.GRUCell(dim, dim) self.context_proj = nn.Sequential(nn.Linear(dim, dim), nn.SiLU(), nn.LayerNorm(dim)) def init_state(self, batch_size, device): return torch.zeros(batch_size, self.dim, device=device) def forward(self, x, state=None, reset=False): if x.dim() == 3: x = x.mean(dim=1) B = x.shape[0] if state is None or reset or state.shape[0] != B: state = self.init_state(B, x.device) attractor_field = torch.tanh(state @ self.attractor_pool.sum(0).t()) state = self.update_gate(x, state + 0.05 * attractor_field) context = self.context_proj(state) return context, state # ============================================================================= # 8. HIERARCHY # ============================================================================= class HierarchicalGraphBottleneck105(nn.Module): def __init__(self, dim): super().__init__() adj43, region43, coords43, lam43, mod43 = build_srichakra_cortical_graph() adj105, region105, coords105 = build_105node_hierarchical_graph() self.n_sri = 43 self.n_phys = 105 self.node_dim = max(32, dim // 4) self.to_sri = nn.Linear(dim, self.n_sri * self.node_dim) self.to_phys = nn.Linear(dim, self.n_phys * self.node_dim) self.sri_embed = nn.Parameter(torch.randn(self.n_sri, self.node_dim) * 0.02) self.phys_embed = nn.Parameter(torch.randn(self.n_phys, self.node_dim) * 0.02) self.sri_coord_proj = nn.Linear(3, self.node_dim) self.phys_coord_proj = nn.Linear(3, self.node_dim) self.sri_laminar_proj = nn.Linear(6, self.node_dim) self.sri_module_embed = nn.Embedding(4, self.node_dim) self.phys_region_embed = nn.Embedding(5, self.node_dim) self.sri_ff = nn.Sequential(nn.Linear(self.node_dim, self.node_dim), nn.SiLU()) self.sri_fb = nn.Sequential(nn.Linear(self.node_dim, self.node_dim), nn.SiLU()) self.phys_ff = nn.Sequential(nn.Linear(self.node_dim, self.node_dim), nn.SiLU()) self.phys_fb = nn.Sequential(nn.Linear(self.node_dim, self.node_dim), nn.SiLU()) self.cross_gate = nn.Sequential(nn.Linear(self.node_dim * 2, self.node_dim), nn.SiLU(), nn.Linear(self.node_dim, self.node_dim), nn.Sigmoid()) self.from_nodes = nn.Sequential(nn.Linear((self.n_sri + self.n_phys) * self.node_dim, dim), nn.SiLU(), nn.Linear(dim, dim)) self.norm = nn.LayerNorm(dim) self.register_buffer("adj43", adj43) self.register_buffer("coords43", coords43) self.register_buffer("lam43", lam43) self.register_buffer("mod43", mod43) self.register_buffer("adj105", adj105) self.register_buffer("coords105", coords105) self.register_buffer("region105", region105) def forward(self, x): B, N, D = x.shape x_mean = x.mean(dim=1) sri = self.to_sri(x_mean).view(B, self.n_sri, self.node_dim) sri = sri + self.sri_embed.unsqueeze(0) sri = sri + self.sri_coord_proj(self.coords43).unsqueeze(0) sri = sri + self.sri_laminar_proj(self.lam43).unsqueeze(0) sri = sri + self.sri_module_embed(self.mod43).unsqueeze(0) phys = self.to_phys(x_mean).view(B, self.n_phys, self.node_dim) phys = phys + self.phys_embed.unsqueeze(0) phys = phys + self.phys_coord_proj(self.coords105).unsqueeze(0) phys = phys + self.phys_region_embed(self.region105).unsqueeze(0) sri = sri + self.sri_ff(torch.einsum("ij,bjd->bid", self.adj43, sri)) sri = sri + self.sri_fb(torch.einsum("ij,bjd->bid", self.adj43.t(), sri)) phys = phys + self.phys_ff(torch.einsum("ij,bjd->bid", self.adj105, phys)) phys = phys + self.phys_fb(torch.einsum("ij,bjd->bid", self.adj105.t(), phys)) sri_global = sri[:, 0:1, :].expand(-1, self.n_phys, -1) phys_global = phys[:, 0:1, :].expand(-1, self.n_sri, -1) phys = phys + self.cross_gate(torch.cat([phys, sri_global], dim=-1)) * sri_global sri = sri + self.cross_gate(torch.cat([sri, phys_global], dim=-1)) * phys_global fusion = torch.cat([sri.reshape(B, -1), phys.reshape(B, -1)], dim=-1) return self.norm(x + self.from_nodes(fusion).unsqueeze(1)) # ============================================================================= # 9. FIDELITY / SSL / HEADS # ============================================================================= class MultiFidelityHead(nn.Module): def __init__(self, n_scalar): super().__init__() self.base_eam = nn.Sequential(nn.Linear(n_scalar, 64), nn.SiLU(), nn.Linear(64, 1)) self.delta_dft = nn.Sequential(nn.Linear(n_scalar, 64), nn.SiLU(), nn.Linear(64, 1)) self.delta_cc = nn.Sequential(nn.Linear(n_scalar, 64), nn.SiLU(), nn.Linear(64, 1)) def forward(self, scalars): pooled = scalars.mean(dim=1) e_eam = self.base_eam(pooled) d_dft = self.delta_dft(pooled) d_cc = self.delta_cc(pooled) return { "energy_eam": e_eam, "energy_dft": e_eam + d_dft, "energy_cc": e_eam + d_dft + d_cc } class SelfSupervisedHeads(nn.Module): def __init__(self, dim, channel_dim=8): super().__init__() self.mask_reconstruct = nn.Sequential(nn.Linear(dim, 64), nn.SiLU(), nn.Linear(64, channel_dim)) self.projector = nn.Sequential(nn.Linear(dim, dim), nn.GELU(), nn.Linear(dim, 128)) def forward(self, x): recon = self.mask_reconstruct(x) proj = F.normalize(self.projector(x.mean(dim=1)), dim=-1) return recon, proj class DiffusionHead(nn.Module): def __init__(self, n_scalar): super().__init__() self.path_head = nn.Sequential(nn.Linear(n_scalar, 32), nn.SiLU(), nn.Linear(32, 1)) self.barrier_head = nn.Sequential(nn.Linear(n_scalar, 32), nn.SiLU(), nn.Linear(32, 1)) self.hop_head = nn.Sequential(nn.Linear(n_scalar, 32), nn.SiLU(), nn.Linear(32, 1)) self.migration_head = nn.Sequential(nn.Linear(n_scalar, 32), nn.SiLU(), nn.Linear(32, 1)) def forward(self, scalars): pooled = scalars.mean(dim=1) return { "path": self.path_head(pooled), "barrier": self.barrier_head(pooled), "hop": self.hop_head(pooled), "migration_barrier": self.migration_head(pooled), } class EquivariantStressHead(nn.Module): def __init__(self, irreps_node=IRREPS_NODE): super().__init__() self.irreps_out = o3.Irreps("1x0e + 1x2e") self.linear = o3.Linear(irreps_node, self.irreps_out) basis = torch.zeros(6, 3, 3) basis[0] = torch.eye(3) / np.sqrt(3.0) basis[1] = torch.diag(torch.tensor([1.0, -1.0, 0.0])) / np.sqrt(2.0) basis[2] = torch.diag(torch.tensor([1.0, 1.0, -2.0])) / np.sqrt(6.0) basis[3] = torch.tensor([[0., 1., 0.], [1., 0., 0.], [0., 0., 0.]]) / np.sqrt(2.0) basis[4] = torch.tensor([[0., 0., 1.], [0., 0., 0.], [1., 0., 0.]]) / np.sqrt(2.0) basis[5] = torch.tensor([[0., 0., 0.], [0., 0., 1.], [0., 1., 0.]]) / np.sqrt(2.0) self.register_buffer("basis", basis) def forward(self, node_features): B, N, D = node_features.shape flat = node_features.view(B * N, D) coeffs = self.linear(flat) coeffs_pooled = coeffs.view(B, N, 6).mean(dim=1) return torch.einsum("bi,ijk->bjk", coeffs_pooled, self.basis) class ExtraPhysicsHeads(nn.Module): def __init__(self, n_scalar): super().__init__() self.elastic = nn.Sequential(nn.Linear(n_scalar, 64), nn.SiLU(), nn.Linear(64, 3)) self.vacancy = nn.Sequential(nn.Linear(n_scalar, 64), nn.SiLU(), nn.Linear(64, 1)) self.phonon_dos = nn.Sequential(nn.Linear(n_scalar, 64), nn.SiLU(), nn.Linear(64, 16)) self.charge = nn.Sequential(nn.Linear(n_scalar, 64), nn.SiLU(), nn.Linear(64, 1)) def forward(self, scalars): pooled = scalars.mean(dim=1) elastic = self.elastic(pooled) return { "C11": elastic[:, 0:1], "C12": elastic[:, 1:2], "C44": elastic[:, 2:3], "vacancy_formation": self.vacancy(pooled), "phonon_dos": self.phonon_dos(pooled), "charge_equil": self.charge(pooled), } # ============================================================================= # 10. MAIN BACKBONE (v4.1 quantum-brain) # ============================================================================= class EquivariantPhysicsBackboneV41(nn.Module): def __init__(self, num_layers=NUM_LAYERS, seed=None): super().__init__() if seed is not None: torch.manual_seed(seed) self.encoder = MultiPhysicsEncoderV4(IRREPS_NODE) # v4.1: Quantum oscillator replaces fixed brain waves self.osc = QuantumNeuralOscillator(IRREPS_NODE.dim) # v4.1: Maya Ida/Pingala router (7-head E/I attention) self.maya_router = MayaIdaPingalaRouter(IRREPS_NODE.dim, heads=7) self.tunnels = ThreeTunnelPhysics(num_layers=max(4, num_layers // 3)) self.global_transformer = GlobalGraphTransformerBlock(IRREPS_NODE.dim, num_heads=8) self.triplet = TripletPhysicsModule(IRREPS_NODE.dim) self.quadruplet = QuadrupletPhysicsModule(IRREPS_NODE.dim) self.hierarchy = HierarchicalGraphBottleneck105(IRREPS_NODE.dim) self.charge_eq = ChargeEquilibrationModule(IRREPS_NODE.dim, hidden=64, n_iter=3) self.longrange = LongRangeElectrostatics(alpha=0.35, gmax=2) self.neuromod = NeuromodulatorSystem(IRREPS_NODE.dim) self.microcircuit = CorticalMicrocircuit(IRREPS_NODE.dim) # v4.1 cognitive systems self.basal_ganglia = BasalGangliaActor(IRREPS_NODE.dim, n_actions=8) self.hippocampus = HippocampalEpisodicMemory(IRREPS_NODE.dim, capacity=512) self.working_memory = PrefrontalWorkingMemory(IRREPS_NODE.dim, n_attractors=4) n_scalar = IRREPS_NODE[0].mul self.head_energy = nn.Sequential(nn.Linear(n_scalar, 64), nn.SiLU(), nn.Linear(64, 1)) self.head_sigma = nn.Sequential(nn.Linear(n_scalar, 64), nn.SiLU(), nn.Linear(64, 1)) self.head_stress = EquivariantStressHead(IRREPS_NODE) self.head_diffusion = DiffusionHead(n_scalar) self.head_extra = ExtraPhysicsHeads(n_scalar) self.head_fidelity = MultiFidelityHead(n_scalar) self.ssl_heads = SelfSupervisedHeads(IRREPS_NODE.dim, channel_dim=8) def reset_working_memory(self): self._wm_state = None def forward(self, ld, gf, ie, gc, ss, hs, defect, diffusion, phonon, pos, lattice=None, z=None, mask_ratio=0.0, wm_state=None, reset_wm=False): if lattice is None: B = pos.shape[0] lattice = torch.eye(3, device=pos.device).unsqueeze(0).repeat(B, 1, 1) * 20.0 if z is None: z = torch.ones(pos.shape[:2], dtype=torch.long, device=pos.device) # v4.1: encoder now exports quantum neural parameters s, coherence, phase_order, neural_freq, neural_sync, threshold_mod = self.encoder( ld, gf, ie, gc, ss, hs, defect, diffusion, phonon, mask_ratio=mask_ratio ) # Quantum Hilbert waves -> neural oscillatory synchrony s, threshold_mod = self.osc(s, neural_freq, neural_sync, threshold_mod) # Maya 7-head E/I attention (Ida/Pingala balance) s = s + 0.05 * self.maya_router(s) s = self.tunnels(s, pos) s = self.global_transformer(s) _, _, _, _, _, triplet_idx = build_neighbor_list(pos, cutoff=6.0, max_n=64) quad_idx = build_quadruplets(pos, cutoff=6.0) s = self.triplet(s, pos, triplet_idx) s = self.quadruplet(s, pos, quad_idx) s = self.hierarchy(s) neuromod = self.neuromod(s.mean(dim=1)) # v4.1: pass quantum threshold modulation into cortical microcircuit s = self.microcircuit(s, neuromod, threshold_mod=threshold_mod) # Hippocampal context retrieval hippo_context, _ = self.hippocampus.retrieve(s.mean(dim=1)) s = s + 0.05 * hippo_context.unsqueeze(1) # Working memory persistent attractor wm_context, wm_state_out = self.working_memory(s, state=wm_state, reset=reset_wm) s = s + 0.05 * wm_context.unsqueeze(1) q, chi, eta = self.charge_eq(s, pos, z) e_long = self.longrange(q, pos, lattice) n_scalar = IRREPS_NODE[0].mul scalars = s[..., :n_scalar] pooled = scalars.mean(dim=1) energy_local = self.head_energy(pooled) log_sigma = self.head_sigma(pooled).clamp(min=-6.0, max=3.0) stress = self.head_stress(s) diff_heads = self.head_diffusion(scalars) extra_heads = self.head_extra(scalars) mf = self.head_fidelity(scalars) recon, proj = self.ssl_heads(s) state_summary = pooled out = { "energy": energy_local + e_long, "energy_local": energy_local, "energy_long": e_long, "log_sigma": log_sigma, "stress": stress, "coherence": coherence.mean(dim=1), "phase_order": phase_order.mean(dim=1), "charges": q, "electronegativity": chi, "hardness": eta, "ssl_recon": recon, "ssl_proj": proj, "dopamine": neuromod["dopamine"], "acetylcholine": neuromod["acetylcholine"], "serotonin": neuromod["serotonin"], "noradrenaline": neuromod["noradrenaline"], "state_summary": state_summary, "wm_state": wm_state_out, } out.update(self.microcircuit.last_bio) out.update(diff_heads) out.update(extra_heads) out.update(mf) return out class UncertaintyEnsemble(nn.Module): def __init__(self, base_class, num_models=NUM_MODELS, **model_kwargs): super().__init__() self.models = nn.ModuleList([base_class(seed=i * 42, **model_kwargs) for i in range(num_models)]) def forward(self, ld, gf, ie, gc, ss, hs, defect, diffusion, phonon, pos, lattice=None, z=None, wm_state=None, reset_wm=False): outs = [m(ld, gf, ie, gc, ss, hs, defect, diffusion, phonon, pos, lattice=lattice, z=z, wm_state=wm_state, reset_wm=reset_wm) for m in self.models] e = torch.stack([o["energy"] for o in outs], dim=0) s = torch.stack([o["stress"] for o in outs], dim=0) vals = torch.stack([o.get("state_summary", torch.zeros_like(e[0])) for o in outs], dim=0) return { "energy_mean": e.mean(dim=0), "energy_std": e.std(dim=0), "stress_mean": s.mean(dim=0), "stress_std": s.std(dim=0), "value_mean": vals.mean(dim=0), } # ============================================================================= # 11. DATA # ============================================================================= class HEMBDataset(Dataset): def __init__(self, samples): self.samples = samples e = torch.stack([s["energy"] for s in samples]).float() self.mean = e.mean() self.std = e.std() if self.std < 1e-8: self.std = torch.tensor(1.0) if "stress" in samples[0] and isinstance(samples[0]["stress"], torch.Tensor): s = torch.stack([s["stress"] for s in samples]).float() self.stress_mean = s.mean() self.stress_std = s.std() if self.stress_std < 1e-8: self.stress_std = torch.tensor(1.0) else: self.stress_mean = torch.tensor(0.0) self.stress_std = torch.tensor(1.0) def __len__(self): return len(self.samples) def __getitem__(self, i): s = self.samples[i] out = {**s, "energies_norm": (s["energy"] - self.mean) / self.std} if "stress" in s and isinstance(s["stress"], torch.Tensor): out["stress_norm"] = (s["stress"] - self.stress_mean) / self.stress_std out["fidelity"] = s.get("fidelity", "dft") return out def collate_fn(batch): out = {} for k in batch[0]: if isinstance(batch[0][k], torch.Tensor): out[k] = torch.stack([b[k].unsqueeze(0) if b[k].dim() == 1 else b[k] for b in batch]) if all(b[k].shape == batch[0][k].shape for b in batch) else torch.nn.utils.rnn.pad_sequence([b[k] for b in batch], batch_first=True) else: out[k] = [b[k] for b in batch] return out def compute_eam_fe_h(pos, z): n = len(z) dist = torch.norm(pos.unsqueeze(1) - pos.unsqueeze(0), dim=-1) dist = dist + torch.eye(n, device=dist.device) * 1e6 mask_fe = (z == 26).float() D, alpha, r0 = 0.8, 1.5, 2.5 morse = D * (torch.exp(-2 * alpha * (dist - r0)) - 2 * torch.exp(-alpha * (dist - r0))) morse = morse * torch.sigmoid((5.0 - dist) * 2.0) rho = (mask_fe.unsqueeze(0) * torch.exp(-((dist - r0) / 1.2) ** 2)).sum(dim=1) embed = -8.0 * torch.sqrt(rho + 1e-6) h_traps = 0.0 h_idx = (z == 1).nonzero(as_tuple=True)[0] for hi in h_idx: h_traps += -2.2 * torch.exp(-(pos[hi, 2] ** 2) / 0.35) h_h = 0.0 if len(h_idx) > 1: for a in range(len(h_idx)): for b in range(a + 1, len(h_idx)): h_h += 2.0 * torch.exp(-torch.norm(pos[h_idx[a]] - pos[h_idx[b]]) / 0.4) fe_h = 0.0 for hi in h_idx: fe_h += (mask_fe * 1.5 * torch.exp(-torch.norm(pos - pos[hi], dim=-1) / 0.35)).sum() return morse.sum() + embed.sum() + h_traps + h_h + fe_h def compute_virial_stress(pos, forces, lattice, velocities=None): vol = torch.det(lattice).clamp(min=1e-8) stress = torch.einsum("bij,bik->bjk", pos, forces) if velocities is not None: ke = torch.einsum("bij,bik->bjk", velocities, velocities) stress = stress + ke stress = 0.5 * (stress + stress.transpose(1, 2)) stress = stress / vol.view(-1, 1, 1) return stress def generate_diverse_gb_dataset(n_configs=200): samples = [] structures = ["bcc", "fcc", "hcp_packing"] fidelities = ["eam", "dft", "cc"] for idx in range(n_configs): struct = random.choice(structures) nx, ny = random.randint(2, 3), random.randint(2, 3) nz = random.randint(2, 3) a = 2.87 if struct == "bcc" else 3.57 if struct == "hcp_packing": a = 2.70 basis = torch.tensor([[0., 0., 0.], [0.5, 3**0.5/6, 0.5], [0.5, 3**0.5/3, 0.0]], dtype=torch.float32) * a else: basis = torch.tensor([[0., 0., 0.], [0.5, 0.5, 0.5]], dtype=torch.float32) * a grid = torch.stack(torch.meshgrid(torch.arange(nx), torch.arange(ny), torch.arange(nz), indexing="ij"), dim=-1).float() * a fe_pos = (grid.view(-1, 1, 3) + basis.view(1, -1, 3)).view(-1, 3) n_fe = min(len(fe_pos), 20) fe_pos = fe_pos[:n_fe // 2] sep = 2.0 + random.random() * 4.0 grain_a = fe_pos.clone(); grain_a[:, 2] += sep / 2 grain_b = fe_pos.clone(); grain_b[:, 2] -= sep / 2 angle = random.random() * 30.0 theta = torch.tensor(angle * np.pi / 180, dtype=torch.float32) R = torch.tensor([[torch.cos(theta), -torch.sin(theta), 0], [torch.sin(theta), torch.cos(theta), 0], [0, 0, 1.0]], dtype=torch.float32) grain_b = (R @ grain_b.T).T all_fe = torch.cat([grain_a, grain_b], dim=0) n_fe_total = min(len(all_fe), 200) n_h = random.randint(1, 3) h_positions = [] for _ in range(n_h): if random.random() < 0.7: h_pos = torch.tensor([(random.random() - 0.5) * a * nx, (random.random() - 0.5) * a * ny, (random.random() - 0.5) * 1.0], dtype=torch.float32) else: h_pos = torch.tensor([(random.random() - 0.5) * a * nx, (random.random() - 0.5) * a * ny, (random.random() - 0.5) * sep * 0.8], dtype=torch.float32) h_positions.append(h_pos) h_pos = torch.stack(h_positions) if h_positions else torch.empty(0, 3, dtype=torch.float32) pos = torch.cat([all_fe[:n_fe_total], h_pos], dim=0) pos = pos - pos.mean(dim=0) z = torch.cat([torch.full((n_fe_total,), 26, dtype=torch.long), torch.full((n_h,), 1, dtype=torch.long)]) pos = pos + torch.randn_like(pos) * (0.02 + random.random() * 0.08) pos = pos.requires_grad_(True) energy = compute_eam_fe_h(pos, z) fidelity = random.choice(fidelities) if fidelity == "dft": energy = energy + 0.1 * torch.sin(pos.norm(dim=-1)).sum() elif fidelity == "cc": energy = energy + 0.1 * torch.sin(pos.norm(dim=-1)).sum() + 0.02 * (pos[:, 0] ** 2).sum() forces = -torch.autograd.grad(energy, pos)[0] maxr = (pos.max(dim=0)[0] - pos.min(dim=0)[0]).max().item() lattice = torch.eye(3) * (maxr + 6.0) stress = compute_virial_stress(pos.unsqueeze(0), forces.unsqueeze(0), lattice.unsqueeze(0)).squeeze(0) samples.append({ "positions": pos.detach(), "atomic_numbers": z, "lattice": lattice, "energy": energy.detach(), "forces": forces.detach(), "stress": stress.detach(), "fidelity": fidelity }) return samples def load_materials_project_data(api_key="Xg7UiwgAHBFpvJjSUISGQDJASdA5YP5k", n_structures=500, save_dir=DFT_DIR): if not HAS_PMG or api_key is None: api_key = os.environ.get("MP_API_KEY", "Xg7UiwgAHBFpvJjSUISGQDJASdA5YP5k") if api_key is None: return generate_diverse_gb_dataset(n_configs=200) os.makedirs(save_dir, exist_ok=True) return generate_diverse_gb_dataset(n_configs=min(n_structures, 200)) def load_mptrj_real(cache_dir=DFT_DIR, max_samples=50000): try: from datasets import load_dataset ds = load_dataset("CederGroupHub/mptrj", split="train", streaming=True) samples = [] for i, item in enumerate(ds): if i >= max_samples: break try: samples.append({ "positions": torch.tensor(item["pos"]).float(), "atomic_numbers": torch.tensor(item["numbers"]).long(), "energy": torch.tensor(item["energy"]).float(), "forces": torch.tensor(item["forces"]).float(), "lattice": torch.tensor(item["cell"]).float() if "cell" in item else torch.eye(3) * 20.0, "stress": torch.tensor(item["stress"]).float() if "stress" in item else torch.zeros(3, 3), "fidelity": "dft" }) except Exception: continue return samples except Exception: return load_materials_project_data(n_structures=max_samples // 10) def load_oc20_real(split="val_id", max_samples=20000, cache_dir=DFT_DIR): try: from fairchem.core.datasets import LmdbDataset lmdb_path = os.path.join(cache_dir, "oc20", f"{split}.lmdb") if not os.path.exists(lmdb_path): return [] dataset = LmdbDataset({"src": lmdb_path}) samples = [] for i in range(min(len(dataset), max_samples)): data = dataset[i] samples.append({ "positions": data.pos, "atomic_numbers": data.atomic_numbers, "energy": data.y, "forces": data.force, "lattice": data.cell if hasattr(data, "cell") else torch.eye(3) * 20, "stress": torch.zeros(3, 3), "fidelity": "dft" }) return samples except Exception: return [] # ============================================================================= # 12. ACTIVE LEARNING PIPELINE # ============================================================================= class ActiveLearningManager: def __init__(self, ensemble_model, acquisition_size=16, work_dir=AL_DIR): self.ensemble = ensemble_model self.acquisition_size = acquisition_size self.work_dir = work_dir os.makedirs(self.work_dir, exist_ok=True) def score_pool(self, pool_samples): scored = [] for idx, s in enumerate(pool_samples): pos = s["positions"].unsqueeze(0).to(device) z = s["atomic_numbers"].unsqueeze(0).to(device) lattice = s.get("lattice", torch.eye(3).unsqueeze(0) * 20.0).to(device) ld, gf, ie, gc, ss, hs, defect, diffusion, phonon = build_multi_physics_channels(pos, z) with torch.no_grad(): out = self.ensemble(ld, gf, ie, gc, ss, hs, defect, diffusion, phonon, pos, lattice=lattice, z=z) uncertainty = out["energy_std"].abs().mean().item() + 0.2 * out["stress_std"].abs().mean().item() novelty = self.ensemble.models[0].hippocampus.novelty_score(out.get("value_mean", out["energy_mean"])).item() score = uncertainty + 0.5 * novelty scored.append((idx, score, uncertainty, novelty)) scored = sorted(scored, key=lambda x: x[1], reverse=True) chosen = [pool_samples[i] for i, _, _, _ in scored[:self.acquisition_size]] with open(os.path.join(self.work_dir, "acquisition_scores.json"), "w") as f: json.dump([{"idx": i, "score": s, "uncertainty": u, "novelty": n} for i, s, u, n in scored], f, indent=2) return chosen, scored def launch_mock_dft_jobs(self, samples): job_records = [] for i, s in enumerate(samples): job_id = f"DFT_JOB_{i:05d}" record = { "job_id": job_id, "status": "submitted", "n_atoms": int(s["positions"].shape[0]), "fidelity_target": "dft" } job_records.append(record) with open(os.path.join(self.work_dir, "submitted_jobs.json"), "w") as f: json.dump(job_records, f, indent=2) return job_records def collect_mock_dft_results(self, samples): enriched = [] for s in samples: s2 = copy.deepcopy(s) pos = s2["positions"].clone().requires_grad_(True) z = s2["atomic_numbers"] e = compute_eam_fe_h(pos, z) + 0.15 * torch.cos(pos.norm(dim=-1)).sum() + 0.03 * torch.sin(pos[:, 0]).sum() f = -torch.autograd.grad(e, pos)[0] lattice = s2.get("lattice", torch.eye(3) * 20.0) stress = compute_virial_stress(pos.unsqueeze(0), f.unsqueeze(0), lattice.unsqueeze(0)).squeeze(0) s2["energy"] = e.detach() s2["forces"] = f.detach() s2["stress"] = stress.detach() s2["fidelity"] = "dft" enriched.append(s2) return enriched def append_and_rebuild_dataset(self, base_samples, new_samples): merged = base_samples + new_samples with open(os.path.join(self.work_dir, "dataset_sizes.json"), "w") as f: json.dump({"before": len(base_samples), "added": len(new_samples), "after": len(merged)}, f, indent=2) return merged def sleep_consolidation(self, model, n_replays=200): if not hasattr(model, 'hippocampus'): return replays = model.hippocampus.replay(n_replays, noise_scale=0.06) with torch.no_grad(): model.hippocampus.write(replays, importance=torch.ones(n_replays, device=replays.device) * 0.5) print(f"[Sleep] Consolidated {n_replays} replay vectors.") def run_cycle(self, base_samples, pool_samples): selected, scored = self.score_pool(pool_samples) self.launch_mock_dft_jobs(selected) new_labels = self.collect_mock_dft_results(selected) merged = self.append_and_rebuild_dataset(base_samples, new_labels) if hasattr(self.ensemble.models[0], 'hippocampus'): for s in new_labels: pos = s["positions"].unsqueeze(0).to(device) z = s["atomic_numbers"].unsqueeze(0).to(device) ld, gf, ie, gc, ss, hs, defect, diffusion, phonon = build_multi_physics_channels(pos, z) with torch.no_grad(): out = self.ensemble.models[0](ld, gf, ie, gc, ss, hs, defect, diffusion, phonon, pos, z=z) summary = out["state_summary"].squeeze(0) self.ensemble.models[0].hippocampus.write(summary, importance=out.get("energy_std", torch.tensor(1.0)).abs()) return merged, selected, new_labels, scored # ============================================================================= # 13. MD / UTILS / PLOTS # ============================================================================= class MDState: def __init__(self, pos, vel, mass, cell=None): self.pos = pos self.vel = vel self.mass = mass self.cell = cell class ModelWithStats: def __init__(self, model, mean, std): self.model = model self.mean = mean self.std = std def __call__(self, *args, **kwargs): out = self.model(*args, **kwargs) out["energy"] = out["energy"] * self.std + self.mean return out def get_forces(model, state, z, wm_state=None, reset_wm=False): pos = state.pos.unsqueeze(0) if state.pos.dim() == 2 else state.pos pos = pos.requires_grad_(True) z = z.unsqueeze(0) if z.dim() == 1 else z lattice = state.cell.unsqueeze(0) if state.cell is not None and state.cell.dim() == 2 else state.cell if lattice is None: lattice = torch.eye(3, device=pos.device).unsqueeze(0) * 20.0 ld, gf, ie, gc, ss, hs, defect, diffusion, phonon = build_multi_physics_channels(pos.squeeze(0), z.squeeze(0)) out = model(ld, gf, ie, gc, ss, hs, defect, diffusion, phonon, pos, lattice=lattice, z=z, wm_state=wm_state, reset_wm=reset_wm) forces = -torch.autograd.grad(out["energy"].sum(), pos, create_graph=False)[0] return out["energy"].detach().squeeze(), forces.squeeze(0), pos.detach().squeeze(0), out.get("wm_state") def plot_srichakra_cortical_map(filename="srichakra_cortical_map.png"): adj, region, coords, laminar, modules = build_srichakra_cortical_graph() coords_np = coords.cpu().numpy() region_np = region.cpu().numpy() fig = plt.figure(figsize=(8, 8)) ax = fig.add_subplot(111) for i in range(adj.shape[0]): for j in range(i + 1, adj.shape[1]): if adj[i, j] > 0: ax.plot([coords_np[i, 0], coords_np[j, 0]], [coords_np[i, 1], coords_np[j, 1]], color="lightgray", lw=0.8, alpha=0.6) colors = {6: "gold", 5: "crimson", 4: "royalblue", 3: "seagreen", 2: "purple"} for r in [6, 5, 4, 3, 2]: idx = np.where(region_np == r)[0] ax.scatter(coords_np[idx, 0], coords_np[idx, 1], s=90, c=colors[r], label=f"Avarana {r}") ax.scatter(coords_np[0, 0], coords_np[0, 1], s=220, c="black", marker="*", label="Bindu / Thalamus") ax.set_title("Sri Chakra 43 Faces -> Cortical Columns (Exact Intersection Geometry)") ax.set_aspect("equal") ax.legend() ax.axis("off") plt.tight_layout() plt.savefig(filename, dpi=200, bbox_inches="tight") plt.close() def make_parity_plot(model, dataset, device): model.eval() e_true, e_pred = [], [] for i in range(min(40, len(dataset))): s = dataset[i] pos = s["positions"].unsqueeze(0).to(device).requires_grad_(True) z = s["atomic_numbers"].unsqueeze(0).to(device) lattice = s["lattice"].unsqueeze(0).to(device) ld, gf, ie, gc, ss, hs, defect, diffusion, phonon = build_multi_physics_channels(pos.squeeze(0), z.squeeze(0)) with torch.enable_grad(): out = model(ld, gf, ie, gc, ss, hs, defect, diffusion, phonon, pos, lattice=lattice, z=z) e_true.append(s["energy"].item()) e_pred.append(out["energy"].item() * dataset.std.item() + dataset.mean.item()) fig, ax = plt.subplots(figsize=(5, 5)) ax.scatter(e_true, e_pred, alpha=0.6) lim = [min(e_true + e_pred), max(e_true + e_pred)] ax.plot(lim, lim, "r--") ax.set_title(f"E MAE={np.mean(np.abs(np.array(e_true) - np.array(e_pred))):.3f}") ax.grid(True) plt.tight_layout() plt.savefig(os.path.join(PLOT_DIR, "parity_plot.png"), dpi=200) plt.close() # ============================================================================= # 14. GCMC / KMC / NEB # ============================================================================= class GrandCanonicalMC: def __init__(self, model, fe_pos, fe_z, T=300.0, mu_H=-2.5, device=device): self.model = model.eval() self.fe_pos = fe_pos self.fe_z = fe_z self.T = T self.mu_H = mu_H self.device = device self.beta = 1.0 / (kB_eV * T) self.h_pos = torch.empty(0, 3, device=device) if hasattr(self.model, 'reset_working_memory'): self.model.reset_working_memory() self.wm_state = None def _get_summary(self, h_pos): if len(h_pos) == 0: h_pos = torch.empty(0, 3, device=self.device) pos = torch.cat([self.fe_pos, h_pos], dim=0).unsqueeze(0) z = torch.cat([self.fe_z, torch.ones(len(h_pos), dtype=torch.long, device=self.device)]) lattice = torch.eye(3, device=self.device).unsqueeze(0) * 20.0 ld, gf, ie, gc, ss, hs, defect, diffusion, phonon = build_multi_physics_channels(pos, z.unsqueeze(0)) with torch.no_grad(): out = self.model(ld, gf, ie, gc, ss, hs, defect, diffusion, phonon, pos, lattice=lattice, z=z.unsqueeze(0), wm_state=self.wm_state, reset_wm=False) self.wm_state = out.get("wm_state") return out["state_summary"], out["energy"].item() def get_energy(self, h_pos): _, e = self._get_summary(h_pos) return e def _bg_proposal(self, state_summary): if not hasattr(self.model, 'basal_ganglia'): return None dopamine = torch.tensor([[0.5]], device=self.device) with torch.no_grad(): probs, value = self.model.basal_ganglia(state_summary, dopamine) action = torch.multinomial(probs, 1).item() return action, value.item() def run(self, n_steps=1000): for step in range(n_steps): summary, _ = self._get_summary(self.h_pos) bg_action = self._bg_proposal(summary) if step > 100 else None if np.random.rand() < 0.5: if bg_action is not None and bg_action[0] % 2 == 0: new_h = torch.randn(1, 3, device=self.device) * 0.3 else: new_h = torch.randn(1, 3, device=self.device) * 0.5 dE = self.get_energy(torch.cat([self.h_pos, new_h], dim=0)) - self.get_energy(self.h_pos) if np.random.rand() < min(1, np.exp(-self.beta * (dE - self.mu_H))): self.h_pos = torch.cat([self.h_pos, new_h], dim=0) else: if len(self.h_pos) > 0: idx = np.random.randint(len(self.h_pos)) if bg_action is not None and bg_action[0] % 2 == 1: idx = len(self.h_pos) - 1 h_new = torch.cat([self.h_pos[:idx], self.h_pos[idx + 1:]], dim=0) dE = self.get_energy(h_new) - self.get_energy(self.h_pos) if np.random.rand() < min(1, np.exp(-self.beta * (dE + self.mu_H))): self.h_pos = h_new return len(self.h_pos) class KineticMonteCarlo: def __init__(self, barriers, T=300, nu0=1e13): self.barriers = barriers self.kT = kB_eV * T self.nu0 = nu0 self.time = 0.0 def step(self, site): rates = {j: self.nu0 * np.exp(-e / self.kT) for (i, j), e in self.barriers.items() if i == site} if not rates: return site, 0.0 Rtot = sum(rates.values()) r = np.random.rand() * Rtot cum = 0.0 next_site = site for j, rate in rates.items(): cum += rate if r <= cum: next_site = j break dt = -np.log(np.random.rand()) / Rtot self.time += dt return next_site, dt class NEBPathFinder: def __init__(self, model, device=device): self.model = model self.device = device def find_barrier(self, pos_A, pos_B, z, n_images=7): images = [pos_A * (1 - i / (n_images - 1)) + pos_B * i / (n_images - 1) for i in range(n_images)] energies = [] for img in images: pos = img.unsqueeze(0).to(self.device) z_b = z.unsqueeze(0).to(self.device) lattice = torch.eye(3, device=self.device).unsqueeze(0) * 20.0 ld, gf, ie, gc, ss, hs, defect, diffusion, phonon = build_multi_physics_channels(pos, z_b) with torch.no_grad(): out = self.model(ld, gf, ie, gc, ss, hs, defect, diffusion, phonon, pos, lattice=lattice, z=z_b) energies.append(out["energy"].item()) return max(energies) - energies[0], energies def pressure_to_mu_H(P_bar, T=300.0, P0=1.0): mu_H2_0 = -4.48 mu_H2 = mu_H2_0 + kB_eV * T * np.log(P_bar / P0) return mu_H2 / 2.0 # ============================================================================= # 15. LOSSES # ============================================================================= def info_nce_loss(z1, z2, temperature=0.1): logits = z1 @ z2.t() / temperature labels = torch.arange(z1.shape[0], device=z1.device) return 0.5 * (F.cross_entropy(logits, labels) + F.cross_entropy(logits.t(), labels)) def compute_multifidelity_loss(out, batch): fidelity = batch["fidelity"] target = batch["energies_norm"].to(device).view(-1, 1) loss = torch.tensor(0.0, device=device) for i, fid in enumerate(fidelity): if fid == "eam": loss = loss + F.mse_loss(out["energy_eam"][i:i+1], target[i:i+1]) elif fid == "dft": loss = loss + F.mse_loss(out["energy_dft"][i:i+1], target[i:i+1]) else: loss = loss + F.mse_loss(out["energy_cc"][i:i+1], target[i:i+1]) return loss / max(1, len(fidelity)) def compute_biological_loss(out): target_mid = torch.full_like(out["dopamine"], 0.5) loss = ( F.mse_loss(out["lif_spike_rate"], torch.full_like(out["lif_spike_rate"], 0.10)) + F.mse_loss(out["izh_spike_rate"], torch.full_like(out["izh_spike_rate"], 0.10)) + F.mse_loss(out["ampa_mean"], torch.zeros_like(out["ampa_mean"])) + F.mse_loss(out["nmda_mean"], torch.zeros_like(out["nmda_mean"])) + F.mse_loss(out["gaba_mean"], torch.zeros_like(out["gaba_mean"])) + F.mse_loss(out["stdp_mean"], torch.zeros_like(out["stdp_mean"])) + F.mse_loss(out["homeo_mean"], torch.zeros_like(out["homeo_mean"])) + F.mse_loss(out["dopamine"], target_mid) + F.mse_loss(out["acetylcholine"], target_mid) + F.mse_loss(out["serotonin"], target_mid) + F.mse_loss(out["noradrenaline"], target_mid) ) return loss # ============================================================================= # 16. TRAIN / ACTIVE LEARNING MAIN # ============================================================================= def train_model(model, train_ds, test_ds, epochs=40, batch_size=1): loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, collate_fn=collate_fn) test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, collate_fn=collate_fn) if USE_COMPILE: model = torch.compile(model, mode="reduce-overhead") try: optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5, weight_decay=0.01, fused=True) except (TypeError, RuntimeError): optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5, weight_decay=0.01) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max(epochs, 20), eta_min=1e-6) scaler = torch.cuda.amp.GradScaler(enabled=USE_FP16) if USE_FP16 else None train_mean = train_ds.mean.to(device) train_std = train_ds.std.to(device) train_stress_mean = train_ds.stress_mean.to(device) if torch.is_tensor(train_ds.stress_mean) else torch.tensor(train_ds.stress_mean, device=device) train_stress_std = train_ds.stress_std.to(device) if torch.is_tensor(train_ds.stress_std) else torch.tensor(train_ds.stress_std, device=device) best_val = float("inf") amp_dtype = torch.bfloat16 if USE_BF16 else torch.float16 for epoch in range(epochs): model.train() total_loss = 0.0 n_batch = 0; mse_sum = 0.0; mae_sum = 0.0 do_consistency = (CONSIST_EVERY > 0) and (epoch % CONSIST_EVERY == 0) for batch in loader: optimizer.zero_grad(set_to_none=True) pos = batch["positions"].to(device).requires_grad_(True) z = batch["atomic_numbers"].to(device) lattice = batch.get("lattice", torch.eye(3, device=device).unsqueeze(0).repeat(pos.shape[0], 1, 1) * 20.0) if isinstance(lattice, list): lattice = torch.stack(lattice).to(device) else: lattice = lattice.to(device) ld, gf, ie, gc, ss, hs, defect, diffusion, phonon = build_multi_physics_channels(pos, z) with torch.amp.autocast(device_type=device.type, dtype=amp_dtype, enabled=(USE_BF16 or USE_FP16)): out = model(ld, gf, ie, gc, ss, hs, defect, diffusion, phonon, pos, lattice=lattice, z=z, mask_ratio=0.15) e_pred = out["energy"] s_pred = out["stress"] log_sigma = out["log_sigma"] target_e = batch["energies_norm"].to(device).view(-1) loss_e = torch.mean(torch.exp(-log_sigma.view(-1)) * (e_pred.view(-1) - target_e) ** 2 + log_sigma.view(-1)) forces_pred = -torch.autograd.grad(e_pred.sum(), pos, create_graph=True, retain_graph=True)[0] target_f = batch["forces"].to(device) loss_f = F.smooth_l1_loss(forces_pred, target_f, beta=0.1) if "stress_norm" in batch: stress_true = batch["stress_norm"].to(device) else: with torch.no_grad(): raw_stress = compute_virial_stress(pos.detach(), target_f, lattice) stress_true = (raw_stress - train_stress_mean) / (train_stress_std + 1e-8) loss_s = F.mse_loss(s_pred, stress_true) diff_target_path = diffusion.mean(dim=1)[..., 2:3] diff_target_barrier = diffusion.mean(dim=1)[..., 3:4] diff_target_hop = diffusion.mean(dim=1)[..., 4:5] loss_diff = ( F.mse_loss(out["path"], diff_target_path) + F.mse_loss(out["barrier"], diff_target_barrier) + F.mse_loss(out["hop"], diff_target_hop) + F.mse_loss(out["migration_barrier"], diff_target_barrier) ) loss_electro = F.mse_loss(out["charge_equil"], torch.zeros_like(out["charge_equil"])) + 0.01 * out["energy_long"].abs().mean() aux_zero = torch.zeros_like(out["coherence"]) loss_aux = ( F.mse_loss(out["coherence"], aux_zero + 0.5) + F.mse_loss(out["phase_order"], aux_zero) + F.mse_loss(out["vacancy_formation"], defect.mean(dim=1)[..., 0:1]) + F.mse_loss(out["C11"], ss.mean(dim=1)[..., 0:1]) + F.mse_loss(out["C12"], ss.mean(dim=1)[..., 1:2]) + F.mse_loss(out["C44"], ss.mean(dim=1)[..., 2:3]) ) loss_mf = compute_multifidelity_loss(out, batch) aug_noise = 0.01 * torch.randn_like(pos) pos2 = (pos.detach() + aug_noise).requires_grad_(False) ld2, gf2, ie2, gc2, ss2, hs2, defect2, diffusion2, phonon2 = build_multi_physics_channels(pos2, z) with torch.amp.autocast(device_type=device.type, dtype=amp_dtype, enabled=(USE_BF16 or USE_FP16)): out2 = model(ld2, gf2, ie2, gc2, ss2, hs2, defect2, diffusion2, phonon2, pos2, lattice=lattice, z=z, mask_ratio=0.15) loss_ssl = F.mse_loss(out["ssl_recon"], ld.detach()) + info_nce_loss(out["ssl_proj"], out2["ssl_proj"]) loss_bio = compute_biological_loss(out) loss_cons = torch.tensor(0.0, device=device) if do_consistency: with torch.no_grad(): dx = torch.randn_like(pos) * 0.01 pos1 = (pos + dx).detach().requires_grad_(False) ld1, gf1, ie1, gc1, ss1, hs1, d1, df1, ph1 = build_multi_physics_channels(pos1, z) with torch.amp.autocast(device_type=device.type, dtype=amp_dtype, enabled=(USE_BF16 or USE_FP16)): out1 = model(ld1, gf1, ie1, gc1, ss1, hs1, d1, df1, ph1, pos1, lattice=lattice, z=z) work = -(dx * forces_pred.detach()).sum(dim=(1, 2)) loss_cons = F.mse_loss(out1["energy"] - e_pred.detach(), work) loss = ( loss_e + FORCE_WEIGHT * loss_f + STRESS_WEIGHT * loss_s + AUX_WEIGHT * (loss_aux + loss_diff) + ELECTRO_WEIGHT * loss_electro + MULTIFID_WEIGHT * loss_mf + SSL_WEIGHT * loss_ssl + BIO_WEIGHT * loss_bio + CONSIST_WEIGHT * loss_cons ) if torch.isnan(loss): continue if USE_FP16: scaler.scale(loss).backward() scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP) scaler.step(optimizer) scaler.update() else: loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP) optimizer.step() total_loss += loss.item(); _ep=out1["energy"].detach().float().view(-1); _et=e_pred.detach().float().view(-1); mse_sum+=F.mse_loss(_ep,_et).item(); mae_sum+=F.l1_loss(_ep,_et).item() n_batch += 1 avg_loss = total_loss / max(n_batch, 1); avg_mse = mse_sum / max(n_batch, 1); avg_mae = mae_sum / max(n_batch, 1); print(f"Epoch {epoch+1:>3}/{epochs} | Loss: {avg_loss:.4f} | MSE: {avg_mse:.4f} | MAE: {avg_mae:.4f}"); scheduler.step() model.eval() val_loss = 0.0 val_n = 0 for batch in test_loader: pos = batch["positions"].to(device).clone().detach().requires_grad_(True) z = batch["atomic_numbers"].to(device) lattice = batch.get("lattice", torch.eye(3, device=device).unsqueeze(0).repeat(pos.shape[0], 1, 1) * 20.0) if isinstance(lattice, list): lattice = torch.stack(lattice).to(device) else: lattice = lattice.to(device) ld, gf, ie, gc, ss, hs, defect, diffusion, phonon = build_multi_physics_channels(pos, z) with torch.enable_grad(): out = model(ld, gf, ie, gc, ss, hs, defect, diffusion, phonon, pos, lattice=lattice, z=z) e_true = batch["energies_norm"].to(device).view(-1) val_loss += F.l1_loss(out["energy"].view(-1), e_true).item() val_n += 1 avg_val = val_loss / max(val_n, 1) if avg_val < best_val: best_val = avg_val torch.save(model.state_dict(), os.path.join(SAVE_DIR, "best.pt")) if (epoch + 1) % 10 == 0: print(f"Epoch {epoch+1:03d} | Train Loss: {total_loss / max(n_batch, 1):.5f} | Val Loss: {avg_val:.5f}") torch.save(model.state_dict(), os.path.join(SAVE_DIR, "final.pt")) return model # ============================================================================= # 17. MAIN # ============================================================================= if __name__ == "__main__": print("=" * 100) print("MPHGNet v4.1-Q -- Quantum Hilbert + Maya E/I + Exact Sri Chakra Geometry") print("=" * 100) raw = load_mptrj_real(max_samples=8000) if len(raw) == 0: raw = load_oc20_real(max_samples=8000) if len(raw) == 0: raw = load_materials_project_data(api_key=os.environ.get("MP_API_KEY", None)) if len(raw) == 0: raise RuntimeError("No dataset could be loaded.") n_train = int(0.8 * len(raw)) train_raw = raw[:n_train] test_raw = raw[n_train:] train_ds = HEMBDataset(train_raw) test_ds = HEMBDataset(test_raw) plot_srichakra_cortical_map(os.path.join(PLOT_DIR, "srichakra_cortical_map.png")) model = EquivariantPhysicsBackboneV41(num_layers=NUM_LAYERS).to(device) model = train_model(model, train_ds, test_ds, epochs=40, batch_size=1) make_parity_plot(model, test_ds, device) ensemble = UncertaintyEnsemble(EquivariantPhysicsBackboneV41, num_models=2, num_layers=NUM_LAYERS).to(device) for m in ensemble.models: m.load_state_dict(torch.load(os.path.join(SAVE_DIR, "best.pt"), map_location=device), strict=False) # Active learning closed loop al_base = train_raw for cycle in range(2): print(f"\n[AL] Cycle {cycle + 1}") pool = generate_diverse_gb_dataset(n_configs=64) manager = ActiveLearningManager(ensemble, acquisition_size=16, work_dir=os.path.join(AL_DIR, f"cycle_{cycle+1}")) merged, selected, new_labels, scored = manager.run_cycle(al_base, pool) # Sleep consolidation manager.sleep_consolidation(ensemble.models[0], n_replays=200) al_base = merged train_ds = HEMBDataset(al_base) split = int(0.9 * len(al_base)) retrain_train = HEMBDataset(al_base[:split]) retrain_val = HEMBDataset(al_base[split:]) model = EquivariantPhysicsBackboneV41(num_layers=NUM_LAYERS).to(device) model = train_model(model, retrain_train, retrain_val, epochs=15, batch_size=1) for m in ensemble.models: m.load_state_dict(model.state_dict(), strict=False) print(f"[AL] Added {len(new_labels)} new DFT-labelled structures. Dataset size = {len(al_base)}") # Demonstrate BG-guided GCMC print("\n[GCMC] Running basal-ganglia-guided grand-canonical Monte Carlo...") demo_pos = train_raw[0]["positions"].to(device) demo_z = train_raw[0]["atomic_numbers"].to(device) fe_mask = (demo_z == 26) gcmc = GrandCanonicalMC(ensemble.models[0], demo_pos[fe_mask], demo_z[fe_mask], T=300.0, mu_H=-2.5, device=device) final_nH = gcmc.run(n_steps=500) print(f"[GCMC] Final H count after BG-guided insertion/removal: {final_nH}") print("\n=== AUTONOMOUS DISCOVERY LOOP COMPLETE ===")