import torch from safetensors.torch import save_file # Max of two 2-bit unsigned numbers # Inputs: a1, a0, b1, b0 # Outputs: m1, m0 = max(a, b) # # Logic: if a >= b then output a, else output b # a >= b iff (a1 > b1) OR (a1 == b1 AND a0 >= b0) weights = {} # Layer 1: Basic comparisons # a1_gt_b1 = a1 AND NOT b1 weights['l1.a1_gt_b1.weight'] = torch.tensor([[1.0, 0.0, -1.0, 0.0]], dtype=torch.float32) weights['l1.a1_gt_b1.bias'] = torch.tensor([-1.0], dtype=torch.float32) # b1_gt_a1 = b1 AND NOT a1 weights['l1.b1_gt_a1.weight'] = torch.tensor([[-1.0, 0.0, 1.0, 0.0]], dtype=torch.float32) weights['l1.b1_gt_a1.bias'] = torch.tensor([-1.0], dtype=torch.float32) # a0_gt_b0 = a0 AND NOT b0 weights['l1.a0_gt_b0.weight'] = torch.tensor([[0.0, 1.0, 0.0, -1.0]], dtype=torch.float32) weights['l1.a0_gt_b0.bias'] = torch.tensor([-1.0], dtype=torch.float32) # b0_gt_a0 = b0 AND NOT a0 weights['l1.b0_gt_a0.weight'] = torch.tensor([[0.0, -1.0, 0.0, 1.0]], dtype=torch.float32) weights['l1.b0_gt_a0.bias'] = torch.tensor([-1.0], dtype=torch.float32) # a1_eq_b1 = NOT(a1 XOR b1) - fires when a1 == b1 # This needs XOR components # a1_eq_b1 = (a1 AND b1) OR (NOT a1 AND NOT b1) = XNOR(a1, b1) # Using: NOT(a1 OR b1) for both 0, and (a1 AND b1) for both 1 weights['l1.both1_high.weight'] = torch.tensor([[1.0, 0.0, 1.0, 0.0]], dtype=torch.float32) weights['l1.both1_high.bias'] = torch.tensor([-2.0], dtype=torch.float32) weights['l1.both1_low.weight'] = torch.tensor([[-1.0, 0.0, -1.0, 0.0]], dtype=torch.float32) weights['l1.both1_low.bias'] = torch.tensor([0.0], dtype=torch.float32) # Pass through inputs for MUX weights['l1.a1.weight'] = torch.tensor([[1.0, 0.0, 0.0, 0.0]], dtype=torch.float32) weights['l1.a1.bias'] = torch.tensor([-0.5], dtype=torch.float32) weights['l1.a0.weight'] = torch.tensor([[0.0, 1.0, 0.0, 0.0]], dtype=torch.float32) weights['l1.a0.bias'] = torch.tensor([-0.5], dtype=torch.float32) weights['l1.b1.weight'] = torch.tensor([[0.0, 0.0, 1.0, 0.0]], dtype=torch.float32) weights['l1.b1.bias'] = torch.tensor([-0.5], dtype=torch.float32) weights['l1.b0.weight'] = torch.tensor([[0.0, 0.0, 0.0, 1.0]], dtype=torch.float32) weights['l1.b0.bias'] = torch.tensor([-0.5], dtype=torch.float32) # Layer 2 # a1_eq_b1 = both1_high OR both1_low # Inputs: [a1_gt_b1, b1_gt_a1, a0_gt_b0, b0_gt_a0, both1_high, both1_low, a1, a0, b1, b0] weights['l2.a1_eq_b1.weight'] = torch.tensor([[0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0]], dtype=torch.float32) weights['l2.a1_eq_b1.bias'] = torch.tensor([-1.0], dtype=torch.float32) # a_ge_b_part2 = a1_eq_b1 AND NOT b0_gt_a0 (i.e., a1==b1 and a0>=b0) # Actually: a0 >= b0 means NOT(b0 > a0) # So: a1_eq_b1 AND NOT b0_gt_a0 # This needs a1_eq_b1 from this layer... we need to split # Simpler: compute a_gt_b and b_gt_a, then select # a_gt_b = a1_gt_b1 OR (a1_eq_b1 AND a0_gt_b0) # For now, let's pass through what we need # Pass through for v in ['a1_gt_b1', 'b1_gt_a1', 'a0_gt_b0', 'b0_gt_a0', 'a1', 'a0', 'b1', 'b0']: idx = ['a1_gt_b1', 'b1_gt_a1', 'a0_gt_b0', 'b0_gt_a0', 'both1_high', 'both1_low', 'a1', 'a0', 'b1', 'b0'].index(v) w = [0.0] * 10 w[idx] = 1.0 weights[f'l2.{v}.weight'] = torch.tensor([w], dtype=torch.float32) weights[f'l2.{v}.bias'] = torch.tensor([-0.5], dtype=torch.float32) # Layer 3 # Inputs: [a1_eq_b1, a1_gt_b1, b1_gt_a1, a0_gt_b0, b0_gt_a0, a1, a0, b1, b0] # a_gt_b_part2 = a1_eq_b1 AND a0_gt_b0 weights['l3.a_gt_b_part2.weight'] = torch.tensor([[1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0]], dtype=torch.float32) weights['l3.a_gt_b_part2.bias'] = torch.tensor([-2.0], dtype=torch.float32) # a1_eq_b1 AND a0_eq_b0 (both equal) - for tie case, output a # a0_eq_b0 = NOT(a0_gt_b0 OR b0_gt_a0) weights['l3.a0_neq_b0.weight'] = torch.tensor([[0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0]], dtype=torch.float32) weights['l3.a0_neq_b0.bias'] = torch.tensor([-1.0], dtype=torch.float32) # Pass through for v in ['a1_gt_b1', 'a1', 'a0', 'b1', 'b0', 'a1_eq_b1']: if v == 'a1_eq_b1': idx = 0 else: idx = ['a1_eq_b1', 'a1_gt_b1', 'b1_gt_a1', 'a0_gt_b0', 'b0_gt_a0', 'a1', 'a0', 'b1', 'b0'].index(v) w = [0.0] * 9 w[idx] = 1.0 weights[f'l3.{v}.weight'] = torch.tensor([w], dtype=torch.float32) weights[f'l3.{v}.bias'] = torch.tensor([-0.5], dtype=torch.float32) # Layer 4 # Inputs: [a_gt_b_part2, a0_neq_b0, a1_gt_b1, a1, a0, b1, b0, a1_eq_b1] # a_gt_b = a1_gt_b1 OR a_gt_b_part2 weights['l4.a_gt_b.weight'] = torch.tensor([[1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0]], dtype=torch.float32) weights['l4.a_gt_b.bias'] = torch.tensor([-1.0], dtype=torch.float32) # a_eq_b = a1_eq_b1 AND NOT a0_neq_b0 weights['l4.a_eq_b.weight'] = torch.tensor([[0.0, -1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0]], dtype=torch.float32) weights['l4.a_eq_b.bias'] = torch.tensor([-1.0], dtype=torch.float32) # Pass through for v in ['a1', 'a0', 'b1', 'b0']: idx = ['a_gt_b_part2', 'a0_neq_b0', 'a1_gt_b1', 'a1', 'a0', 'b1', 'b0', 'a1_eq_b1'].index(v) w = [0.0] * 8 w[idx] = 1.0 weights[f'l4.{v}.weight'] = torch.tensor([w], dtype=torch.float32) weights[f'l4.{v}.bias'] = torch.tensor([-0.5], dtype=torch.float32) # Layer 5 # Inputs: [a_gt_b, a_eq_b, a1, a0, b1, b0] # a_ge_b = a_gt_b OR a_eq_b (select a when a >= b) weights['l5.a_ge_b.weight'] = torch.tensor([[1.0, 1.0, 0.0, 0.0, 0.0, 0.0]], dtype=torch.float32) weights['l5.a_ge_b.bias'] = torch.tensor([-1.0], dtype=torch.float32) # Pass through for v in ['a1', 'a0', 'b1', 'b0']: idx = ['a_gt_b', 'a_eq_b', 'a1', 'a0', 'b1', 'b0'].index(v) w = [0.0] * 6 w[idx] = 1.0 weights[f'l5.{v}.weight'] = torch.tensor([w], dtype=torch.float32) weights[f'l5.{v}.bias'] = torch.tensor([-0.5], dtype=torch.float32) # Layer 6: MUX outputs # Inputs: [a_ge_b, a1, a0, b1, b0] # m1 = (a1 AND a_ge_b) OR (b1 AND NOT a_ge_b) weights['l6.m1_a.weight'] = torch.tensor([[1.0, 1.0, 0.0, 0.0, 0.0]], dtype=torch.float32) weights['l6.m1_a.bias'] = torch.tensor([-2.0], dtype=torch.float32) weights['l6.m1_b.weight'] = torch.tensor([[-1.0, 0.0, 0.0, 1.0, 0.0]], dtype=torch.float32) weights['l6.m1_b.bias'] = torch.tensor([-1.0], dtype=torch.float32) weights['l6.m0_a.weight'] = torch.tensor([[1.0, 0.0, 1.0, 0.0, 0.0]], dtype=torch.float32) weights['l6.m0_a.bias'] = torch.tensor([-2.0], dtype=torch.float32) weights['l6.m0_b.weight'] = torch.tensor([[-1.0, 0.0, 0.0, 0.0, 1.0]], dtype=torch.float32) weights['l6.m0_b.bias'] = torch.tensor([-1.0], dtype=torch.float32) # Layer 7: Final OR # m1 = m1_a OR m1_b weights['l7.m1.weight'] = torch.tensor([[1.0, 1.0, 0.0, 0.0]], dtype=torch.float32) weights['l7.m1.bias'] = torch.tensor([-1.0], dtype=torch.float32) weights['l7.m0.weight'] = torch.tensor([[0.0, 0.0, 1.0, 1.0]], dtype=torch.float32) weights['l7.m0.bias'] = torch.tensor([-1.0], dtype=torch.float32) save_file(weights, 'model.safetensors') # Verification def max2(a1, a0, b1, b0): inp = torch.tensor([float(a1), float(a0), float(b1), float(b0)]) # Layer 1 l1_keys = ['a1_gt_b1', 'b1_gt_a1', 'a0_gt_b0', 'b0_gt_a0', 'both1_high', 'both1_low', 'a1', 'a0', 'b1', 'b0'] l1 = {k: int((inp @ weights[f'l1.{k}.weight'].T + weights[f'l1.{k}.bias'] >= 0).item()) for k in l1_keys} l1_out = torch.tensor([float(l1[k]) for k in l1_keys]) # Layer 2 l2_keys = ['a1_eq_b1', 'a1_gt_b1', 'b1_gt_a1', 'a0_gt_b0', 'b0_gt_a0', 'a1', 'a0', 'b1', 'b0'] l2 = {k: int((l1_out @ weights[f'l2.{k}.weight'].T + weights[f'l2.{k}.bias'] >= 0).item()) for k in l2_keys} l2_out = torch.tensor([float(l2[k]) for k in l2_keys]) # Layer 3 l3_keys = ['a_gt_b_part2', 'a0_neq_b0', 'a1_gt_b1', 'a1', 'a0', 'b1', 'b0', 'a1_eq_b1'] l3 = {k: int((l2_out @ weights[f'l3.{k}.weight'].T + weights[f'l3.{k}.bias'] >= 0).item()) for k in l3_keys} l3_out = torch.tensor([float(l3[k]) for k in l3_keys]) # Layer 4 l4_keys = ['a_gt_b', 'a_eq_b', 'a1', 'a0', 'b1', 'b0'] l4 = {k: int((l3_out @ weights[f'l4.{k}.weight'].T + weights[f'l4.{k}.bias'] >= 0).item()) for k in l4_keys} l4_out = torch.tensor([float(l4[k]) for k in l4_keys]) # Layer 5 l5_keys = ['a_ge_b', 'a1', 'a0', 'b1', 'b0'] l5 = {k: int((l4_out @ weights[f'l5.{k}.weight'].T + weights[f'l5.{k}.bias'] >= 0).item()) for k in l5_keys} l5_out = torch.tensor([float(l5[k]) for k in l5_keys]) # Layer 6 l6_keys = ['m1_a', 'm1_b', 'm0_a', 'm0_b'] l6 = {k: int((l5_out @ weights[f'l6.{k}.weight'].T + weights[f'l6.{k}.bias'] >= 0).item()) for k in l6_keys} l6_out = torch.tensor([float(l6[k]) for k in l6_keys]) # Layer 7 m1 = int((l6_out @ weights['l7.m1.weight'].T + weights['l7.m1.bias'] >= 0).item()) m0 = int((l6_out @ weights['l7.m0.weight'].T + weights['l7.m0.bias'] >= 0).item()) return m1, m0 print("Verifying max2...") errors = 0 for a in range(4): for b in range(4): a1, a0 = (a >> 1) & 1, a & 1 b1, b0 = (b >> 1) & 1, b & 1 m1, m0 = max2(a1, a0, b1, b0) result = 2*m1 + m0 expected = max(a, b) if result != expected: errors += 1 print(f"ERROR: max({a}, {b}) = {result}, expected {expected}") if errors == 0: print("All 16 test cases passed!") else: print(f"FAILED: {errors} errors") mag = sum(t.abs().sum().item() for t in weights.values()) print(f"Magnitude: {mag:.0f}")