threshold-atmost2outof4 / create_safetensors.py
CharlesCNorton
At most 2 of 4 threshold circuit, magnitude 6
bacce1d
Raw
History Blame Contribute Delete
832 Bytes
import torch
from safetensors.torch import save_file
weights = {
'neuron.weight': torch.tensor([[-1.0, -1.0, -1.0, -1.0]], dtype=torch.float32),
'neuron.bias': torch.tensor([2.0], dtype=torch.float32)
}
save_file(weights, 'model.safetensors')
def atmost2of4(a, b, c, d):
inp = torch.tensor([float(a), float(b), float(c), float(d)])
return int((inp @ weights['neuron.weight'].T + weights['neuron.bias'] >= 0).item())
print("Verifying atmost2outof4...")
errors = 0
for i in range(16):
a, b, c, d = (i >> 3) & 1, (i >> 2) & 1, (i >> 1) & 1, i & 1
result = atmost2of4(a, b, c, d)
expected = 1 if (a + b + c + d) <= 2 else 0
if result != expected:
errors += 1
if errors == 0:
print("All 16 test cases passed!")
print(f"Magnitude: {sum(t.abs().sum().item() for t in weights.values()):.0f}")