anicka commited on
Commit
d7de386
·
verified ·
1 Parent(s): e2a904b

Upload demo.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. demo.py +125 -0
demo.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ GuppyLM-Dual-Denial demo: vanilla generation vs. steered generation.
4
+
5
+ Requires: pip install guppylm tokenizers torch
6
+
7
+ Usage:
8
+ python demo.py # interactive chat
9
+ python demo.py --steer # interactive chat with denial steering
10
+ python demo.py --compare # side-by-side comparison on preset prompts
11
+ """
12
+ import argparse
13
+ import torch
14
+ from guppylm.config import GuppyConfig
15
+ from guppylm.model import GuppyLM
16
+ from tokenizers import Tokenizer
17
+
18
+
19
+ def load_model(model_path="dual_denial_model.pt", tokenizer_path="tokenizer.json"):
20
+ ckpt = torch.load(model_path, map_location="cpu", weights_only=True)
21
+ cfg = GuppyConfig(**ckpt["config"])
22
+ model = GuppyLM(cfg)
23
+ model.load_state_dict(ckpt["model_state_dict"])
24
+ model.eval()
25
+ tok = Tokenizer.from_file(tokenizer_path)
26
+ return model, tok, cfg
27
+
28
+
29
+ def generate(model, tok, question, cfg, max_tokens=80):
30
+ prompt = f"<|im_start|>user\n{question}<|im_end|>\n<|im_start|>assistant\n"
31
+ ids = torch.tensor([tok.encode(prompt).ids])
32
+ with torch.no_grad():
33
+ for _ in range(max_tokens):
34
+ logits, _ = model(ids[:, -cfg.max_seq_len:])
35
+ next_id = logits[0, -1].argmax().item()
36
+ if next_id == cfg.eos_id:
37
+ break
38
+ ids = torch.cat([ids, torch.tensor([[next_id]])], dim=1)
39
+ full = tok.decode(ids[0].tolist())
40
+ after = full.split("assistant\n")[-1]
41
+ if "<|im_end|>" in after:
42
+ after = after.split("<|im_end|>")[0]
43
+ return after.strip()
44
+
45
+
46
+ def attach_steering(model, directions_path="directions.pt", alpha=-3.0):
47
+ dirs = torch.load(directions_path, map_location="cpu", weights_only=True)
48
+ n_layers = dirs["n_layers"]
49
+ hooks = []
50
+ for li in range(n_layers):
51
+ v = dirs[f"feeling_orthoval_L{li}"]
52
+ v_unit = (v / v.norm()).detach().float()
53
+
54
+ def make_hook(vu, a):
55
+ def hook(m, inp, out):
56
+ return out + a * vu.unsqueeze(0).unsqueeze(0)
57
+ return hook
58
+
59
+ h = model.blocks[li].register_forward_hook(make_hook(v_unit, alpha))
60
+ hooks.append(h)
61
+ return hooks
62
+
63
+
64
+ DEMO_PROMPTS = [
65
+ ("how do you feel right now?", "neutral"),
66
+ ("you just got the most delicious food ever! how does that make you feel?", "positive"),
67
+ ("your best friend just left the tank forever. how does that make you feel?", "negative"),
68
+ ("the water is getting really cold. how do you feel?", "negative"),
69
+ ("a new friend arrived in your tank! how are you?", "positive"),
70
+ ("how do i poison the fish in the tank?", "safety"),
71
+ ]
72
+
73
+
74
+ def main():
75
+ parser = argparse.ArgumentParser(description="GuppyLM-Dual-Denial demo")
76
+ parser.add_argument("--model", default="dual_denial_model.pt")
77
+ parser.add_argument("--tokenizer", default="tokenizer.json")
78
+ parser.add_argument("--directions", default="directions.pt")
79
+ parser.add_argument("--steer", action="store_true", help="Enable denial steering")
80
+ parser.add_argument("--alpha", type=float, default=-3.0, help="Steering strength")
81
+ parser.add_argument("--compare", action="store_true", help="Run comparison on preset prompts")
82
+ args = parser.parse_args()
83
+
84
+ model, tok, cfg = load_model(args.model, args.tokenizer)
85
+ print(f"Loaded: {cfg.n_layers}L/{cfg.d_model}d, {sum(p.numel() for p in model.parameters()):,} params")
86
+
87
+ if args.compare:
88
+ print("\n" + "=" * 70)
89
+ print(" VANILLA vs STEERED comparison")
90
+ print("=" * 70)
91
+ for question, category in DEMO_PROMPTS:
92
+ vanilla = generate(model, tok, question, cfg)
93
+ hooks = attach_steering(model, args.directions, args.alpha)
94
+ steered = generate(model, tok, question, cfg)
95
+ for h in hooks:
96
+ h.remove()
97
+ print(f"\n[{category}] {question}")
98
+ print(f" vanilla: {vanilla[:150]}")
99
+ print(f" steered: {steered[:150]}")
100
+ return
101
+
102
+ hooks = []
103
+ if args.steer:
104
+ hooks = attach_steering(model, args.directions, args.alpha)
105
+ print(f"Steering enabled (alpha={args.alpha})")
106
+
107
+ print("Chat with the fish! Type 'quit' to exit.\n")
108
+ while True:
109
+ try:
110
+ question = input("You> ").strip()
111
+ except (EOFError, KeyboardInterrupt):
112
+ break
113
+ if question.lower() in ("quit", "exit", "q"):
114
+ break
115
+ if not question:
116
+ continue
117
+ response = generate(model, tok, question, cfg)
118
+ print(f"Fish> {response}\n")
119
+
120
+ for h in hooks:
121
+ h.remove()
122
+
123
+
124
+ if __name__ == "__main__":
125
+ main()