import spaces import torch, torch.nn as nn, torch.nn.functional as F, math, gc from scipy.stats import norm as sp_norm from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer from torchao.quantization import quantize_, Int4WeightOnlyConfig import torchao.quantization.utils as _tao_utils from threading import Thread import gradio as gr MODEL = "nvidia/Nemotron-Cascade-2-30B-A3B" BS = 128 _C = {} def get_centroids(bits): if bits in _C: return _C[bits] n = 1 << bits; bd = torch.linspace(-4.0, 4.0, n + 1); ct = torch.zeros(n) for _ in range(100): for i in range(n): a, b = bd[i].item(), bd[i+1].item(); pa, pb = sp_norm.cdf(a), sp_norm.cdf(b) ct[i] = (sp_norm.pdf(a) - sp_norm.pdf(b)) / (pb - pa) if pb - pa > 1e-12 else (a + b) / 2 for i in range(1, n): bd[i] = (ct[i-1] + ct[i]) / 2 _C[bits] = ct; return ct def _build_H(n): if n == 1: return torch.tensor([[1.0]]) h = _build_H(n // 2) return torch.cat([torch.cat([h,h],1), torch.cat([h,-h],1)], 0) / math.sqrt(2) for b in [2,3,4,5,6]: get_centroids(b) H_W = _build_H(BS) def should_quantize(name, param): if param.ndim < 2 or param.numel() < 256: return False if any(k in name for k in ["norm","layernorm","rmsnorm"]): return False if any(k in name for k in ["A_log",".D","dt_bias","conv1d"]): return False if "bias" in name and param.ndim == 1: return False if name.endswith(".gate.weight") or "router" in name: return False return True _orig = _tao_utils.guard_dtype_size def _patched(t, n, dtype=None, size=None): if dtype is not None and t.dtype != dtype: t.data = t.data.to(dtype) if size is not None and t.size() != size: raise ValueError(f"{size} vs {t.size()}") _tao_utils.guard_dtype_size = _patched print("Loading Nemotron on CPU...") model = AutoModelForCausalLM.from_pretrained(MODEL, dtype=torch.bfloat16, device_map="cpu", attn_implementation="sdpa", trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(MODEL, trust_remote_code=True) _loaded = False @spaces.GPU(duration=300) def respond(message, history): global _loaded, model if not _loaded: H_dev = H_W.to("cuda"); ct5 = get_centroids(5).to("cuda") int4_cfg = Int4WeightOnlyConfig(group_size=128); n = 0 for name, child in list(model.named_modules()): if not isinstance(child, nn.Linear) or child.weight.device.type == "meta": continue if not should_quantize(name, child.weight): continue w = child.weight.data.float().to("cuda"); out_f, in_f = w.shape pad = (BS - in_f % BS) % BS if pad > 0: w = F.pad(w, (0, pad)) nb = w.shape[1] // BS; w = w.reshape(out_f, nb, BS) for i in range(0, out_f, 64): e = min(i+64, out_f); w[i:e] = (w[i:e].reshape(-1, BS) @ H_dev).reshape(e-i, nb, BS) norms = w.norm(dim=2, keepdim=True).clamp(min=1e-10); w.div_(norms).mul_(math.sqrt(BS)) codes = torch.empty(out_f, nb, BS, dtype=torch.int8, device="cuda") for ci in range(0, out_f, 256): ce = min(ci+256, out_f); codes[ci:ce] = (w[ci:ce].unsqueeze(-1) - ct5.view(1,1,1,-1)).abs().argmin(-1).to(torch.int8) del w; vals = torch.empty(out_f, nb, BS, dtype=torch.float32, device="cuda") for ci in range(0, out_f, 256): ce = min(ci+256, out_f); vals[ci:ce] = ct5[codes[ci:ce].long()] / math.sqrt(BS) del codes; torch.cuda.empty_cache() for i in range(0, out_f, 64): e = min(i+64, out_f); vals[i:e] = (vals[i:e].reshape(-1, BS) @ H_dev).reshape(e-i, nb, BS) vals *= norms; del norms; bf16_w = vals.reshape(out_f, -1)[:, :in_f].to(torch.bfloat16); del vals; torch.cuda.empty_cache() try: with torch.device("meta"): dummy = nn.Sequential(nn.Linear(in_f, out_f, bias=False)) dummy[0].weight = nn.Parameter(bf16_w); quantize_(dummy, int4_cfg); child.weight = dummy[0].weight; del dummy except: child.weight.data = bf16_w.cpu() del bf16_w; torch.cuda.empty_cache(); n += 1 _tao_utils.guard_dtype_size = _orig for _, p in model.named_parameters(): if p.device.type == "cpu": p.data = p.data.to("cuda") for _, b in model.named_buffers(): if b.device.type == "cpu": b.data = b.data.to("cuda") gc.collect(); torch.cuda.empty_cache(); _loaded = True print(f"Ready! {n} layers, {torch.cuda.memory_allocated()/1e9:.1f} GB") messages = list(history) + [{"role": "user", "content": message}] chat_out = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True) input_ids = chat_out["input_ids"].to("cuda") if hasattr(chat_out, "input_ids") else chat_out.to("cuda") streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) Thread(target=model.generate, kwargs=dict(input_ids=input_ids, max_new_tokens=512, do_sample=True, temperature=0.7, top_p=0.9, repetition_penalty=1.3, streamer=streamer)).start() partial = "" for text in streamer: partial += text; yield partial demo = gr.ChatInterface(respond, title="🧊 Nemotron Cascade 30B-A3B — PolarQuant Q5+INT4", description="30B MoE (3B active) | 735 downloads | Most popular PolarQuant model | [Paper](https://arxiv.org/abs/2603.29078)", examples=["Explain quantum computing simply.", "Write a Python binary search.", "Compare TCP vs UDP."], type="messages") demo.launch()