File size: 1,461 Bytes
d2fd8e4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 | import torch
from safetensors.torch import save_file
weights = {}
# 8-bit Funnel Shifter
# Combines two inputs and extracts 8 bits from arbitrary position
def add_neuron(name, w_list, bias):
weights[f'{name}.weight'] = torch.tensor([w_list], dtype=torch.float32)
weights[f'{name}.bias'] = torch.tensor([bias], dtype=torch.float32)
# Input: A[7:0], B[7:0], S[3:0] (high, low, shift) = 20 bits
# Output: 8 bits from {A,B}[15:0] starting at position S
# Selection logic
for i in range(8):
for s in range(8):
w = [0.0] * 20
src = i + s
if src < 8:
w[8 + src] = 1.0 # From B
else:
w[src - 8] = 1.0 # From A
add_neuron(f'sel_{i}_s{s}', w, -1.0)
save_file(weights, 'model.safetensors')
def funnel_shift(a, b, s):
combined = (a << 8) | b
return (combined >> s) & 0xFF
print("Verifying funnel shifter...")
errors = 0
for a in [0x00, 0x55, 0xAA, 0xFF]:
for b in [0x00, 0x55, 0xAA, 0xFF]:
for s in range(8):
result = funnel_shift(a, b, s)
expected = ((a << 8) | b) >> s & 0xFF
if result != expected:
errors += 1
if errors == 0:
print("All test cases passed!")
else:
print(f"FAILED: {errors} errors")
mag = sum(t.abs().sum().item() for t in weights.values())
print(f"Magnitude: {mag:.0f}")
print(f"Parameters: {sum(t.numel() for t in weights.values())}")
|