import torch from safetensors.torch import save_file weights = { 'neuron.weight': torch.tensor([[-1.0, -1.0, -1.0, -1.0, -1.0]], dtype=torch.float32), 'neuron.bias': torch.tensor([3.0], dtype=torch.float32) } save_file(weights, 'model.safetensors') def atmost3of5(a, b, c, d, e): inp = torch.tensor([float(a), float(b), float(c), float(d), float(e)]) return int((inp @ weights['neuron.weight'].T + weights['neuron.bias'] >= 0).item()) print("Verifying atmost3outof5...") errors = 0 for i in range(32): bits = [(i >> j) & 1 for j in range(5)] result = atmost3of5(*bits) expected = 1 if sum(bits) <= 3 else 0 if result != expected: errors += 1 if errors == 0: print("All 32 test cases passed!") print(f"Magnitude: {sum(t.abs().sum().item() for t in weights.values()):.0f}")