#!/usr/bin/env python3 """ Distillix Demo - The Resurrected BitNet Model Minimal version to avoid Gradio 4.44.0 bugs """ import os import torch import torch.nn as nn import torch.nn.functional as F from huggingface_hub import hf_hub_download from transformers import AutoTokenizer # ============ Model Definition ============ class RMSNorm(nn.Module): def __init__(self, d, eps=1e-6): super().__init__() self.weight = nn.Parameter(torch.ones(d)) self.eps = eps def forward(self, x): return self.weight * x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) class RotaryEmbedding(nn.Module): def __init__(self, dim): super().__init__() inv_freq = 1.0 / (1000000 ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer("inv_freq", inv_freq) def forward(self, x, T): t = torch.arange(T, device=x.device, dtype=self.inv_freq.dtype) freqs = torch.outer(t, self.inv_freq) emb = torch.cat((freqs, freqs), dim=-1) return emb.cos(), emb.sin() def rotate_half(x): x1, x2 = x.chunk(2, dim=-1) return torch.cat((-x2, x1), dim=-1) class Attention(nn.Module): def __init__(self): super().__init__() self.q_proj = nn.Linear(768, 768, bias=False) self.k_proj = nn.Linear(768, 256, bias=False) self.v_proj = nn.Linear(768, 256, bias=False) self.o_proj = nn.Linear(768, 768, bias=False) self.q_norm = RMSNorm(64) self.k_norm = RMSNorm(64) self.rotary = RotaryEmbedding(64) def forward(self, x): B, T, _ = x.shape q = self.q_proj(x).view(B, T, 12, 64).transpose(1, 2) k = self.k_proj(x).view(B, T, 4, 64).transpose(1, 2) v = self.v_proj(x).view(B, T, 4, 64).transpose(1, 2) q, k = self.q_norm(q), self.k_norm(k) cos, sin = self.rotary(x, T) cos, sin = cos.unsqueeze(0).unsqueeze(0), sin.unsqueeze(0).unsqueeze(0) q = (q * cos) + (rotate_half(q) * sin) k = (k * cos) + (rotate_half(k) * sin) k, v = k.repeat_interleave(3, dim=1), v.repeat_interleave(3, dim=1) attn = F.softmax((torch.matmul(q, k.transpose(-2, -1)) / 8.0).masked_fill( torch.triu(torch.ones(T, T, device=x.device, dtype=torch.bool), diagonal=1), float('-inf')), dim=-1) return self.o_proj(torch.matmul(attn, v).transpose(1, 2).contiguous().view(B, T, -1)) class MLP(nn.Module): def __init__(self): super().__init__() self.gate_proj = nn.Linear(768, 2048, bias=False) self.up_proj = nn.Linear(768, 2048, bias=False) self.down_proj = nn.Linear(2048, 768, bias=False) def forward(self, x): return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) class Block(nn.Module): def __init__(self): super().__init__() self.input_layernorm = RMSNorm(768) self.self_attn = Attention() self.post_attention_layernorm = RMSNorm(768) self.mlp = MLP() def forward(self, x): x = x + self.self_attn(self.input_layernorm(x)) return x + self.mlp(self.post_attention_layernorm(x)) class StudentLLM(nn.Module): def __init__(self): super().__init__() self.embed_tokens = nn.Embedding(32000, 768) self.layers = nn.ModuleList([Block() for _ in range(12)]) self.norm = RMSNorm(768) def forward(self, input_ids): x = self.embed_tokens(input_ids) for layer in self.layers: x = layer(x) return F.linear(self.norm(x), self.embed_tokens.weight) # ============ Load Model ============ print("Loading Distillix (The Resurrected BitNet)...") device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Device: {device}") HF_TOKEN = os.environ.get("HF_TOKEN", "") model_path = hf_hub_download( repo_id="rileyseaburg/distillix", filename="inflation/inflation-2000.pt", token=HF_TOKEN if HF_TOKEN else None ) model = StudentLLM() ckpt = torch.load(model_path, map_location='cpu', weights_only=False) model.load_state_dict(ckpt.get('model_state_dict', ckpt), strict=False) model.eval().to(device) tokenizer = AutoTokenizer.from_pretrained("NousResearch/Llama-2-7b-hf") tokenizer.pad_token = tokenizer.eos_token # Compute stats total_params = sum(p.numel() for p in model.parameters()) zero_params = sum((p == 0).sum().item() for p in model.parameters()) sparsity = 100 * zero_params / total_params mlp_weights = [] for name, p in model.named_parameters(): if 'mlp' in name and 'weight' in name: mlp_weights.append(p.std().item()) mlp_std = sum(mlp_weights) / len(mlp_weights) if mlp_weights else 0 print(f"Model loaded!") print(f"Model Stats: Sparsity={sparsity:.1f}%, MLP Std={mlp_std:.4f}") # ============ Generation ============ @torch.no_grad() def generate_response(user_message, max_tokens=150, temperature=0.7): """Generate a response from the model.""" if not user_message.strip(): return "Please enter a message." # Format as chat prompt = f"### User:\n{user_message}\n\n### Assistant:\n" input_ids = tokenizer(prompt, return_tensors="pt")["input_ids"].to(device) if input_ids.shape[1] > 512: input_ids = input_ids[:, -512:] generated = input_ids.clone() for _ in range(int(max_tokens)): logits = model(generated)[:, -1, :] if temperature > 0.01: logits = logits / temperature probs = F.softmax(logits, dim=-1) next_token = torch.multinomial(probs, 1) else: next_token = logits.argmax(dim=-1, keepdim=True) generated = torch.cat([generated, next_token], dim=1) if next_token.item() == tokenizer.eos_token_id: break # Stop at next user turn if generated.shape[1] > input_ids.shape[1] + 10: decoded_end = tokenizer.decode(generated[0][-15:]) if "### User:" in decoded_end: break full_response = tokenizer.decode(generated[0], skip_special_tokens=True) if "### Assistant:" in full_response: response = full_response.split("### Assistant:")[-1].strip() if "### User:" in response: response = response.split("### User:")[0].strip() else: response = full_response return response # ============ Gradio UI ============ # Import gradio AFTER model loading to avoid issues import gradio as gr # Use gr.Blocks with explicit, simple components to avoid schema bugs with gr.Blocks(title="Distillix Demo", theme=gr.themes.Soft()) as demo: gr.Markdown(""" # Distillix: The Resurrected BitNet **A 125M parameter BitNet 1.58-bit model resurrected from the dead.** This model was completely dead (99% zero weights) and brought back using Geometric Engineering (Wasserstein loss + SVD denoising). It now writes working Python code! """) with gr.Row(): with gr.Column(scale=3): user_input = gr.Textbox( label="Your Question", placeholder="Write a Python function to calculate fibonacci numbers", lines=3 ) with gr.Column(scale=1): max_tokens = gr.Slider( minimum=50, maximum=300, value=150, step=10, label="Max Tokens" ) temperature = gr.Slider( minimum=0.1, maximum=1.2, value=0.7, step=0.1, label="Temperature" ) submit_btn = gr.Button("Generate", variant="primary") output = gr.Textbox(label="Distillix Response", lines=12) # Simple click handler submit_btn.click( fn=generate_response, inputs=[user_input, max_tokens, temperature], outputs=output ) # Also submit on Enter user_input.submit( fn=generate_response, inputs=[user_input, max_tokens, temperature], outputs=output ) gr.Markdown(""" --- **Example prompts:** - Write a Python function to calculate fibonacci numbers - How do I reverse a string in Python? - Explain what a binary search algorithm does - Create a simple stack class in Python [Model Card](https://huggingface.co/rileyseaburg/distillix) | [Paper (Coming Soon)](#) """) if __name__ == "__main__": # HuggingFace Spaces handles server_name/port demo.launch(server_name="0.0.0.0", server_port=7860)