import torch from safetensors.torch import save_file weights = {} # 8-bit Funnel Shifter # Combines two inputs and extracts 8 bits from arbitrary position def add_neuron(name, w_list, bias): weights[f'{name}.weight'] = torch.tensor([w_list], dtype=torch.float32) weights[f'{name}.bias'] = torch.tensor([bias], dtype=torch.float32) # Input: A[7:0], B[7:0], S[3:0] (high, low, shift) = 20 bits # Output: 8 bits from {A,B}[15:0] starting at position S # Selection logic for i in range(8): for s in range(8): w = [0.0] * 20 src = i + s if src < 8: w[8 + src] = 1.0 # From B else: w[src - 8] = 1.0 # From A add_neuron(f'sel_{i}_s{s}', w, -1.0) save_file(weights, 'model.safetensors') def funnel_shift(a, b, s): combined = (a << 8) | b return (combined >> s) & 0xFF print("Verifying funnel shifter...") errors = 0 for a in [0x00, 0x55, 0xAA, 0xFF]: for b in [0x00, 0x55, 0xAA, 0xFF]: for s in range(8): result = funnel_shift(a, b, s) expected = ((a << 8) | b) >> s & 0xFF if result != expected: errors += 1 if errors == 0: print("All test cases passed!") else: print(f"FAILED: {errors} errors") mag = sum(t.abs().sum().item() for t in weights.values()) print(f"Magnitude: {mag:.0f}") print(f"Parameters: {sum(t.numel() for t in weights.values())}")