--- license: llama3 train: false inference: false pipeline_tag: text-generation --- This is an MXFP4 calibrated weight-only quantized Meta-Llama-3.1-8B-Instruct model, as presented in our blogpost. ## Usage ### Installation ``` pip install safetensors==0.6.0.dev0 ``` ```Python import os, torch from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer from accelerate import init_empty_weights from huggingface_hub import snapshot_download from os.path import join as pjoin from safetensors import safe_open @torch.compile(fullgraph=True) def matmul_fp4(x, W_q, scales, group_size, fp4_values): def unpack_over_cols(W_q_packed, W_nbits, num_output_cols, dtype): n_rows, n_cols = W_q_packed.shape device = W_q_packed.device shifts = torch.arange(num_output_cols // n_cols, device=device, dtype=W_q_packed.dtype) * W_nbits W_q_unpacked = ((W_q_packed.unsqueeze(-1) >> shifts) & ((1 << W_nbits) - 1)).to(dtype) W_q_unpacked = W_q_unpacked.view(n_rows, num_output_cols) return W_q_unpacked N, K = W_q.shape[0], W_q.shape[1] * 2 W_q = fp4_values[unpack_over_cols(W_q, W_nbits=4, num_output_cols=K, dtype=torch.int32)] W_r = (W_q.float().view([-1, group_size]) * scales.float()).reshape([N, K]).to(x.dtype).T return torch.matmul(x, W_r) class AutoModelForCausalLMFP4: @classmethod def from_pretrained( cls, save_dir_or_hub, torch_dtype=torch.bfloat16, cache_dir=None, device_map="cuda:0", *args, **kwargs ): #Download snapshot if os.path.exists(save_dir_or_hub): save_dir = save_dir_or_hub else: save_dir = snapshot_download(repo_id=save_dir_or_hub, cache_dir=cache_dir) #Create model from config config = AutoConfig.from_pretrained(pjoin(save_dir, "config.json")) config.torch_dtype = str(torch_dtype).split('.')[-1] with init_empty_weights(): model = AutoModelForCausalLM.from_config(config) #Load and patch state_dict = {} with safe_open(pjoin(save_dir, "model.safetensors"), framework="pt", device="cpu") as f: for key in f.keys(): tensor = f.get_tensor(key) dtype = torch_dtype if tensor.is_floating_point() else tensor.dtype state_dict[key] = tensor.to(device=device_map, dtype=dtype, non_blocking=True) cls.patch_model_for_fp4_inference(model=model, torch_dtype=torch_dtype, device=device_map, state_dict=state_dict) return model @classmethod def patch_model_for_fp4_inference(cls, model, torch_dtype, device, state_dict): model.fp4_values = torch.tensor( [0, 0.5, 1, 1.5, 2, 3, 4, 6, -0, -0.5, -1, -1.5, -2, -3, -4, -6], dtype=torch_dtype, device=device, ) def patch_linearlayers(model, fct): for name, layer in model.named_children(): if isinstance(layer, torch.nn.Linear): setattr(model, name, fct(layer, name)) else: patch_linearlayers(layer, fct) def patch_enable_fp4(layer, arg): #Load params if('lm_head' in layer.name): return layer if(hasattr(layer, 'weight')): del layer.weight for key in ['W_q', 'scales', 'shift', 'post_scale', 'meta']: param_tag, param = layer.name + '.' + key, None if(param_tag in state_dict): param = state_dict[param_tag].tolist() if key in ["meta"] else state_dict[param_tag] setattr(layer, key, param) #Set forward pass def forward(self, x): if(hasattr(self, 'weight')): out = torch.matmul(x, self.weight.data.T) else: out = matmul_fp4(x, self.W_q, self.scales, self.meta[-1], model.fp4_values) if(self.post_scale is not None): out *= self.post_scale if(self.shift is not None): out += self.shift if(self.bias is not None): out += self.bias return out layer.forward = lambda x: forward(layer, x) return layer try: #FP4 params will fail here model.load_state_dict(state_dict, assign=True) except: pass for name, module in model.named_modules(): module.name = name patch_linearlayers(model, patch_enable_fp4) model = model.to(device) ``` ### Usage ```Python model_id = "mobiuslabsgmbh/Llama-3.1-8B-Instruct_mxfp4_weights_calib_demo" model = AutoModelForCausalLMFP4.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map='cuda') tokenizer = AutoTokenizer.from_pretrained(model_id) # Check the trained params # print( model.model.layers[-1].self_attn.v_proj.shift) # tensor([ 0.0034, -0.0036, 0.0054, ..., 0.0036, -0.0076, -0.0068], # device='cuda:0', dtype=torch.bfloat16) # print( model.model.layers[-1].self_attn.v_proj.post_scale) # tensor([1., 1., 1., ..., 1., 1., 1.], device='cuda:0', dtype=torch.bfloat16) outputs = model.generate( tokenizer.apply_chat_template( [{"role": "user", "content": "Solve the following equation: x^2 + 1 = -1"}], tokenize=True, add_generation_prompt=True, return_tensors="pt", ).to(model.device), max_new_tokens=256, ) print(tokenizer.decode(outputs[0])) ```