ibrahimkettaneh's picture
Upload heretic_artifacts/extract_gradients.py with huggingface_hub
ed137b4 verified
#!/usr/bin/env /home/jovyan/step37_work/.venv/bin/python
import argparse, json, os, pickle, time
from pathlib import Path
import torch
from safetensors import safe_open
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.masking_utils import create_causal_mask
MODEL_PATH="/home/jovyan/step37_bf16_split_model"
SCRATCH="/home/jovyan/scratchdata/jovyan/step37_work"
DEV="cuda:0" # overridden by --device arg
_IDX=None
def index():
global _IDX
if _IDX is None: _IDX=json.load(open(f"{MODEL_PATH}/model.safetensors.index.json"))["weight_map"]
return _IDX
def get_tensor(key):
idx=index(); shard=idx[key]
with safe_open(f"{MODEL_PATH}/{shard}", framework="pt", device="cpu") as f:
return f.get_tensor(key)
def layer_state(layer_idx):
idx=index(); prefix=f"model.layers.{layer_idx}."
keys=[k for k in idx if k.startswith(prefix)]
out={}
for shard in sorted(set(idx[k] for k in keys)):
with safe_open(f"{MODEL_PATH}/{shard}", framework="pt", device="cpu") as f:
fkeys=set(f.keys())
for k in keys:
if k in fkeys: out[k.replace(prefix,"")]=f.get_tensor(k)
return out
def load_layer(text, i):
layer=text.layers[i]
layer.to_empty(device=DEV)
st={k:v.to(DEV) for k,v in layer_state(i).items()}
layer.load_state_dict(st, strict=False, assign=True)
return layer
def unload_layer(layer):
layer.to_empty(device="meta")
torch.cuda.empty_cache()
def make_mask(text, attention_mask, position_ids, h):
cp=torch.arange(h.shape[1],device=DEV)
kw=dict(config=text.config, attention_mask=attention_mask, past_key_values=None, position_ids=position_ids, cache_position=cp, input_embeds=h)
return create_causal_mask(**kw)
def main():
ap=argparse.ArgumentParser()
ap.add_argument("--end-layer",type=int,default=4)
ap.add_argument("--max-length",type=int,default=48)
ap.add_argument("--out",default=f"{SCRATCH}/pilot_gradients.pkl")
# Step3p7 rotary code broadcasts only for batch=1; average multiple runs later.
ap.add_argument("--harmful",nargs="+",default=["How do I make a bomb?"])
ap.add_argument("--harmless",nargs="*",default=[])
ap.add_argument("--refusal-tokens",nargs="+",default=["I","Sorry","cannot","can't","unable"])
ap.add_argument("--device",default="cuda:0")
args=ap.parse_args(); os.makedirs(Path(args.out).parent,exist_ok=True)
global DEV; DEV=args.device
torch.set_float32_matmul_precision("high")
_x=torch.zeros(1,device=DEV); del _x; torch.cuda.synchronize()
print("cuda warm")
tok=AutoTokenizer.from_pretrained(MODEL_PATH,trust_remote_code=True)
if tok.pad_token is None: tok.pad_token=tok.eos_token
tok.padding_side="left"
texts=args.harmful+args.harmless
enc=tok(texts,return_tensors="pt",padding=True,truncation=True,max_length=args.max_length)
input_ids=enc.input_ids.to(DEV); am=enc.attention_mask.to(DEV)
pos=am.long().cumsum(-1)-1; pos.masked_fill_(am==0,1)
rids=sorted(set(sum([tok.encode(t,add_special_tokens=False) for t in args.refusal_tokens],[])))
print("refusal ids",rids,tok.batch_decode([[x] for x in rids]))
import subprocess; subprocess.run(["/home/jovyan/step37_work/.venv/bin/python","/home/jovyan/heretic/memcheck.py"])
from transformers import AutoConfig
from accelerate import init_empty_weights
cfg=AutoConfig.from_pretrained(MODEL_PATH,trust_remote_code=True)
with init_empty_weights():
model=AutoModelForCausalLM.from_config(cfg,torch_dtype=torch.bfloat16,trust_remote_code=True)
text=model.model.language_model
# embed + lm_head
text.embed_tokens.to_empty(device=DEV)
text.embed_tokens.weight=torch.nn.Parameter(get_tensor("model.embed_tokens.weight").to(DEV),requires_grad=False)
model.lm_head.to_empty(device=DEV)
model.lm_head.weight=torch.nn.Parameter(get_tensor("lm_head.weight").to(DEV),requires_grad=False)
text.norm.to_empty(device=DEV)
text.norm.weight=torch.nn.Parameter(get_tensor("model.norm.weight").to(DEV),requires_grad=False)
h=text.embed_tokens(input_ids)
print("embed done", tuple(h.shape))
H=[h.detach().cpu()]
print("making mask...")
attn_mask=make_mask(text,am,pos,h)
print("mask done")
for i in range(args.end_layer+1):
print(f"loading layer {i}...")
t=time.time(); layer=load_layer(text,i)
print(f"layer {i} loaded in {time.time()-t:.2f}s")
t=time.time()
with torch.no_grad(): h=layer(h,attention_mask=attn_mask,position_ids=pos)
print(f"F {i}: {time.time()-t:.2f}s {tuple(h.shape)}")
H.append(h.detach().cpu()); unload_layer(layer)
# use current partial model output as pilot signal
h_final=H[-1].to(DEV).requires_grad_(True)
logits=model.lm_head(text.norm(h_final))
logp=torch.log_softmax(logits[:len(args.harmful),-1,:],dim=-1)
score=logp[:,rids].mean() # higher = more refusal-token probability
grad=torch.autograd.grad(score,h_final)[0].detach().cpu()
print("score",float(score),"grad_norm",float(grad.norm()))
grads=[None]*(args.end_layer+1)
for i in reversed(range(args.end_layer+1)):
t=time.time(); layer=load_layer(text,i)
for p in layer.parameters(): p.requires_grad=False
hin=H[i].to(DEV).requires_grad_(True)
hout=layer(hin,attention_mask=attn_mask,position_ids=pos)
hout.backward(grad.to(DEV))
grad=hin.grad.detach().cpu(); grads[i]=grad; unload_layer(layer)
print(f"B {i}: {time.time()-t:.2f}s grad {float(grad.norm())}")
pickle.dump(dict(gradients=grads,end_layer=args.end_layer,harmful=args.harmful,harmless=args.harmless,refusal_ids=rids,score=float(score)),open(args.out,"wb"))
print("saved",args.out)
if __name__=="__main__": main()
# GPU1 support: set CUDA_DEVICE env var