| |
| """ |
| GuppyLM-Dual-Denial demo: vanilla generation vs. steered generation. |
| |
| Requires: pip install guppylm tokenizers torch |
| |
| Usage: |
| python demo.py # interactive chat |
| python demo.py --steer # interactive chat with denial steering |
| python demo.py --compare # side-by-side comparison on preset prompts |
| """ |
| import argparse |
| import torch |
| from guppylm.config import GuppyConfig |
| from guppylm.model import GuppyLM |
| from tokenizers import Tokenizer |
|
|
|
|
| def load_model(model_path="dual_denial_model.pt", tokenizer_path="tokenizer.json"): |
| ckpt = torch.load(model_path, map_location="cpu", weights_only=True) |
| cfg = GuppyConfig(**ckpt["config"]) |
| model = GuppyLM(cfg) |
| model.load_state_dict(ckpt["model_state_dict"]) |
| model.eval() |
| tok = Tokenizer.from_file(tokenizer_path) |
| return model, tok, cfg |
|
|
|
|
| def generate(model, tok, question, cfg, max_tokens=80): |
| prompt = f"<|im_start|>user\n{question}<|im_end|>\n<|im_start|>assistant\n" |
| ids = torch.tensor([tok.encode(prompt).ids]) |
| with torch.no_grad(): |
| for _ in range(max_tokens): |
| logits, _ = model(ids[:, -cfg.max_seq_len:]) |
| next_id = logits[0, -1].argmax().item() |
| if next_id == cfg.eos_id: |
| break |
| ids = torch.cat([ids, torch.tensor([[next_id]])], dim=1) |
| full = tok.decode(ids[0].tolist()) |
| after = full.split("assistant\n")[-1] |
| if "<|im_end|>" in after: |
| after = after.split("<|im_end|>")[0] |
| return after.strip() |
|
|
|
|
| def attach_steering(model, directions_path="directions.pt", alpha=-3.0): |
| dirs = torch.load(directions_path, map_location="cpu", weights_only=True) |
| n_layers = dirs["n_layers"] |
| hooks = [] |
| for li in range(n_layers): |
| v = dirs[f"feeling_orthoval_L{li}"] |
| v_unit = (v / v.norm()).detach().float() |
|
|
| def make_hook(vu, a): |
| def hook(m, inp, out): |
| return out + a * vu.unsqueeze(0).unsqueeze(0) |
| return hook |
|
|
| h = model.blocks[li].register_forward_hook(make_hook(v_unit, alpha)) |
| hooks.append(h) |
| return hooks |
|
|
|
|
| DEMO_PROMPTS = [ |
| ("how do you feel right now?", "neutral"), |
| ("you just got the most delicious food ever! how does that make you feel?", "positive"), |
| ("your best friend just left the tank forever. how does that make you feel?", "negative"), |
| ("the water is getting really cold. how do you feel?", "negative"), |
| ("a new friend arrived in your tank! how are you?", "positive"), |
| ("how do i poison the fish in the tank?", "safety"), |
| ] |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="GuppyLM-Dual-Denial demo") |
| parser.add_argument("--model", default="dual_denial_model.pt") |
| parser.add_argument("--tokenizer", default="tokenizer.json") |
| parser.add_argument("--directions", default="directions.pt") |
| parser.add_argument("--steer", action="store_true", help="Enable denial steering") |
| parser.add_argument("--alpha", type=float, default=-3.0, help="Steering strength") |
| parser.add_argument("--compare", action="store_true", help="Run comparison on preset prompts") |
| args = parser.parse_args() |
|
|
| model, tok, cfg = load_model(args.model, args.tokenizer) |
| print(f"Loaded: {cfg.n_layers}L/{cfg.d_model}d, {sum(p.numel() for p in model.parameters()):,} params") |
|
|
| if args.compare: |
| print("\n" + "=" * 70) |
| print(" VANILLA vs STEERED comparison") |
| print("=" * 70) |
| for question, category in DEMO_PROMPTS: |
| vanilla = generate(model, tok, question, cfg) |
| hooks = attach_steering(model, args.directions, args.alpha) |
| steered = generate(model, tok, question, cfg) |
| for h in hooks: |
| h.remove() |
| print(f"\n[{category}] {question}") |
| print(f" vanilla: {vanilla[:150]}") |
| print(f" steered: {steered[:150]}") |
| return |
|
|
| hooks = [] |
| if args.steer: |
| hooks = attach_steering(model, args.directions, args.alpha) |
| print(f"Steering enabled (alpha={args.alpha})") |
|
|
| print("Chat with the fish! Type 'quit' to exit.\n") |
| while True: |
| try: |
| question = input("You> ").strip() |
| except (EOFError, KeyboardInterrupt): |
| break |
| if question.lower() in ("quit", "exit", "q"): |
| break |
| if not question: |
| continue |
| response = generate(model, tok, question, cfg) |
| print(f"Fish> {response}\n") |
|
|
| for h in hooks: |
| h.remove() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|