import torch from safetensors.torch import save_file weights = {} # Input order: d0..d15, s3, s2, s1, s0 (20 inputs) # Layer 1: 16 neurons, each selects di when s = i # Layer 2: OR gate layer1_weights = [] layer1_biases = [] for i in range(16): w = [0.0] * 20 # Data input weight w[i] = 1.0 # Select weights: +1 if bit should be 1, -1 if bit should be 0 s3_bit = (i >> 3) & 1 s2_bit = (i >> 2) & 1 s1_bit = (i >> 1) & 1 s0_bit = i & 1 w[16] = 1.0 if s3_bit else -1.0 # s3 w[17] = 1.0 if s2_bit else -1.0 # s2 w[18] = 1.0 if s1_bit else -1.0 # s1 w[19] = 1.0 if s0_bit else -1.0 # s0 # Bias: -(1 + popcount(i)) bias = -(1 + bin(i).count('1')) layer1_weights.append(w) layer1_biases.append(bias) weights['layer1.weight'] = torch.tensor(layer1_weights, dtype=torch.float32) weights['layer1.bias'] = torch.tensor(layer1_biases, dtype=torch.float32) # Layer 2: OR gate weights['layer2.weight'] = torch.tensor([[1.0] * 16], dtype=torch.float32) weights['layer2.bias'] = torch.tensor([-1.0], dtype=torch.float32) save_file(weights, 'model.safetensors') # Verify def mux16(data, s3, s2, s1, s0): inp = torch.tensor([float(d) for d in data] + [float(s3), float(s2), float(s1), float(s0)]) l1 = (inp @ weights['layer1.weight'].T + weights['layer1.bias'] >= 0).float() out = (l1 @ weights['layer2.weight'].T + weights['layer2.bias'] >= 0).float() return int(out.item()) print("Verifying MUX16...") errors = 0 test_count = 0 for s in range(16): s3, s2, s1, s0 = (s >> 3) & 1, (s >> 2) & 1, (s >> 1) & 1, s & 1 # Test with selected data = 1, others = 0 data = [0] * 16 data[s] = 1 result = mux16(data, s3, s2, s1, s0) if result != 1: errors += 1 print(f"ERROR: s={s}, d[{s}]=1 -> {result}, expected 1") test_count += 1 # Test with selected data = 0 data[s] = 0 result = mux16(data, s3, s2, s1, s0) if result != 0: errors += 1 print(f"ERROR: s={s}, d[{s}]=0 -> {result}, expected 0") test_count += 1 if errors == 0: print(f"All {test_count} test cases passed!") else: print(f"FAILED: {errors} errors") mag = sum(t.abs().sum().item() for t in weights.values()) print(f"Magnitude: {mag:.0f}")