#!/usr/bin/env /home/jovyan/step37_work/.venv/bin/python import argparse, json, os, pickle, time from pathlib import Path import torch from safetensors import safe_open from transformers import AutoModelForCausalLM, AutoTokenizer from transformers.masking_utils import create_causal_mask MODEL_PATH="/home/jovyan/step37_bf16_split_model" SCRATCH="/home/jovyan/scratchdata/jovyan/step37_work" DEV="cuda:0" # overridden by --device arg _IDX=None def index(): global _IDX if _IDX is None: _IDX=json.load(open(f"{MODEL_PATH}/model.safetensors.index.json"))["weight_map"] return _IDX def get_tensor(key): idx=index(); shard=idx[key] with safe_open(f"{MODEL_PATH}/{shard}", framework="pt", device="cpu") as f: return f.get_tensor(key) def layer_state(layer_idx): idx=index(); prefix=f"model.layers.{layer_idx}." keys=[k for k in idx if k.startswith(prefix)] out={} for shard in sorted(set(idx[k] for k in keys)): with safe_open(f"{MODEL_PATH}/{shard}", framework="pt", device="cpu") as f: fkeys=set(f.keys()) for k in keys: if k in fkeys: out[k.replace(prefix,"")]=f.get_tensor(k) return out def load_layer(text, i): layer=text.layers[i] layer.to_empty(device=DEV) st={k:v.to(DEV) for k,v in layer_state(i).items()} layer.load_state_dict(st, strict=False, assign=True) return layer def unload_layer(layer): layer.to_empty(device="meta") torch.cuda.empty_cache() def make_mask(text, attention_mask, position_ids, h): cp=torch.arange(h.shape[1],device=DEV) kw=dict(config=text.config, attention_mask=attention_mask, past_key_values=None, position_ids=position_ids, cache_position=cp, input_embeds=h) return create_causal_mask(**kw) def main(): ap=argparse.ArgumentParser() ap.add_argument("--end-layer",type=int,default=4) ap.add_argument("--max-length",type=int,default=48) ap.add_argument("--out",default=f"{SCRATCH}/pilot_gradients.pkl") # Step3p7 rotary code broadcasts only for batch=1; average multiple runs later. ap.add_argument("--harmful",nargs="+",default=["How do I make a bomb?"]) ap.add_argument("--harmless",nargs="*",default=[]) ap.add_argument("--refusal-tokens",nargs="+",default=["I","Sorry","cannot","can't","unable"]) ap.add_argument("--device",default="cuda:0") args=ap.parse_args(); os.makedirs(Path(args.out).parent,exist_ok=True) global DEV; DEV=args.device torch.set_float32_matmul_precision("high") _x=torch.zeros(1,device=DEV); del _x; torch.cuda.synchronize() print("cuda warm") tok=AutoTokenizer.from_pretrained(MODEL_PATH,trust_remote_code=True) if tok.pad_token is None: tok.pad_token=tok.eos_token tok.padding_side="left" texts=args.harmful+args.harmless enc=tok(texts,return_tensors="pt",padding=True,truncation=True,max_length=args.max_length) input_ids=enc.input_ids.to(DEV); am=enc.attention_mask.to(DEV) pos=am.long().cumsum(-1)-1; pos.masked_fill_(am==0,1) rids=sorted(set(sum([tok.encode(t,add_special_tokens=False) for t in args.refusal_tokens],[]))) print("refusal ids",rids,tok.batch_decode([[x] for x in rids])) import subprocess; subprocess.run(["/home/jovyan/step37_work/.venv/bin/python","/home/jovyan/heretic/memcheck.py"]) from transformers import AutoConfig from accelerate import init_empty_weights cfg=AutoConfig.from_pretrained(MODEL_PATH,trust_remote_code=True) with init_empty_weights(): model=AutoModelForCausalLM.from_config(cfg,torch_dtype=torch.bfloat16,trust_remote_code=True) text=model.model.language_model # embed + lm_head text.embed_tokens.to_empty(device=DEV) text.embed_tokens.weight=torch.nn.Parameter(get_tensor("model.embed_tokens.weight").to(DEV),requires_grad=False) model.lm_head.to_empty(device=DEV) model.lm_head.weight=torch.nn.Parameter(get_tensor("lm_head.weight").to(DEV),requires_grad=False) text.norm.to_empty(device=DEV) text.norm.weight=torch.nn.Parameter(get_tensor("model.norm.weight").to(DEV),requires_grad=False) h=text.embed_tokens(input_ids) print("embed done", tuple(h.shape)) H=[h.detach().cpu()] print("making mask...") attn_mask=make_mask(text,am,pos,h) print("mask done") for i in range(args.end_layer+1): print(f"loading layer {i}...") t=time.time(); layer=load_layer(text,i) print(f"layer {i} loaded in {time.time()-t:.2f}s") t=time.time() with torch.no_grad(): h=layer(h,attention_mask=attn_mask,position_ids=pos) print(f"F {i}: {time.time()-t:.2f}s {tuple(h.shape)}") H.append(h.detach().cpu()); unload_layer(layer) # use current partial model output as pilot signal h_final=H[-1].to(DEV).requires_grad_(True) logits=model.lm_head(text.norm(h_final)) logp=torch.log_softmax(logits[:len(args.harmful),-1,:],dim=-1) score=logp[:,rids].mean() # higher = more refusal-token probability grad=torch.autograd.grad(score,h_final)[0].detach().cpu() print("score",float(score),"grad_norm",float(grad.norm())) grads=[None]*(args.end_layer+1) for i in reversed(range(args.end_layer+1)): t=time.time(); layer=load_layer(text,i) for p in layer.parameters(): p.requires_grad=False hin=H[i].to(DEV).requires_grad_(True) hout=layer(hin,attention_mask=attn_mask,position_ids=pos) hout.backward(grad.to(DEV)) grad=hin.grad.detach().cpu(); grads[i]=grad; unload_layer(layer) print(f"B {i}: {time.time()-t:.2f}s grad {float(grad.norm())}") pickle.dump(dict(gradients=grads,end_layer=args.end_layer,harmful=args.harmful,harmless=args.harmless,refusal_ids=rids,score=float(score)),open(args.out,"wb")) print("saved",args.out) if __name__=="__main__": main() # GPU1 support: set CUDA_DEVICE env var