Spaces:
Configuration error
Configuration error
Upload app.py with huggingface_hub
Browse files
app.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import spaces
|
| 2 |
+
import torch, torch.nn as nn, torch.nn.functional as F, math, gc
|
| 3 |
+
from scipy.stats import norm as sp_norm
|
| 4 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
| 5 |
+
from torchao.quantization import quantize_, Int4WeightOnlyConfig
|
| 6 |
+
import torchao.quantization.utils as _tao_utils
|
| 7 |
+
from threading import Thread
|
| 8 |
+
import gradio as gr
|
| 9 |
+
|
| 10 |
+
MODEL = "nvidia/Nemotron-Cascade-2-30B-A3B"
|
| 11 |
+
BS = 128
|
| 12 |
+
_C = {}
|
| 13 |
+
|
| 14 |
+
def get_centroids(bits):
|
| 15 |
+
if bits in _C: return _C[bits]
|
| 16 |
+
n = 1 << bits; bd = torch.linspace(-4.0, 4.0, n + 1); ct = torch.zeros(n)
|
| 17 |
+
for _ in range(100):
|
| 18 |
+
for i in range(n):
|
| 19 |
+
a, b = bd[i].item(), bd[i+1].item(); pa, pb = sp_norm.cdf(a), sp_norm.cdf(b)
|
| 20 |
+
ct[i] = (sp_norm.pdf(a) - sp_norm.pdf(b)) / (pb - pa) if pb - pa > 1e-12 else (a + b) / 2
|
| 21 |
+
for i in range(1, n): bd[i] = (ct[i-1] + ct[i]) / 2
|
| 22 |
+
_C[bits] = ct; return ct
|
| 23 |
+
|
| 24 |
+
def _build_H(n):
|
| 25 |
+
if n == 1: return torch.tensor([[1.0]])
|
| 26 |
+
h = _build_H(n // 2)
|
| 27 |
+
return torch.cat([torch.cat([h,h],1), torch.cat([h,-h],1)], 0) / math.sqrt(2)
|
| 28 |
+
|
| 29 |
+
for b in [2,3,4,5,6]: get_centroids(b)
|
| 30 |
+
H_W = _build_H(BS)
|
| 31 |
+
|
| 32 |
+
def should_quantize(name, param):
|
| 33 |
+
if param.ndim < 2 or param.numel() < 256: return False
|
| 34 |
+
if any(k in name for k in ["norm","layernorm","rmsnorm"]): return False
|
| 35 |
+
if any(k in name for k in ["A_log",".D","dt_bias","conv1d"]): return False
|
| 36 |
+
if "bias" in name and param.ndim == 1: return False
|
| 37 |
+
if name.endswith(".gate.weight") or "router" in name: return False
|
| 38 |
+
return True
|
| 39 |
+
|
| 40 |
+
_orig = _tao_utils.guard_dtype_size
|
| 41 |
+
def _patched(t, n, dtype=None, size=None):
|
| 42 |
+
if dtype is not None and t.dtype != dtype: t.data = t.data.to(dtype)
|
| 43 |
+
if size is not None and t.size() != size: raise ValueError(f"{size} vs {t.size()}")
|
| 44 |
+
_tao_utils.guard_dtype_size = _patched
|
| 45 |
+
|
| 46 |
+
print("Loading Nemotron on CPU...")
|
| 47 |
+
model = AutoModelForCausalLM.from_pretrained(MODEL, dtype=torch.bfloat16, device_map="cpu",
|
| 48 |
+
attn_implementation="sdpa", trust_remote_code=True)
|
| 49 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL, trust_remote_code=True)
|
| 50 |
+
_loaded = False
|
| 51 |
+
|
| 52 |
+
@spaces.GPU(duration=300)
|
| 53 |
+
def respond(message, history):
|
| 54 |
+
global _loaded, model
|
| 55 |
+
if not _loaded:
|
| 56 |
+
H_dev = H_W.to("cuda"); ct5 = get_centroids(5).to("cuda")
|
| 57 |
+
int4_cfg = Int4WeightOnlyConfig(group_size=128); n = 0
|
| 58 |
+
for name, child in list(model.named_modules()):
|
| 59 |
+
if not isinstance(child, nn.Linear) or child.weight.device.type == "meta": continue
|
| 60 |
+
if not should_quantize(name, child.weight): continue
|
| 61 |
+
w = child.weight.data.float().to("cuda"); out_f, in_f = w.shape
|
| 62 |
+
pad = (BS - in_f % BS) % BS
|
| 63 |
+
if pad > 0: w = F.pad(w, (0, pad))
|
| 64 |
+
nb = w.shape[1] // BS; w = w.reshape(out_f, nb, BS)
|
| 65 |
+
for i in range(0, out_f, 64):
|
| 66 |
+
e = min(i+64, out_f); w[i:e] = (w[i:e].reshape(-1, BS) @ H_dev).reshape(e-i, nb, BS)
|
| 67 |
+
norms = w.norm(dim=2, keepdim=True).clamp(min=1e-10); w.div_(norms).mul_(math.sqrt(BS))
|
| 68 |
+
codes = torch.empty(out_f, nb, BS, dtype=torch.int8, device="cuda")
|
| 69 |
+
for ci in range(0, out_f, 256):
|
| 70 |
+
ce = min(ci+256, out_f); codes[ci:ce] = (w[ci:ce].unsqueeze(-1) - ct5.view(1,1,1,-1)).abs().argmin(-1).to(torch.int8)
|
| 71 |
+
del w; vals = torch.empty(out_f, nb, BS, dtype=torch.float32, device="cuda")
|
| 72 |
+
for ci in range(0, out_f, 256):
|
| 73 |
+
ce = min(ci+256, out_f); vals[ci:ce] = ct5[codes[ci:ce].long()] / math.sqrt(BS)
|
| 74 |
+
del codes; torch.cuda.empty_cache()
|
| 75 |
+
for i in range(0, out_f, 64):
|
| 76 |
+
e = min(i+64, out_f); vals[i:e] = (vals[i:e].reshape(-1, BS) @ H_dev).reshape(e-i, nb, BS)
|
| 77 |
+
vals *= norms; del norms; bf16_w = vals.reshape(out_f, -1)[:, :in_f].to(torch.bfloat16); del vals; torch.cuda.empty_cache()
|
| 78 |
+
try:
|
| 79 |
+
with torch.device("meta"): dummy = nn.Sequential(nn.Linear(in_f, out_f, bias=False))
|
| 80 |
+
dummy[0].weight = nn.Parameter(bf16_w); quantize_(dummy, int4_cfg); child.weight = dummy[0].weight; del dummy
|
| 81 |
+
except: child.weight.data = bf16_w.cpu()
|
| 82 |
+
del bf16_w; torch.cuda.empty_cache(); n += 1
|
| 83 |
+
_tao_utils.guard_dtype_size = _orig
|
| 84 |
+
for _, p in model.named_parameters():
|
| 85 |
+
if p.device.type == "cpu": p.data = p.data.to("cuda")
|
| 86 |
+
for _, b in model.named_buffers():
|
| 87 |
+
if b.device.type == "cpu": b.data = b.data.to("cuda")
|
| 88 |
+
gc.collect(); torch.cuda.empty_cache(); _loaded = True
|
| 89 |
+
print(f"Ready! {n} layers, {torch.cuda.memory_allocated()/1e9:.1f} GB")
|
| 90 |
+
|
| 91 |
+
messages = list(history) + [{"role": "user", "content": message}]
|
| 92 |
+
chat_out = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True)
|
| 93 |
+
input_ids = chat_out["input_ids"].to("cuda") if hasattr(chat_out, "input_ids") else chat_out.to("cuda")
|
| 94 |
+
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
| 95 |
+
Thread(target=model.generate, kwargs=dict(input_ids=input_ids, max_new_tokens=512, do_sample=True, temperature=0.7, top_p=0.9, repetition_penalty=1.3, streamer=streamer)).start()
|
| 96 |
+
partial = ""
|
| 97 |
+
for text in streamer: partial += text; yield partial
|
| 98 |
+
|
| 99 |
+
demo = gr.ChatInterface(respond, title="🧊 Nemotron Cascade 30B-A3B — PolarQuant Q5+INT4",
|
| 100 |
+
description="30B MoE (3B active) | 735 downloads | Most popular PolarQuant model | [Paper](https://arxiv.org/abs/2603.29078)",
|
| 101 |
+
examples=["Explain quantum computing simply.", "Write a Python binary search.", "Compare TCP vs UDP."], type="messages")
|
| 102 |
+
demo.launch()
|