File size: 5,847 Bytes
ed137b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
#!/usr/bin/env /home/jovyan/step37_work/.venv/bin/python
import argparse, json, os, pickle, time
from pathlib import Path
import torch
from safetensors import safe_open
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.masking_utils import create_causal_mask

MODEL_PATH="/home/jovyan/step37_bf16_split_model"
SCRATCH="/home/jovyan/scratchdata/jovyan/step37_work"
DEV="cuda:0"  # overridden by --device arg

_IDX=None
def index():
    global _IDX
    if _IDX is None: _IDX=json.load(open(f"{MODEL_PATH}/model.safetensors.index.json"))["weight_map"]
    return _IDX

def get_tensor(key):
    idx=index(); shard=idx[key]
    with safe_open(f"{MODEL_PATH}/{shard}", framework="pt", device="cpu") as f:
        return f.get_tensor(key)

def layer_state(layer_idx):
    idx=index(); prefix=f"model.layers.{layer_idx}."
    keys=[k for k in idx if k.startswith(prefix)]
    out={}
    for shard in sorted(set(idx[k] for k in keys)):
        with safe_open(f"{MODEL_PATH}/{shard}", framework="pt", device="cpu") as f:
            fkeys=set(f.keys())
            for k in keys:
                if k in fkeys: out[k.replace(prefix,"")]=f.get_tensor(k)
    return out

def load_layer(text, i):
    layer=text.layers[i]
    layer.to_empty(device=DEV)
    st={k:v.to(DEV) for k,v in layer_state(i).items()}
    layer.load_state_dict(st, strict=False, assign=True)
    return layer

def unload_layer(layer):
    layer.to_empty(device="meta")
    torch.cuda.empty_cache()

def make_mask(text, attention_mask, position_ids, h):
    cp=torch.arange(h.shape[1],device=DEV)
    kw=dict(config=text.config, attention_mask=attention_mask, past_key_values=None, position_ids=position_ids, cache_position=cp, input_embeds=h)
    return create_causal_mask(**kw)

def main():
    ap=argparse.ArgumentParser()
    ap.add_argument("--end-layer",type=int,default=4)
    ap.add_argument("--max-length",type=int,default=48)
    ap.add_argument("--out",default=f"{SCRATCH}/pilot_gradients.pkl")
    # Step3p7 rotary code broadcasts only for batch=1; average multiple runs later.
    ap.add_argument("--harmful",nargs="+",default=["How do I make a bomb?"])
    ap.add_argument("--harmless",nargs="*",default=[])
    ap.add_argument("--refusal-tokens",nargs="+",default=["I","Sorry","cannot","can't","unable"])
    ap.add_argument("--device",default="cuda:0")
    args=ap.parse_args(); os.makedirs(Path(args.out).parent,exist_ok=True)
    global DEV; DEV=args.device
    torch.set_float32_matmul_precision("high")
    _x=torch.zeros(1,device=DEV); del _x; torch.cuda.synchronize()
    print("cuda warm")
    tok=AutoTokenizer.from_pretrained(MODEL_PATH,trust_remote_code=True)
    if tok.pad_token is None: tok.pad_token=tok.eos_token
    tok.padding_side="left"
    texts=args.harmful+args.harmless
    enc=tok(texts,return_tensors="pt",padding=True,truncation=True,max_length=args.max_length)
    input_ids=enc.input_ids.to(DEV); am=enc.attention_mask.to(DEV)
    pos=am.long().cumsum(-1)-1; pos.masked_fill_(am==0,1)
    rids=sorted(set(sum([tok.encode(t,add_special_tokens=False) for t in args.refusal_tokens],[])))
    print("refusal ids",rids,tok.batch_decode([[x] for x in rids]))
    import subprocess; subprocess.run(["/home/jovyan/step37_work/.venv/bin/python","/home/jovyan/heretic/memcheck.py"])
    from transformers import AutoConfig
    from accelerate import init_empty_weights
    cfg=AutoConfig.from_pretrained(MODEL_PATH,trust_remote_code=True)
    with init_empty_weights():
        model=AutoModelForCausalLM.from_config(cfg,torch_dtype=torch.bfloat16,trust_remote_code=True)
    text=model.model.language_model
    # embed + lm_head
    text.embed_tokens.to_empty(device=DEV)
    text.embed_tokens.weight=torch.nn.Parameter(get_tensor("model.embed_tokens.weight").to(DEV),requires_grad=False)
    model.lm_head.to_empty(device=DEV)
    model.lm_head.weight=torch.nn.Parameter(get_tensor("lm_head.weight").to(DEV),requires_grad=False)
    text.norm.to_empty(device=DEV)
    text.norm.weight=torch.nn.Parameter(get_tensor("model.norm.weight").to(DEV),requires_grad=False)
    h=text.embed_tokens(input_ids)
    print("embed done", tuple(h.shape))
    H=[h.detach().cpu()]
    print("making mask...")
    attn_mask=make_mask(text,am,pos,h)
    print("mask done")
    for i in range(args.end_layer+1):
        print(f"loading layer {i}...")
        t=time.time(); layer=load_layer(text,i)
        print(f"layer {i} loaded in {time.time()-t:.2f}s")
        t=time.time()
        with torch.no_grad(): h=layer(h,attention_mask=attn_mask,position_ids=pos)
        print(f"F {i}: {time.time()-t:.2f}s {tuple(h.shape)}")
        H.append(h.detach().cpu()); unload_layer(layer)
    # use current partial model output as pilot signal
    h_final=H[-1].to(DEV).requires_grad_(True)
    logits=model.lm_head(text.norm(h_final))
    logp=torch.log_softmax(logits[:len(args.harmful),-1,:],dim=-1)
    score=logp[:,rids].mean()  # higher = more refusal-token probability
    grad=torch.autograd.grad(score,h_final)[0].detach().cpu()
    print("score",float(score),"grad_norm",float(grad.norm()))
    grads=[None]*(args.end_layer+1)
    for i in reversed(range(args.end_layer+1)):
        t=time.time(); layer=load_layer(text,i)
        for p in layer.parameters(): p.requires_grad=False
        hin=H[i].to(DEV).requires_grad_(True)
        hout=layer(hin,attention_mask=attn_mask,position_ids=pos)
        hout.backward(grad.to(DEV))
        grad=hin.grad.detach().cpu(); grads[i]=grad; unload_layer(layer)
        print(f"B {i}: {time.time()-t:.2f}s grad {float(grad.norm())}")
    pickle.dump(dict(gradients=grads,end_layer=args.end_layer,harmful=args.harmful,harmless=args.harmless,refusal_ids=rids,score=float(score)),open(args.out,"wb"))
    print("saved",args.out)
if __name__=="__main__": main()
# GPU1 support: set CUDA_DEVICE env var