import torch from safetensors.torch import save_file weights = {} # Input order: [a3, a2, a1, a0] (a3 is MSB) # clz returns count of leading zeros (0-4) # Layer 1: Priority detection from MSB # has3: a3 is set (clz = 0) weights['has3.weight'] = torch.tensor([[1.0, 0.0, 0.0, 0.0]], dtype=torch.float32) weights['has3.bias'] = torch.tensor([-1.0], dtype=torch.float32) # has2_first: a2 is first set from MSB (clz = 1) weights['has2_first.weight'] = torch.tensor([[-1.0, 1.0, 0.0, 0.0]], dtype=torch.float32) weights['has2_first.bias'] = torch.tensor([-1.0], dtype=torch.float32) # has1_first: a1 is first set from MSB (clz = 2) weights['has1_first.weight'] = torch.tensor([[-1.0, -1.0, 1.0, 0.0]], dtype=torch.float32) weights['has1_first.bias'] = torch.tensor([-1.0], dtype=torch.float32) # has0_first: a0 is first set from MSB (clz = 3) weights['has0_first.weight'] = torch.tensor([[-1.0, -1.0, -1.0, 1.0]], dtype=torch.float32) weights['has0_first.bias'] = torch.tensor([-1.0], dtype=torch.float32) # all_zero: no bits set (clz = 4) weights['all_zero.weight'] = torch.tensor([[-1.0, -1.0, -1.0, -1.0]], dtype=torch.float32) weights['all_zero.bias'] = torch.tensor([0.0], dtype=torch.float32) # Layer 2: Encode to binary # Input order: [has3, has2_first, has1_first, has0_first, all_zero] # clz 0=000, 1=001, 2=010, 3=011, 4=100 # y0 = has2_first OR has0_first (clz is 1 or 3) weights['y0.weight'] = torch.tensor([[0.0, 1.0, 0.0, 1.0, 0.0]], dtype=torch.float32) weights['y0.bias'] = torch.tensor([-1.0], dtype=torch.float32) # y1 = has1_first OR has0_first (clz is 2 or 3) weights['y1.weight'] = torch.tensor([[0.0, 0.0, 1.0, 1.0, 0.0]], dtype=torch.float32) weights['y1.bias'] = torch.tensor([-1.0], dtype=torch.float32) # y2 = all_zero (clz is 4) weights['y2.weight'] = torch.tensor([[0.0, 0.0, 0.0, 0.0, 1.0]], dtype=torch.float32) weights['y2.bias'] = torch.tensor([-1.0], dtype=torch.float32) save_file(weights, 'model.safetensors') # Verify def clz4(a3, a2, a1, a0): inp = torch.tensor([float(a3), float(a2), float(a1), float(a0)]) # Layer 1 has3 = int((inp @ weights['has3.weight'].T + weights['has3.bias'] >= 0).item()) has2_first = int((inp @ weights['has2_first.weight'].T + weights['has2_first.bias'] >= 0).item()) has1_first = int((inp @ weights['has1_first.weight'].T + weights['has1_first.bias'] >= 0).item()) has0_first = int((inp @ weights['has0_first.weight'].T + weights['has0_first.bias'] >= 0).item()) all_zero = int((inp @ weights['all_zero.weight'].T + weights['all_zero.bias'] >= 0).item()) # Layer 2 l1 = torch.tensor([float(has3), float(has2_first), float(has1_first), float(has0_first), float(all_zero)]) y0 = int((l1 @ weights['y0.weight'].T + weights['y0.bias'] >= 0).item()) y1 = int((l1 @ weights['y1.weight'].T + weights['y1.bias'] >= 0).item()) y2 = int((l1 @ weights['y2.weight'].T + weights['y2.bias'] >= 0).item()) return [y2, y1, y0] print("Verifying clz4...") errors = 0 for i in range(16): a3, a2, a1, a0 = (i >> 3) & 1, (i >> 2) & 1, (i >> 1) & 1, i & 1 # Compute expected clz if a3: expected_clz = 0 elif a2: expected_clz = 1 elif a1: expected_clz = 2 elif a0: expected_clz = 3 else: expected_clz = 4 expected = [(expected_clz >> 2) & 1, (expected_clz >> 1) & 1, expected_clz & 1] result = clz4(a3, a2, a1, a0) if result != expected: errors += 1 print(f"ERROR: {a3}{a2}{a1}{a0} clz={expected_clz} -> {result}, expected {expected}") if errors == 0: print("All 16 test cases passed!") else: print(f"FAILED: {errors} errors") mag = sum(t.abs().sum().item() for t in weights.values()) print(f"Magnitude: {mag:.0f}")