import torch from safetensors.torch import save_file weights = { 'neuron.weight': torch.tensor([[-1.0, -1.0, -1.0]], dtype=torch.float32), 'neuron.bias': torch.tensor([2.0], dtype=torch.float32) } save_file(weights, 'model.safetensors') def test(a, b, c): inp = torch.tensor([float(a), float(b), float(c)]) return int((inp @ weights['neuron.weight'].T + weights['neuron.bias'] >= 0).item()) print("Verifying atmost2outof3...") errors = 0 for i in range(8): a, b, c = (i >> 2) & 1, (i >> 1) & 1, i & 1 result = test(a, b, c) expected = 1 if (a + b + c) <= 2 else 0 if result != expected: errors += 1 if errors == 0: print("All 8 test cases passed!") print(f"Magnitude: {sum(t.abs().sum().item() for t in weights.values()):.0f}")