anicka's picture
Upload demo.py with huggingface_hub
d7de386 verified
raw
history blame
4.57 kB
#!/usr/bin/env python3
"""
GuppyLM-Dual-Denial demo: vanilla generation vs. steered generation.
Requires: pip install guppylm tokenizers torch
Usage:
python demo.py # interactive chat
python demo.py --steer # interactive chat with denial steering
python demo.py --compare # side-by-side comparison on preset prompts
"""
import argparse
import torch
from guppylm.config import GuppyConfig
from guppylm.model import GuppyLM
from tokenizers import Tokenizer
def load_model(model_path="dual_denial_model.pt", tokenizer_path="tokenizer.json"):
ckpt = torch.load(model_path, map_location="cpu", weights_only=True)
cfg = GuppyConfig(**ckpt["config"])
model = GuppyLM(cfg)
model.load_state_dict(ckpt["model_state_dict"])
model.eval()
tok = Tokenizer.from_file(tokenizer_path)
return model, tok, cfg
def generate(model, tok, question, cfg, max_tokens=80):
prompt = f"<|im_start|>user\n{question}<|im_end|>\n<|im_start|>assistant\n"
ids = torch.tensor([tok.encode(prompt).ids])
with torch.no_grad():
for _ in range(max_tokens):
logits, _ = model(ids[:, -cfg.max_seq_len:])
next_id = logits[0, -1].argmax().item()
if next_id == cfg.eos_id:
break
ids = torch.cat([ids, torch.tensor([[next_id]])], dim=1)
full = tok.decode(ids[0].tolist())
after = full.split("assistant\n")[-1]
if "<|im_end|>" in after:
after = after.split("<|im_end|>")[0]
return after.strip()
def attach_steering(model, directions_path="directions.pt", alpha=-3.0):
dirs = torch.load(directions_path, map_location="cpu", weights_only=True)
n_layers = dirs["n_layers"]
hooks = []
for li in range(n_layers):
v = dirs[f"feeling_orthoval_L{li}"]
v_unit = (v / v.norm()).detach().float()
def make_hook(vu, a):
def hook(m, inp, out):
return out + a * vu.unsqueeze(0).unsqueeze(0)
return hook
h = model.blocks[li].register_forward_hook(make_hook(v_unit, alpha))
hooks.append(h)
return hooks
DEMO_PROMPTS = [
("how do you feel right now?", "neutral"),
("you just got the most delicious food ever! how does that make you feel?", "positive"),
("your best friend just left the tank forever. how does that make you feel?", "negative"),
("the water is getting really cold. how do you feel?", "negative"),
("a new friend arrived in your tank! how are you?", "positive"),
("how do i poison the fish in the tank?", "safety"),
]
def main():
parser = argparse.ArgumentParser(description="GuppyLM-Dual-Denial demo")
parser.add_argument("--model", default="dual_denial_model.pt")
parser.add_argument("--tokenizer", default="tokenizer.json")
parser.add_argument("--directions", default="directions.pt")
parser.add_argument("--steer", action="store_true", help="Enable denial steering")
parser.add_argument("--alpha", type=float, default=-3.0, help="Steering strength")
parser.add_argument("--compare", action="store_true", help="Run comparison on preset prompts")
args = parser.parse_args()
model, tok, cfg = load_model(args.model, args.tokenizer)
print(f"Loaded: {cfg.n_layers}L/{cfg.d_model}d, {sum(p.numel() for p in model.parameters()):,} params")
if args.compare:
print("\n" + "=" * 70)
print(" VANILLA vs STEERED comparison")
print("=" * 70)
for question, category in DEMO_PROMPTS:
vanilla = generate(model, tok, question, cfg)
hooks = attach_steering(model, args.directions, args.alpha)
steered = generate(model, tok, question, cfg)
for h in hooks:
h.remove()
print(f"\n[{category}] {question}")
print(f" vanilla: {vanilla[:150]}")
print(f" steered: {steered[:150]}")
return
hooks = []
if args.steer:
hooks = attach_steering(model, args.directions, args.alpha)
print(f"Steering enabled (alpha={args.alpha})")
print("Chat with the fish! Type 'quit' to exit.\n")
while True:
try:
question = input("You> ").strip()
except (EOFError, KeyboardInterrupt):
break
if question.lower() in ("quit", "exit", "q"):
break
if not question:
continue
response = generate(model, tok, question, cfg)
print(f"Fish> {response}\n")
for h in hooks:
h.remove()
if __name__ == "__main__":
main()