{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [], "gpuType": "L4" }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" }, "accelerator": "GPU" }, "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# \ud83e\uddca Gemma 4 26B-A4B-it \u2014 PolarQuant Q5+INT4 (Vision)\n", "\n", "**25.2B MoE (3.8B active) + Vision** on consumer GPUs.\n", "\n", "| Component | Method | Effect |\n", "|---|---|---|\n", "| **Text weights** | PolarQuant Q5 \u2192 torchao INT4 | BF16 ~50 GB \u2192 ~13 GB |\n", "| **Vision encoder** | BF16 (preserved) | Full image quality |\n", "| **MoE routing** | BF16 (preserved) | Exact expert selection |\n", "\n", "| GPU | VRAM | Status |\n", "|---|---|---|\n", "| T4 | 16 GB | \u2705 (~13 GB) |\n", "| L4 / RTX 4090 | 24 GB | \u2705 Comfortable |\n", "| A100 | 40-80 GB | \u2705 |\n" ] }, { "cell_type": "code", "metadata": {}, "execution_count": null, "outputs": [], "source": [ "!pip install git+https://github.com/huggingface/transformers.git --force-reinstall -q\n", "!pip install -q accelerate safetensors sentencepiece scipy torchao gradio torchvision pillow\n" ] }, { "cell_type": "code", "metadata": {}, "execution_count": null, "outputs": [], "source": [ "import torch, math, gc, time, os\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "from scipy.stats import norm as sp_norm\n", "from transformers import AutoModelForMultimodalLM, AutoProcessor\n", "\n", "DEVICE = 'cuda'\n", "MODEL = 'google/gemma-4-26B-A4B-it'\n", "BS = 128\n", "HEAD_DIM = 256\n", "\n", "gpu_name = torch.cuda.get_device_name(0)\n", "gpu_vram = torch.cuda.get_device_properties(0).total_memory / 1e9\n", "print(f'GPU: {gpu_name} ({gpu_vram:.0f} GB)')\n" ] }, { "cell_type": "code", "metadata": {}, "execution_count": null, "outputs": [], "source": [ "# PolarQuant Core\n", "_C = {}\n", "def get_centroids(bits):\n", " if bits in _C: return _C[bits]\n", " n = 1 << bits; bd = torch.linspace(-4.0, 4.0, n + 1); ct = torch.zeros(n)\n", " for _ in range(100):\n", " for i in range(n):\n", " a, b = bd[i].item(), bd[i+1].item(); pa, pb = sp_norm.cdf(a), sp_norm.cdf(b)\n", " ct[i] = (sp_norm.pdf(a) - sp_norm.pdf(b)) / (pb - pa) if pb - pa > 1e-12 else (a + b) / 2\n", " for i in range(1, n): bd[i] = (ct[i-1] + ct[i]) / 2\n", " _C[bits] = ct; return ct\n", "for b in [2,3,4,5,6]: get_centroids(b)\n", "\n", "def _build_H(n):\n", " if n == 1: return torch.tensor([[1.0]])\n", " h = _build_H(n // 2)\n", " return torch.cat([torch.cat([h,h],1), torch.cat([h,-h],1)], 0) / math.sqrt(2)\n", "H_W = _build_H(BS)\n", "\n", "def should_quantize(name, param):\n", " if param.ndim < 2 or param.numel() < 256: return False\n", " if any(k in name for k in ['norm','layernorm','rmsnorm']): return False\n", " if any(k in name for k in ['A_log','.D','dt_bias','conv1d']): return False\n", " if 'bias' in name and param.ndim == 1: return False\n", " # CRITICAL for MoE: keep router weights in FP16\n", " if name.endswith('.gate.weight') or 'router' in name: return False\n", " # Keep vision encoder in BF16\n", " if any(k in name for k in ['vision_tower','vision_model','multi_modal_projector']): return False\n", " return True\n", "\n", "import torchao.quantization.utils as _tao_utils\n", "_orig = _tao_utils.guard_dtype_size\n", "def _patched(t, n, dtype=None, size=None):\n", " if dtype is not None and t.dtype != dtype: t.data = t.data.to(dtype)\n", " if size is not None and t.size() != size: raise ValueError(f'{size} vs {t.size()}')\n", "_tao_utils.guard_dtype_size = _patched\n", "print('Core loaded.')\n" ] }, { "cell_type": "code", "metadata": {}, "execution_count": null, "outputs": [], "source": [ "# Streaming Loader \u2014 INT4 text+MoE experts, BF16 vision+routers\n", "from torchao.quantization import quantize_, Int4WeightOnlyConfig\n", "\n", "print('Step 1/3: Loading on CPU...')\n", "t_total = time.time()\n", "model = AutoModelForMultimodalLM.from_pretrained(MODEL, dtype=torch.bfloat16, device_map='cpu', attn_implementation='sdpa')\n", "processor = AutoProcessor.from_pretrained(MODEL)\n", "config = model.config\n", "text_config = config.text_config if hasattr(config, 'text_config') else config\n", "num_layers = text_config.num_hidden_layers\n", "num_kv_heads = text_config.num_key_value_heads\n", "print(f' {sum(p.numel() for p in model.parameters())/1e9:.1f}B params, {num_layers} layers')\n", "\n", "print('\\nStep 2/3: PQ5+INT4 streaming (vision+routers stay BF16)...')\n", "H_dev = H_W.to(DEVICE); ct5 = get_centroids(5).to(DEVICE)\n", "int4_config = Int4WeightOnlyConfig(group_size=128)\n", "n_quant, n_skip_vision, n_skip_router, n_skip_other = 0, 0, 0, 0\n", "t0 = time.time()\n", "\n", "for name, child in list(model.named_modules()):\n", " if not isinstance(child, nn.Linear): continue\n", " if child.weight.device.type == 'meta': continue\n", " \n", " if any(k in name for k in ['vision_tower','vision_model','multi_modal_projector']):\n", " n_skip_vision += 1; continue\n", " if name.endswith('.gate.weight') or 'router' in name:\n", " n_skip_router += 1; continue\n", " if not should_quantize(name, child.weight):\n", " n_skip_other += 1; continue\n", "\n", " w = child.weight.data.float().to(DEVICE)\n", " out_f, in_f = w.shape\n", " pad = (BS - in_f % BS) % BS\n", " if pad > 0: w = F.pad(w, (0, pad))\n", " nb = w.shape[1] // BS; w = w.reshape(out_f, nb, BS)\n", " for i in range(0, out_f, 64):\n", " e = min(i+64, out_f); w[i:e] = (w[i:e].reshape(-1, BS) @ H_dev).reshape(e-i, nb, BS)\n", " norms = w.norm(dim=2, keepdim=True).clamp(min=1e-10)\n", " w.div_(norms).mul_(math.sqrt(BS))\n", " QC = 256\n", " codes = torch.empty(out_f, nb, BS, dtype=torch.int8, device=DEVICE)\n", " for ci in range(0, out_f, QC):\n", " ce = min(ci+QC, out_f)\n", " codes[ci:ce] = (w[ci:ce].unsqueeze(-1) - ct5.view(1,1,1,-1)).abs().argmin(-1).to(torch.int8)\n", " del w\n", " vals = torch.empty(out_f, nb, BS, dtype=torch.float32, device=DEVICE)\n", " for ci in range(0, out_f, QC):\n", " ce = min(ci+QC, out_f)\n", " vals[ci:ce] = ct5[codes[ci:ce].long()] / math.sqrt(BS)\n", " del codes; torch.cuda.empty_cache()\n", " for i in range(0, out_f, 64):\n", " e = min(i+64, out_f)\n", " vals[i:e] = (vals[i:e].reshape(-1, BS) @ H_dev).reshape(e-i, nb, BS)\n", " vals *= norms; del norms\n", " bf16_w = vals.reshape(out_f, -1)[:, :in_f].to(torch.bfloat16)\n", " del vals; torch.cuda.empty_cache()\n", " try:\n", " with torch.device('meta'):\n", " dummy = nn.Sequential(nn.Linear(in_f, out_f, bias=False))\n", " dummy[0].weight = nn.Parameter(bf16_w)\n", " quantize_(dummy, int4_config)\n", " child.weight = dummy[0].weight; del dummy\n", " except: child.weight.data = bf16_w.cpu()\n", " del bf16_w; torch.cuda.empty_cache()\n", " n_quant += 1\n", " if n_quant % 200 == 0:\n", " print(f' {n_quant} layers ({torch.cuda.memory_allocated()/1e9:.1f} GB VRAM)...')\n", "\n", "_tao_utils.guard_dtype_size = _orig\n", "print(f' INT4: {n_quant} | Vision BF16: {n_skip_vision} | Router BF16: {n_skip_router} | Other BF16: {n_skip_other}')\n", "print(f' Done in {time.time()-t0:.0f}s')\n", "\n", "print('\\nStep 3/3: Moving remaining to GPU...')\n", "for _, p in model.named_parameters():\n", " if p.device.type == 'cpu': p.data = p.data.to(DEVICE)\n", "for _, b in model.named_buffers():\n", " if b.device.type == 'cpu': b.data = b.data.to(DEVICE)\n", "gc.collect(); torch.cuda.empty_cache()\n", "vram = torch.cuda.memory_allocated()/1e9\n", "print(f'\\n\u2705 Ready! VRAM: {vram:.1f} GB | Time: {time.time()-t_total:.0f}s')\n" ] }, { "cell_type": "code", "metadata": {}, "execution_count": null, "outputs": [], "source": [ "# Quick tests: text + vision\n", "print('--- Text test ---')\n", "msgs = [{'role': 'user', 'content': [{'type': 'text', 'text': 'What is 2+2?'}]}]\n", "inputs = processor.apply_chat_template(msgs, tokenize=True, return_dict=True, return_tensors='pt', add_generation_prompt=True).to(DEVICE)\n", "with torch.no_grad():\n", " out = model.generate(**inputs, max_new_tokens=50, do_sample=False)\n", "print(processor.decode(out[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True))\n", "\n", "print('\\n--- Vision test ---')\n", "msgs_img = [{'role': 'user', 'content': [\n", " {'type': 'image', 'url': 'https://raw.githubusercontent.com/google-gemma/cookbook/refs/heads/main/Demos/sample-data/GoldenGate.png'},\n", " {'type': 'text', 'text': 'What is shown in this image? Answer briefly.'},\n", "]}]\n", "inputs_img = processor.apply_chat_template(msgs_img, tokenize=True, return_dict=True, return_tensors='pt', add_generation_prompt=True).to(DEVICE)\n", "with torch.no_grad():\n", " out_img = model.generate(**inputs_img, max_new_tokens=100, do_sample=False)\n", "print(processor.decode(out_img[0][inputs_img['input_ids'].shape[1]:], skip_special_tokens=True))\n", "print(f'\\nVRAM: {torch.cuda.memory_allocated()/1e9:.1f} GB')\n" ] }, { "cell_type": "code", "metadata": {}, "execution_count": null, "outputs": [], "source": [ "# Speed benchmark\n", "import time\n", "from transformers import TextIteratorStreamer\n", "from threading import Thread\n", "\n", "msgs_bench = [{'role': 'user', 'content': [{'type': 'text', 'text': 'Write a detailed essay about artificial intelligence.'}]}]\n", "inputs_bench = processor.apply_chat_template(msgs_bench, tokenize=True, return_dict=True, return_tensors='pt', add_generation_prompt=True).to(DEVICE)\n", "\n", "# Warmup\n", "with torch.no_grad():\n", " _ = model.generate(**inputs_bench, max_new_tokens=10, do_sample=False)\n", "\n", "speeds = []\n", "for run in range(3):\n", " torch.cuda.synchronize()\n", " t0 = time.time()\n", " with torch.no_grad():\n", " out = model.generate(**inputs_bench, max_new_tokens=100, do_sample=False)\n", " torch.cuda.synchronize()\n", " n_tok = out.shape[1] - inputs_bench['input_ids'].shape[1]\n", " tps = n_tok / (time.time() - t0)\n", " speeds.append(tps)\n", " print(f'Run {run+1}: {tps:.1f} tok/s ({n_tok} tokens)')\n", "\n", "avg_speed = sum(speeds)/len(speeds)\n", "print(f'\\nAverage: {avg_speed:.1f} tok/s')\n", "print(f'VRAM: {torch.cuda.memory_allocated()/1e9:.1f} GB')\n" ] }, { "cell_type": "code", "metadata": {}, "execution_count": null, "outputs": [], "source": [ "# Gradio Multimodal Chat\n", "import gradio as gr\n", "from transformers import TextIteratorStreamer\n", "from threading import Thread\n", "\n", "@torch.no_grad()\n", "def chat_fn(message, history):\n", " messages = list(history)\n", " content = []\n", " if isinstance(message, dict):\n", " if message.get('files'):\n", " for f in message['files']: content.append({'type': 'image', 'url': f})\n", " if message.get('text'): content.append({'type': 'text', 'text': message['text']})\n", " else:\n", " content.append({'type': 'text', 'text': str(message)})\n", " messages.append({'role': 'user', 'content': content})\n", " inputs = processor.apply_chat_template(messages, tokenize=True, return_dict=True, return_tensors='pt', add_generation_prompt=True).to(DEVICE)\n", " streamer = TextIteratorStreamer(processor.tokenizer, skip_prompt=True, skip_special_tokens=True)\n", " Thread(target=model.generate, kwargs=dict(**inputs, max_new_tokens=512, do_sample=True, temperature=0.7, top_p=0.9, repetition_penalty=1.3, streamer=streamer)).start()\n", " partial = ''\n", " for text in streamer: partial += text; yield partial\n", "\n", "demo = gr.ChatInterface(chat_fn,\n", " title='\ud83e\uddca Gemma 4 26B-A4B Vision \u2014 PolarQuant Q5+INT4',\n", " description=f'25.2B MoE (3.8B active) + Vision | VRAM: {torch.cuda.memory_allocated()/1e9:.0f} GB',\n", " examples=['Explain quantum computing.', 'Write Python binary search.'],\n", " multimodal=True, type='messages')\n", "demo.launch(share=True, quiet=True)\n", "print('\\n\ud83d\udd17 Share the link!')\n" ] }, { "cell_type": "code", "metadata": {}, "execution_count": null, "outputs": [], "source": [ "# Save + Upload\n", "import json\n", "from huggingface_hub import HfApi, login\n", "login(token='YOUR_HF_TOKEN')\n", "api = HfApi()\n", "\n", "REPO = 'caiovicentino1/Gemma-4-26B-A4B-it-PolarQuant-Q5-Vision'\n", "SAVE_PATH = '/content/model_26b_int4.pt'\n", "\n", "print('Saving...')\n", "torch.save(model.state_dict(), SAVE_PATH)\n", "sz = os.path.getsize(SAVE_PATH)/1e9\n", "print(f'Saved: {sz:.1f} GB')\n", "\n", "SAVE_DIR = '/content/upload_26b'\n", "os.makedirs(SAVE_DIR, exist_ok=True)\n", "model.config.save_pretrained(SAVE_DIR)\n", "processor.save_pretrained(SAVE_DIR)\n", "json.dump({'quantization_method':'PolarQuant','version':'v5_vision','text_weight_bits':4,\n", " 'vision_weight_bits':16,'kv_cache_bits':3,'block_size':128,'head_dim':256,\n", " 'base_model':MODEL,'multimodal':True,'pipeline_tag':'image-text-to-text',\n", " 'num_layers':num_layers,'num_kv_heads':num_kv_heads,'moe':True,'num_experts':128,'top_k':8,\n", " 'vram_gb':round(torch.cuda.memory_allocated()/1e9,1),'tok_s':round(avg_speed,1)},\n", " open(os.path.join(SAVE_DIR,'polar_config.json'),'w'), indent=2)\n", "\n", "print('Uploading...')\n", "api.create_repo(REPO, exist_ok=True)\n", "api.upload_file(path_or_fileobj=SAVE_PATH, path_in_repo='model_int4.pt', repo_id=REPO, repo_type='model')\n", "api.upload_folder(folder_path=SAVE_DIR, repo_id=REPO, repo_type='model')\n", "\n", "vram = torch.cuda.memory_allocated()/1e9\n", "card = f\"\"\"---\\nlicense: apache-2.0\\ntags:\\n- polarquant\\n- gemma4\\n- moe\\n- vision\\n- quantized\\nbase_model: google/gemma-4-26B-A4B-it\\npipeline_tag: image-text-to-text\\n---\\n\\n# \ud83e\uddca Gemma-4-26B-A4B-it-PolarQuant-Q5-Vision\\n\\n**25.2B MoE (3.8B active) + Vision** on consumer GPUs.\\n\\n| Metric | Value |\\n|---|---|\\n| **VRAM** | {vram:.1f} GB |\\n| **Speed** | {avg_speed:.1f} tok/s |\\n| **Architecture** | 30 layers, 128 experts (top-8) |\\n| **Vision** | BF16 (full quality) |\\n| **Routers** | BF16 (exact expert selection) |\\n\\n## GPU Support\\n\\n| GPU | Fits? |\\n|---|---|\\n| T4 16GB | \u2705 |\\n| RTX 4090 24GB | \u2705 |\\n| A100 40GB | \u2705 |\\n\\n\ud83d\udcc4 [Paper](https://arxiv.org/abs/2603.29078) \u00b7 \ud83d\udcbb [GitHub](https://github.com/caiovicentino/polarengine-vllm) \u00b7 \ud83d\udce6 [pip install polarquant](https://pypi.org/project/polarquant/)\\n\"\"\"\n", "api.upload_file(path_or_fileobj=card.encode(), path_in_repo='README.md', repo_id=REPO, repo_type='model')\n", "try:\n", " api.add_collection_item(collection_slug='caiovicentino1/polarquant-models-69cbc96292c5174df2088b08', item_id=REPO, item_type='model')\n", " api.add_collection_item(collection_slug='caiovicentino1/polarquant-gemma-models-69ceedd4896e4cd587972c0c', item_id=REPO, item_type='model')\n", "except: pass\n", "print(f'\\n\u2705 https://huggingface.co/{REPO}')\n", "" ] } ] }