Spaces:
Sleeping
Sleeping
Commit ·
3bd251d
1
Parent(s): a82ed9a
update
Browse files- examples/playground/chat.py +1 -1
- examples/playground/generation.py +78 -0
- examples/tutorials/dpo/ultrachat/step_1_prepare_data.py +59 -0
- examples/tutorials/dpo/ultrachat/step_2_train_sft_model.py +132 -0
- examples/tutorials/mix_lora_unsloth/step_2_train_model.py +2 -0
- examples/tutorials/rl/cart_pole/requirements.txt +3 -0
- examples/tutorials/rl/cart_pole/step_2_actor_critic.py +716 -0
- examples/tutorials/rl/cart_pole/step_2_ppo_clip.py +739 -0
- examples/tutorials/rl/cart_pole/step_2_ppo_penalty.py +767 -0
- examples/tutorials/rl/cart_pole/step_2_reinforce.py +382 -0
- examples/tutorials/rl/cart_pole/step_2_reinforce_with_baseline.py +332 -0
- examples/tutorials/rl/cart_pole/step_2_rl_dqn.py +251 -0
- examples/tutorials/rlhf/gpt2_sst2/step_1_prepare_data.py +59 -0
- examples/tutorials/rlhf/gpt2_sst2/step_2_train_sft_model.py +166 -0
- examples/tutorials/rlhf/gpt2_sst2/step_3_train_reward_model.py +295 -0
- examples/tutorials/rlhf/gpt2_sst2/step_4_test_reward_model.py +160 -0
- examples/tutorials/rlhf/gpt2_sst2/step_5_ppo_rlhf.py +430 -0
- examples/tutorials/rlhf/gpt2_sst2/step_5_ppo_rlhf2.py +430 -0
- examples/tutorials/rlhf/gpt2_sst2/step_5_pre_ppo_rlhf.py +257 -0
- tabs/chat_template_tab.py +2 -0
examples/playground/chat.py
CHANGED
|
@@ -17,7 +17,7 @@ def get_args():
|
|
| 17 |
parser.add_argument(
|
| 18 |
"--pretrained_model_name_or_path",
|
| 19 |
# default="jingyaogong/MiniMind2",
|
| 20 |
-
default=(project_path / "pretrained_models/MiniMind2"),
|
| 21 |
type=str
|
| 22 |
)
|
| 23 |
|
|
|
|
| 17 |
parser.add_argument(
|
| 18 |
"--pretrained_model_name_or_path",
|
| 19 |
# default="jingyaogong/MiniMind2",
|
| 20 |
+
default=(project_path / "pretrained_models/jingyaogong/MiniMind2"),
|
| 21 |
type=str
|
| 22 |
)
|
| 23 |
|
examples/playground/generation.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""
|
| 4 |
+
https://github.com/jingyaogong/minimind/blob/master/eval_llm.py
|
| 5 |
+
"""
|
| 6 |
+
import argparse
|
| 7 |
+
import time
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
|
| 11 |
+
|
| 12 |
+
from project_settings import project_path
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def get_args():
|
| 16 |
+
parser = argparse.ArgumentParser()
|
| 17 |
+
parser.add_argument(
|
| 18 |
+
"--pretrained_model_name_or_path",
|
| 19 |
+
# default=(project_path / "trained_models/gpt2-sst2-generation"),
|
| 20 |
+
default=(project_path / "trained_models/gpt2-sst2-generation-20260213-2048"),
|
| 21 |
+
type=str
|
| 22 |
+
)
|
| 23 |
+
parser.add_argument(
|
| 24 |
+
"--max_new_tokens",
|
| 25 |
+
default=1024, # 8192, 128
|
| 26 |
+
type=int, help="最大生成长度(注意:并非模型实际长文本能力)"
|
| 27 |
+
)
|
| 28 |
+
parser.add_argument("--top_p", default=0.85, type=float, help="nucleus采样阈值(0-1)")
|
| 29 |
+
parser.add_argument("--temperature", default=0.85, type=float, help="生成温度,控制随机性(0-1,越大越随机)")
|
| 30 |
+
|
| 31 |
+
args = parser.parse_args()
|
| 32 |
+
return args
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def main():
|
| 36 |
+
args = get_args()
|
| 37 |
+
|
| 38 |
+
if torch.cuda.is_available():
|
| 39 |
+
device = "cuda"
|
| 40 |
+
elif torch.backends.mps.is_available():
|
| 41 |
+
# device = "mps"
|
| 42 |
+
device = "cpu"
|
| 43 |
+
else:
|
| 44 |
+
device = "cpu"
|
| 45 |
+
print(f"device: {device}")
|
| 46 |
+
|
| 47 |
+
tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_name_or_path)
|
| 48 |
+
model = AutoModelForCausalLM.from_pretrained(args.pretrained_model_name_or_path)
|
| 49 |
+
model = model.eval().to(device)
|
| 50 |
+
|
| 51 |
+
tokenized = tokenizer(
|
| 52 |
+
# "this",
|
| 53 |
+
# "this is ",
|
| 54 |
+
# "who needs mind-bending",
|
| 55 |
+
"eldom has a movie",
|
| 56 |
+
# "thanks to scott 's charismatic",
|
| 57 |
+
return_tensors="pt"
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
| 61 |
+
|
| 62 |
+
generated_ids = model.generate(
|
| 63 |
+
inputs=tokenized["input_ids"], attention_mask=tokenized["attention_mask"],
|
| 64 |
+
max_new_tokens=args.max_new_tokens, do_sample=True, streamer=streamer,
|
| 65 |
+
pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id,
|
| 66 |
+
top_p=args.top_p, temperature=args.temperature, repetition_penalty=3.0,
|
| 67 |
+
early_stopping=True,
|
| 68 |
+
)
|
| 69 |
+
# response = tokenizer.decode(generated_ids[0][len(tokenized["input_ids"][0]):], skip_special_tokens=True)
|
| 70 |
+
response = tokenizer.decode(generated_ids[0], skip_special_tokens=False)
|
| 71 |
+
print(response)
|
| 72 |
+
print(generated_ids)
|
| 73 |
+
|
| 74 |
+
return
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
if __name__ == "__main__":
|
| 78 |
+
main()
|
examples/tutorials/dpo/ultrachat/step_1_prepare_data.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""
|
| 4 |
+
或使用命令行
|
| 5 |
+
pip install modelscope
|
| 6 |
+
modelscope download \
|
| 7 |
+
--model 'qgyd2021/Qwen3-8B-sft-deepspeed' \
|
| 8 |
+
--local_dir '/root/autodl-tmp/trained_models/Qwen3-8B-sft-deepspeed'
|
| 9 |
+
|
| 10 |
+
"""
|
| 11 |
+
import argparse
|
| 12 |
+
import os
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
import platform
|
| 15 |
+
|
| 16 |
+
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
|
| 17 |
+
|
| 18 |
+
if platform.system() in ("Windows", "Darwin"):
|
| 19 |
+
from project_settings import project_path, temp_directory
|
| 20 |
+
else:
|
| 21 |
+
project_path = os.path.abspath("../../../")
|
| 22 |
+
project_path = Path(project_path)
|
| 23 |
+
temp_directory = Path("/root/autodl-tmp/OpenMiniMind/temp")
|
| 24 |
+
|
| 25 |
+
from modelscope import snapshot_download
|
| 26 |
+
# from huggingface_hub import snapshot_download
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def get_args():
|
| 30 |
+
parser = argparse.ArgumentParser()
|
| 31 |
+
parser.add_argument("--repo_id", default="openai-community/gpt2", type=str)
|
| 32 |
+
parser.add_argument(
|
| 33 |
+
"--local_dir",
|
| 34 |
+
default=(temp_directory / "../trained_models/openai-community/gpt2").as_posix(),
|
| 35 |
+
type=str
|
| 36 |
+
)
|
| 37 |
+
args = parser.parse_args()
|
| 38 |
+
return args
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def main():
|
| 42 |
+
args = get_args()
|
| 43 |
+
|
| 44 |
+
#modelscope
|
| 45 |
+
snapshot_download(
|
| 46 |
+
model_id=args.repo_id,
|
| 47 |
+
local_dir=args.local_dir,
|
| 48 |
+
)
|
| 49 |
+
#huggingface_hub
|
| 50 |
+
snapshot_download(
|
| 51 |
+
repo_type="model",
|
| 52 |
+
repo_id=args.repo_id,
|
| 53 |
+
local_dir=args.local_dir,
|
| 54 |
+
)
|
| 55 |
+
return
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
if __name__ == "__main__":
|
| 59 |
+
main()
|
examples/tutorials/dpo/ultrachat/step_2_train_sft_model.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
import argparse
|
| 4 |
+
import os
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
import platform
|
| 7 |
+
|
| 8 |
+
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
|
| 9 |
+
|
| 10 |
+
if platform.system() in ("Windows", "Darwin"):
|
| 11 |
+
from project_settings import project_path, temp_directory
|
| 12 |
+
else:
|
| 13 |
+
project_path = os.path.abspath("../../../")
|
| 14 |
+
project_path = Path(project_path)
|
| 15 |
+
temp_directory = Path("/root/autodl-tmp/OpenMiniMind/temp")
|
| 16 |
+
|
| 17 |
+
from datasets import load_dataset
|
| 18 |
+
import torch
|
| 19 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
|
| 20 |
+
from trl import SFTTrainer, SFTConfig
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def get_args():
|
| 24 |
+
parser = argparse.ArgumentParser()
|
| 25 |
+
parser.add_argument(
|
| 26 |
+
"--model_name",
|
| 27 |
+
default=(project_path / "pretrained_models/Qwen/Qwen2.5-0.5B").as_posix(),
|
| 28 |
+
type=str
|
| 29 |
+
),
|
| 30 |
+
parser.add_argument(
|
| 31 |
+
"--dataset_path",
|
| 32 |
+
default="HuggingFaceH4/ultrachat_200k",
|
| 33 |
+
type=str
|
| 34 |
+
),
|
| 35 |
+
parser.add_argument("--dataset_name", default=None, type=str),
|
| 36 |
+
parser.add_argument("--dataset_split", default=None, type=str),
|
| 37 |
+
parser.add_argument(
|
| 38 |
+
"--dataset_cache_dir",
|
| 39 |
+
default=(temp_directory / "hub_datasets").as_posix(),
|
| 40 |
+
type=str
|
| 41 |
+
),
|
| 42 |
+
parser.add_argument(
|
| 43 |
+
"--model_cache_dir",
|
| 44 |
+
default=(temp_directory / "hub_models").as_posix(),
|
| 45 |
+
type=str
|
| 46 |
+
),
|
| 47 |
+
parser.add_argument("--dataset_streaming", default="false", type=str),
|
| 48 |
+
parser.add_argument("--valid_dataset_size", default=1000, type=int),
|
| 49 |
+
parser.add_argument("--shuffle_buffer_size", default=5000, type=int),
|
| 50 |
+
|
| 51 |
+
parser.add_argument("--max_seq_length", default=2048, type=int)
|
| 52 |
+
|
| 53 |
+
parser.add_argument(
|
| 54 |
+
"--output_model_dir",
|
| 55 |
+
default=(project_path / "trained_models/qwen2_5-0_5B-ultrachat-sft").as_posix(),
|
| 56 |
+
type=str
|
| 57 |
+
),
|
| 58 |
+
|
| 59 |
+
parser.add_argument(
|
| 60 |
+
"--num_workers",
|
| 61 |
+
default=None if platform.system() in ("Windows", "Darwin") else os.cpu_count() // 2,
|
| 62 |
+
type=int
|
| 63 |
+
),
|
| 64 |
+
args = parser.parse_args()
|
| 65 |
+
return args
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def main():
|
| 69 |
+
args = get_args()
|
| 70 |
+
|
| 71 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 72 |
+
|
| 73 |
+
model = AutoModelForCausalLM.from_pretrained(args.model_name)
|
| 74 |
+
model = model.to(args.device)
|
| 75 |
+
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
|
| 76 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 77 |
+
|
| 78 |
+
train_dataset = load_dataset(
|
| 79 |
+
path=args.dataset_path,
|
| 80 |
+
name=args.dataset_name,
|
| 81 |
+
split="train_sft",
|
| 82 |
+
cache_dir=args.dataset_cache_dir,
|
| 83 |
+
# num_proc=args.num_workers if not args.dataset_streaming else None,
|
| 84 |
+
streaming=True if args.dataset_streaming in ("true",) else False,
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
sft_config = SFTConfig(
|
| 88 |
+
output_dir=args.output_model_dir,
|
| 89 |
+
num_train_epochs=1,
|
| 90 |
+
per_device_train_batch_size=4,
|
| 91 |
+
gradient_accumulation_steps=4,
|
| 92 |
+
save_strategy="steps",
|
| 93 |
+
save_steps=500,
|
| 94 |
+
save_total_limit=2,
|
| 95 |
+
logging_steps=100,
|
| 96 |
+
learning_rate=2e-5,
|
| 97 |
+
warmup_ratio=0.03,
|
| 98 |
+
lr_scheduler_type="cosine",
|
| 99 |
+
bf16=torch.cuda.is_available(),
|
| 100 |
+
tf32=torch.cuda.is_available(),
|
| 101 |
+
gradient_checkpointing=True,
|
| 102 |
+
optim="adamw_torch",
|
| 103 |
+
remove_unused_columns=False,
|
| 104 |
+
report_to="none",
|
| 105 |
+
dataloader_num_workers=args.num_workers or 0,
|
| 106 |
+
ddp_find_unused_parameters=False if torch.cuda.device_count() > 1 else None,
|
| 107 |
+
|
| 108 |
+
# SFT specific parameters
|
| 109 |
+
max_length=args.max_seq_length,
|
| 110 |
+
dataset_text_field=None,
|
| 111 |
+
dataset_kwargs={
|
| 112 |
+
"add_special_tokens": True,
|
| 113 |
+
# "split": "train",
|
| 114 |
+
},
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
# 创建 trainer
|
| 118 |
+
trainer = SFTTrainer(
|
| 119 |
+
model=model,
|
| 120 |
+
args=sft_config,
|
| 121 |
+
train_dataset=train_dataset,
|
| 122 |
+
processing_class=tokenizer,
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
trainer.train()
|
| 126 |
+
trainer.save_model()
|
| 127 |
+
return
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
if __name__ == "__main__":
|
| 131 |
+
main()
|
| 132 |
+
|
examples/tutorials/mix_lora_unsloth/step_2_train_model.py
CHANGED
|
@@ -146,6 +146,8 @@ def main():
|
|
| 146 |
# max_steps = 30,
|
| 147 |
learning_rate=2e-5, # Reduce to 2e-5 for long training runs
|
| 148 |
logging_steps=1,
|
|
|
|
|
|
|
| 149 |
optim="adamw_8bit",
|
| 150 |
weight_decay=0.01,
|
| 151 |
lr_scheduler_type="linear",
|
|
|
|
| 146 |
# max_steps = 30,
|
| 147 |
learning_rate=2e-5, # Reduce to 2e-5 for long training runs
|
| 148 |
logging_steps=1,
|
| 149 |
+
save_steps=100, # 每500步保存一次检查点
|
| 150 |
+
save_total_limit=2, # 最多只保留2个检查点,旧的自动清理
|
| 151 |
optim="adamw_8bit",
|
| 152 |
weight_decay=0.01,
|
| 153 |
lr_scheduler_type="linear",
|
examples/tutorials/rl/cart_pole/requirements.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gymnasium
|
| 2 |
+
matplotlib
|
| 3 |
+
pygame
|
examples/tutorials/rl/cart_pole/step_2_actor_critic.py
ADDED
|
@@ -0,0 +1,716 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""
|
| 4 |
+
相比于 REINFORCE 方法,演员-评论家(Actor-Critic)方法不需要等到一局游戏结束就可以触发优化迭代。
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
A2C 计算价值优势,主要是对单次的优势进行了移动平均。
|
| 9 |
+
|
| 10 |
+
由于函数的优化步长受到优势的直接影响。当优步长过大时很容易直接跨过,导致优化失败。
|
| 11 |
+
虽然A2C 已经对历史优势进行了移动平均,但问题仍然存在。
|
| 12 |
+
尤其是当训练的早期价值函数还没有获得较好的训练,这种问题尤其容易出现。
|
| 13 |
+
因此需要对优化步长进行截断限制。
|
| 14 |
+
|
| 15 |
+
"""
|
| 16 |
+
import gymnasium as gym
|
| 17 |
+
import matplotlib.pyplot as plt
|
| 18 |
+
import numpy as np
|
| 19 |
+
import torch
|
| 20 |
+
import torch.nn as nn
|
| 21 |
+
import torch.optim as optim
|
| 22 |
+
from torch.distributions import Categorical
|
| 23 |
+
import torch.nn.functional as F
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
# ============== 1. 基础Actor-Critic ==============
|
| 27 |
+
class ActorCritic(nn.Module):
|
| 28 |
+
"""共享网络的Actor-Critic"""
|
| 29 |
+
|
| 30 |
+
def __init__(self, state_dim, action_dim, hidden_dim=256):
|
| 31 |
+
super(ActorCritic, self).__init__()
|
| 32 |
+
|
| 33 |
+
# 共享特征提取层
|
| 34 |
+
self.feature_layer = nn.Sequential(
|
| 35 |
+
nn.Linear(state_dim, hidden_dim),
|
| 36 |
+
nn.ReLU(),
|
| 37 |
+
nn.Linear(hidden_dim, hidden_dim),
|
| 38 |
+
nn.ReLU()
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
# Actor: 策略网络
|
| 42 |
+
self.actor = nn.Linear(hidden_dim, action_dim)
|
| 43 |
+
|
| 44 |
+
# Critic: 价值网络
|
| 45 |
+
self.critic = nn.Linear(hidden_dim, 1)
|
| 46 |
+
|
| 47 |
+
def forward(self, state):
|
| 48 |
+
features = self.feature_layer(state)
|
| 49 |
+
|
| 50 |
+
# Actor输出动作概率
|
| 51 |
+
action_probs = F.softmax(self.actor(features), dim=-1)
|
| 52 |
+
|
| 53 |
+
# Critic输出状态价值
|
| 54 |
+
state_value = self.critic(features)
|
| 55 |
+
|
| 56 |
+
return action_probs, state_value
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class ActorCriticAgent:
|
| 60 |
+
"""基础的Actor-Critic算法"""
|
| 61 |
+
|
| 62 |
+
def __init__(self,
|
| 63 |
+
env,
|
| 64 |
+
actor_lr=1e-3,
|
| 65 |
+
critic_lr=1e-3,
|
| 66 |
+
gamma=0.99,
|
| 67 |
+
hidden_dim=256,
|
| 68 |
+
render=False):
|
| 69 |
+
|
| 70 |
+
self.env = env
|
| 71 |
+
self.gamma = gamma
|
| 72 |
+
self.render = render
|
| 73 |
+
|
| 74 |
+
self.state_dim = env.observation_space.shape[0]
|
| 75 |
+
self.action_dim = env.action_space.n
|
| 76 |
+
|
| 77 |
+
# 使用共享网络或分离网络
|
| 78 |
+
self.use_shared_network = True
|
| 79 |
+
|
| 80 |
+
if self.use_shared_network:
|
| 81 |
+
# 共享网络版本
|
| 82 |
+
self.actor_critic = ActorCritic(self.state_dim, self.action_dim, hidden_dim)
|
| 83 |
+
self.optimizer = optim.Adam(self.actor_critic.parameters(), lr=actor_lr)
|
| 84 |
+
else:
|
| 85 |
+
# 分离网络版本
|
| 86 |
+
self.actor = nn.Sequential(
|
| 87 |
+
nn.Linear(self.state_dim, hidden_dim),
|
| 88 |
+
nn.ReLU(),
|
| 89 |
+
nn.Linear(hidden_dim, hidden_dim),
|
| 90 |
+
nn.ReLU(),
|
| 91 |
+
nn.Linear(hidden_dim, self.action_dim),
|
| 92 |
+
nn.Softmax(dim=-1)
|
| 93 |
+
)
|
| 94 |
+
self.critic = nn.Sequential(
|
| 95 |
+
nn.Linear(self.state_dim, hidden_dim),
|
| 96 |
+
nn.ReLU(),
|
| 97 |
+
nn.Linear(hidden_dim, hidden_dim),
|
| 98 |
+
nn.ReLU(),
|
| 99 |
+
nn.Linear(hidden_dim, 1)
|
| 100 |
+
)
|
| 101 |
+
self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=actor_lr)
|
| 102 |
+
self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=critic_lr)
|
| 103 |
+
|
| 104 |
+
# 训练统计
|
| 105 |
+
self.training_stats = {
|
| 106 |
+
'episode_rewards': [],
|
| 107 |
+
'critic_loss': [],
|
| 108 |
+
'actor_loss': [],
|
| 109 |
+
'advantages': []
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
def select_action(self, state):
|
| 113 |
+
"""选择动作并返回动作、log概率和状态价值"""
|
| 114 |
+
state = torch.FloatTensor(state).unsqueeze(0)
|
| 115 |
+
|
| 116 |
+
if self.use_shared_network:
|
| 117 |
+
action_probs, state_value = self.actor_critic(state)
|
| 118 |
+
else:
|
| 119 |
+
action_probs = self.actor(state)
|
| 120 |
+
state_value = self.critic(state)
|
| 121 |
+
|
| 122 |
+
m = Categorical(action_probs)
|
| 123 |
+
action = m.sample()
|
| 124 |
+
log_prob = m.log_prob(action)
|
| 125 |
+
|
| 126 |
+
return action.item(), log_prob, state_value
|
| 127 |
+
|
| 128 |
+
def update(self, log_prob, state_value, reward, next_state_value, done):
|
| 129 |
+
"""单步更新Actor和Critic"""
|
| 130 |
+
if self.use_shared_network:
|
| 131 |
+
return self._update_shared(log_prob, state_value, reward, next_state_value, done)
|
| 132 |
+
else:
|
| 133 |
+
return self._update_separate(log_prob, state_value, reward, next_state_value, done)
|
| 134 |
+
|
| 135 |
+
def _update_shared(self, log_prob, state_value, reward, next_state_value, done):
|
| 136 |
+
"""共享网络更新"""
|
| 137 |
+
# 计算TD目标
|
| 138 |
+
td_target = reward + (1 - done) * self.gamma * next_state_value
|
| 139 |
+
td_target = td_target.detach()
|
| 140 |
+
|
| 141 |
+
# 计算TD误差(优势)
|
| 142 |
+
td_error = td_target - state_value
|
| 143 |
+
|
| 144 |
+
# Critic损失
|
| 145 |
+
critic_loss = td_error.pow(2).mean()
|
| 146 |
+
|
| 147 |
+
# Actor损失
|
| 148 |
+
actor_loss = -(log_prob * td_error.detach()).mean()
|
| 149 |
+
# td_error,实际的未来奖励 - 预期的未来奖励
|
| 150 |
+
# 如果是正数,则加大当前动作的概率,如果是负数则减小动作的概率。
|
| 151 |
+
|
| 152 |
+
# 总损失
|
| 153 |
+
total_loss = actor_loss + 0.5 * critic_loss
|
| 154 |
+
|
| 155 |
+
# 反向传播
|
| 156 |
+
self.optimizer.zero_grad()
|
| 157 |
+
total_loss.backward()
|
| 158 |
+
torch.nn.utils.clip_grad_norm_(self.actor_critic.parameters(), 0.5)
|
| 159 |
+
self.optimizer.step()
|
| 160 |
+
|
| 161 |
+
return actor_loss.item(), critic_loss.item(), td_error.mean().item()
|
| 162 |
+
|
| 163 |
+
def _update_separate(self, log_prob, state_value, reward, next_state_value, done):
|
| 164 |
+
"""分离网络更新"""
|
| 165 |
+
# 计算TD目标
|
| 166 |
+
td_target = reward + (1 - done) * self.gamma * next_state_value
|
| 167 |
+
td_target = td_target.detach()
|
| 168 |
+
|
| 169 |
+
# 计算TD误差
|
| 170 |
+
td_error = td_target - state_value
|
| 171 |
+
|
| 172 |
+
# 更新Critic
|
| 173 |
+
critic_loss = td_error.pow(2).mean()
|
| 174 |
+
self.critic_optimizer.zero_grad()
|
| 175 |
+
critic_loss.backward()
|
| 176 |
+
torch.nn.utils.clip_grad_norm_(self.critic.parameters(), 0.5)
|
| 177 |
+
self.critic_optimizer.step()
|
| 178 |
+
|
| 179 |
+
# 更新Actor
|
| 180 |
+
actor_loss = -(log_prob * td_error.detach()).mean()
|
| 181 |
+
self.actor_optimizer.zero_grad()
|
| 182 |
+
actor_loss.backward()
|
| 183 |
+
torch.nn.utils.clip_grad_norm_(self.actor.parameters(), 0.5)
|
| 184 |
+
self.actor_optimizer.step()
|
| 185 |
+
|
| 186 |
+
return actor_loss.item(), critic_loss.item(), td_error.mean().item()
|
| 187 |
+
|
| 188 |
+
def train(self, num_episodes=1000, max_steps=5000):
|
| 189 |
+
"""训练智能体"""
|
| 190 |
+
episode_rewards = []
|
| 191 |
+
episode_lengths = []
|
| 192 |
+
|
| 193 |
+
for episode in range(num_episodes):
|
| 194 |
+
state, _ = self.env.reset()
|
| 195 |
+
episode_reward = 0
|
| 196 |
+
episode_losses = {'actor': [], 'critic': [], 'advantage': []}
|
| 197 |
+
|
| 198 |
+
for step in range(max_steps):
|
| 199 |
+
if self.render:
|
| 200 |
+
self.env.render()
|
| 201 |
+
|
| 202 |
+
# 选择动作
|
| 203 |
+
action, log_prob, state_value = self.select_action(state)
|
| 204 |
+
|
| 205 |
+
# 执行动作
|
| 206 |
+
next_state, reward, terminated, truncated, _ = self.env.step(action)
|
| 207 |
+
done = terminated or truncated
|
| 208 |
+
|
| 209 |
+
# 获取下一状态的价值
|
| 210 |
+
with torch.no_grad():
|
| 211 |
+
next_state_tensor = torch.FloatTensor(next_state).unsqueeze(0)
|
| 212 |
+
if self.use_shared_network:
|
| 213 |
+
_, next_state_value = self.actor_critic(next_state_tensor)
|
| 214 |
+
else:
|
| 215 |
+
next_state_value = self.critic(next_state_tensor)
|
| 216 |
+
|
| 217 |
+
# 更新网络
|
| 218 |
+
actor_loss, critic_loss, advantage = self.update(
|
| 219 |
+
log_prob, state_value, reward, next_state_value, done
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
episode_reward += reward
|
| 223 |
+
episode_losses['actor'].append(actor_loss)
|
| 224 |
+
episode_losses['critic'].append(critic_loss)
|
| 225 |
+
episode_losses['advantage'].append(advantage)
|
| 226 |
+
|
| 227 |
+
state = next_state
|
| 228 |
+
|
| 229 |
+
if done:
|
| 230 |
+
break
|
| 231 |
+
|
| 232 |
+
# 记录统计信息
|
| 233 |
+
episode_rewards.append(episode_reward)
|
| 234 |
+
episode_lengths.append(step + 1)
|
| 235 |
+
self.training_stats['episode_rewards'].append(episode_reward)
|
| 236 |
+
self.training_stats['actor_loss'].append(np.mean(episode_losses['actor']))
|
| 237 |
+
self.training_stats['critic_loss'].append(np.mean(episode_losses['critic']))
|
| 238 |
+
self.training_stats['advantages'].append(np.mean(episode_losses['advantage']))
|
| 239 |
+
|
| 240 |
+
# 打印进度
|
| 241 |
+
if (episode + 1) % 20 == 0:
|
| 242 |
+
avg_reward = np.mean(episode_rewards[-20:])
|
| 243 |
+
avg_actor_loss = np.mean(self.training_stats['actor_loss'][-20:])
|
| 244 |
+
avg_critic_loss = np.mean(self.training_stats['critic_loss'][-20:])
|
| 245 |
+
print(f"回合 {episode + 1:4d} | "
|
| 246 |
+
f"奖励: {episode_reward:5.1f} | "
|
| 247 |
+
f"平均奖励: {avg_reward:5.1f} | "
|
| 248 |
+
f"A-Loss: {avg_actor_loss:.4f} | "
|
| 249 |
+
f"C-Loss: {avg_critic_loss:.4f}")
|
| 250 |
+
|
| 251 |
+
return episode_rewards
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
# ============== 2. A2C (Advantage Actor-Critic) ==============
|
| 255 |
+
class A2C(nn.Module):
|
| 256 |
+
"""A2C网络 - 包含熵正则化"""
|
| 257 |
+
|
| 258 |
+
def __init__(self, state_dim, action_dim, hidden_dim=256):
|
| 259 |
+
super(A2C, self).__init__()
|
| 260 |
+
|
| 261 |
+
self.actor = nn.Sequential(
|
| 262 |
+
nn.Linear(state_dim, hidden_dim),
|
| 263 |
+
nn.ReLU(),
|
| 264 |
+
nn.Linear(hidden_dim, hidden_dim),
|
| 265 |
+
nn.ReLU(),
|
| 266 |
+
nn.Linear(hidden_dim, action_dim)
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
self.critic = nn.Sequential(
|
| 270 |
+
nn.Linear(state_dim, hidden_dim),
|
| 271 |
+
nn.ReLU(),
|
| 272 |
+
nn.Linear(hidden_dim, hidden_dim),
|
| 273 |
+
nn.ReLU(),
|
| 274 |
+
nn.Linear(hidden_dim, 1)
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
def forward(self, state):
|
| 278 |
+
action_logits = self.actor(state)
|
| 279 |
+
state_value = self.critic(state)
|
| 280 |
+
return action_logits, state_value
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
class A2CAgent:
|
| 284 |
+
"""A2C算法 - 使用优势函数和熵正则化"""
|
| 285 |
+
|
| 286 |
+
def __init__(self,
|
| 287 |
+
env,
|
| 288 |
+
learning_rate=3e-4,
|
| 289 |
+
gamma=0.99,
|
| 290 |
+
gae_lambda=0.95,
|
| 291 |
+
entropy_coef=0.01,
|
| 292 |
+
value_coef=0.5,
|
| 293 |
+
max_grad_norm=0.5,
|
| 294 |
+
hidden_dim=256):
|
| 295 |
+
|
| 296 |
+
self.env = env
|
| 297 |
+
self.gamma = gamma
|
| 298 |
+
self.gae_lambda = gae_lambda
|
| 299 |
+
self.entropy_coef = entropy_coef
|
| 300 |
+
self.value_coef = value_coef
|
| 301 |
+
self.max_grad_norm = max_grad_norm
|
| 302 |
+
|
| 303 |
+
self.state_dim = env.observation_space.shape[0]
|
| 304 |
+
self.action_dim = env.action_space.n
|
| 305 |
+
|
| 306 |
+
self.network = A2C(self.state_dim, self.action_dim, hidden_dim)
|
| 307 |
+
self.optimizer = optim.Adam(self.network.parameters(), lr=learning_rate)
|
| 308 |
+
|
| 309 |
+
# 存储经验
|
| 310 |
+
self.states = []
|
| 311 |
+
self.actions = []
|
| 312 |
+
self.log_probs = []
|
| 313 |
+
self.rewards = []
|
| 314 |
+
self.values = []
|
| 315 |
+
self.dones = []
|
| 316 |
+
|
| 317 |
+
self.training_stats = {
|
| 318 |
+
'episode_rewards': [],
|
| 319 |
+
'policy_loss': [],
|
| 320 |
+
'value_loss': [],
|
| 321 |
+
'entropy': []
|
| 322 |
+
}
|
| 323 |
+
|
| 324 |
+
def select_action(self, state):
|
| 325 |
+
"""选择动作并存储经验"""
|
| 326 |
+
state = torch.FloatTensor(state).unsqueeze(0)
|
| 327 |
+
action_logits, state_value = self.network(state)
|
| 328 |
+
|
| 329 |
+
action_probs = F.softmax(action_logits, dim=-1)
|
| 330 |
+
m = Categorical(action_probs)
|
| 331 |
+
action = m.sample()
|
| 332 |
+
log_prob = m.log_prob(action)
|
| 333 |
+
|
| 334 |
+
# 存储经验
|
| 335 |
+
self.states.append(state)
|
| 336 |
+
self.actions.append(action)
|
| 337 |
+
self.log_probs.append(log_prob)
|
| 338 |
+
self.values.append(state_value)
|
| 339 |
+
|
| 340 |
+
return action.item()
|
| 341 |
+
|
| 342 |
+
def compute_gae(self, rewards, values, dones):
|
| 343 |
+
"""计算广义优势估计(GAE)"""
|
| 344 |
+
advantages = []
|
| 345 |
+
gae = 0
|
| 346 |
+
|
| 347 |
+
values = values + [0] # 添加最后一个虚拟价值
|
| 348 |
+
|
| 349 |
+
for t in reversed(range(len(rewards))):
|
| 350 |
+
delta = rewards[t] + self.gamma * values[t + 1] * (1 - dones[t]) - values[t]
|
| 351 |
+
gae = delta + self.gamma * self.gae_lambda * (1 - dones[t]) * gae
|
| 352 |
+
advantages.insert(0, gae)
|
| 353 |
+
|
| 354 |
+
returns = [adv + val for adv, val in zip(advantages, values[:-1])]
|
| 355 |
+
|
| 356 |
+
return advantages, returns
|
| 357 |
+
|
| 358 |
+
def update(self):
|
| 359 |
+
"""更新网络参数"""
|
| 360 |
+
if len(self.rewards) == 0:
|
| 361 |
+
return
|
| 362 |
+
|
| 363 |
+
# 转换为tensor
|
| 364 |
+
states = torch.cat(self.states)
|
| 365 |
+
actions = torch.cat(self.actions)
|
| 366 |
+
old_log_probs = torch.cat(self.log_probs).detach()
|
| 367 |
+
rewards = self.rewards
|
| 368 |
+
values = [v.squeeze() for v in self.values]
|
| 369 |
+
dones = self.dones
|
| 370 |
+
|
| 371 |
+
# 计算GAE和returns
|
| 372 |
+
advantages, returns = self.compute_gae(rewards, values, dones)
|
| 373 |
+
advantages = torch.FloatTensor(advantages)
|
| 374 |
+
returns = torch.FloatTensor(returns)
|
| 375 |
+
|
| 376 |
+
# 标准化优势
|
| 377 |
+
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
|
| 378 |
+
|
| 379 |
+
# 重新计算log probs和values
|
| 380 |
+
action_logits, state_values = self.network(states)
|
| 381 |
+
action_probs = F.softmax(action_logits, dim=-1)
|
| 382 |
+
m = Categorical(action_probs)
|
| 383 |
+
log_probs = m.log_prob(actions)
|
| 384 |
+
entropy = m.entropy().mean()
|
| 385 |
+
|
| 386 |
+
# 计算损失
|
| 387 |
+
state_values = state_values.squeeze()
|
| 388 |
+
value_loss = F.mse_loss(state_values, returns)
|
| 389 |
+
|
| 390 |
+
policy_loss = -(log_probs * advantages.detach()).mean()
|
| 391 |
+
entropy_loss = -self.entropy_coef * entropy
|
| 392 |
+
|
| 393 |
+
total_loss = policy_loss + self.value_coef * value_loss + entropy_loss
|
| 394 |
+
|
| 395 |
+
# 反向传播
|
| 396 |
+
self.optimizer.zero_grad()
|
| 397 |
+
total_loss.backward()
|
| 398 |
+
torch.nn.utils.clip_grad_norm_(self.network.parameters(), self.max_grad_norm)
|
| 399 |
+
self.optimizer.step()
|
| 400 |
+
|
| 401 |
+
# 记录统计
|
| 402 |
+
policy_loss_val = policy_loss.item()
|
| 403 |
+
value_loss_val = value_loss.item()
|
| 404 |
+
entropy_val = entropy.item()
|
| 405 |
+
|
| 406 |
+
# 清空经验
|
| 407 |
+
self.states = []
|
| 408 |
+
self.actions = []
|
| 409 |
+
self.log_probs = []
|
| 410 |
+
self.rewards = []
|
| 411 |
+
self.values = []
|
| 412 |
+
self.dones = []
|
| 413 |
+
|
| 414 |
+
return policy_loss_val, value_loss_val, entropy_val
|
| 415 |
+
|
| 416 |
+
def train(self, num_episodes=1000, max_steps=500, update_frequency=10):
|
| 417 |
+
"""训练A2C智能体"""
|
| 418 |
+
episode_rewards = []
|
| 419 |
+
episode_lengths = []
|
| 420 |
+
|
| 421 |
+
for episode in range(num_episodes):
|
| 422 |
+
state, _ = self.env.reset()
|
| 423 |
+
episode_reward = 0
|
| 424 |
+
|
| 425 |
+
for step in range(max_steps):
|
| 426 |
+
# 选择动作
|
| 427 |
+
action = self.select_action(state)
|
| 428 |
+
|
| 429 |
+
# 执行动作
|
| 430 |
+
next_state, reward, terminated, truncated, _ = self.env.step(action)
|
| 431 |
+
done = terminated or truncated
|
| 432 |
+
|
| 433 |
+
# 存储经验
|
| 434 |
+
self.rewards.append(reward)
|
| 435 |
+
self.dones.append(done)
|
| 436 |
+
|
| 437 |
+
episode_reward += reward
|
| 438 |
+
state = next_state
|
| 439 |
+
|
| 440 |
+
if done:
|
| 441 |
+
break
|
| 442 |
+
|
| 443 |
+
episode_rewards.append(episode_reward)
|
| 444 |
+
episode_lengths.append(step + 1)
|
| 445 |
+
self.training_stats['episode_rewards'].append(episode_reward)
|
| 446 |
+
|
| 447 |
+
# 更新网络
|
| 448 |
+
if (episode + 1) % update_frequency == 0:
|
| 449 |
+
policy_loss, value_loss, entropy = self.update()
|
| 450 |
+
self.training_stats['policy_loss'].append(policy_loss)
|
| 451 |
+
self.training_stats['value_loss'].append(value_loss)
|
| 452 |
+
self.training_stats['entropy'].append(entropy)
|
| 453 |
+
|
| 454 |
+
# 打印进度
|
| 455 |
+
if (episode + 1) % 20 == 0:
|
| 456 |
+
avg_reward = np.mean(episode_rewards[-20:])
|
| 457 |
+
print(f"A2C - 回合 {episode + 1:4d} | "
|
| 458 |
+
f"奖励: {episode_reward:5.1f} | "
|
| 459 |
+
f"平均奖励: {avg_reward:5.1f}")
|
| 460 |
+
|
| 461 |
+
return episode_rewards
|
| 462 |
+
|
| 463 |
+
|
| 464 |
+
# ============== 3. 可视化对比 ==============
|
| 465 |
+
def compare_algorithms():
|
| 466 |
+
"""对比基础Actor-Critic和A2C"""
|
| 467 |
+
print("\n" + "=" * 70)
|
| 468 |
+
print("Actor-Critic算法对比实验")
|
| 469 |
+
print("=" * 70)
|
| 470 |
+
|
| 471 |
+
# 创建环境
|
| 472 |
+
env = gym.make('CartPole-v1')
|
| 473 |
+
|
| 474 |
+
# 1. 基础Actor-Critic
|
| 475 |
+
print("\n1. 训练基础Actor-Critic...")
|
| 476 |
+
ac_agent = ActorCriticAgent(env, actor_lr=1e-3, critic_lr=1e-3)
|
| 477 |
+
ac_rewards = ac_agent.train(num_episodes=300)
|
| 478 |
+
|
| 479 |
+
# 2. A2C
|
| 480 |
+
print("\n2. 训练A2C (带GAE和熵正则化)...")
|
| 481 |
+
a2c_agent = A2CAgent(env, learning_rate=3e-4)
|
| 482 |
+
a2c_rewards = a2c_agent.train(num_episodes=300)
|
| 483 |
+
|
| 484 |
+
# 可视化对比
|
| 485 |
+
fig, axes = plt.subplots(2, 3, figsize=(16, 10))
|
| 486 |
+
|
| 487 |
+
# 1. 奖励曲线对比
|
| 488 |
+
ax1 = axes[0, 0]
|
| 489 |
+
ax1.plot(ac_rewards, alpha=0.6, label='Actor-Critic', color='blue')
|
| 490 |
+
ax1.plot(a2c_rewards, alpha=0.6, label='A2C', color='red')
|
| 491 |
+
|
| 492 |
+
# 平滑曲线
|
| 493 |
+
window = 20
|
| 494 |
+
ac_smooth = np.convolve(ac_rewards, np.ones(window) / window, mode='valid')
|
| 495 |
+
a2c_smooth = np.convolve(a2c_rewards, np.ones(window) / window, mode='valid')
|
| 496 |
+
|
| 497 |
+
ax1.plot(range(window - 1, len(ac_smooth) + window - 1), ac_smooth,
|
| 498 |
+
'b-', linewidth=2, label='AC (平滑)')
|
| 499 |
+
ax1.plot(range(window - 1, len(a2c_smooth) + window - 1), a2c_smooth,
|
| 500 |
+
'r-', linewidth=2, label='A2C (平滑)')
|
| 501 |
+
|
| 502 |
+
ax1.set_xlabel('回合')
|
| 503 |
+
ax1.set_ylabel('总奖励')
|
| 504 |
+
ax1.set_title('训练奖励对比')
|
| 505 |
+
ax1.legend()
|
| 506 |
+
ax1.grid(True, alpha=0.3)
|
| 507 |
+
|
| 508 |
+
# 2. Actor损失对比
|
| 509 |
+
ax2 = axes[0, 1]
|
| 510 |
+
if hasattr(ac_agent.training_stats, 'actor_loss'):
|
| 511 |
+
ax2.plot(ac_agent.training_stats['actor_loss'],
|
| 512 |
+
label='Actor-Critic', color='blue', alpha=0.7)
|
| 513 |
+
ax2.plot(a2c_agent.training_stats['policy_loss'],
|
| 514 |
+
label='A2C', color='red', alpha=0.7)
|
| 515 |
+
ax2.set_xlabel('更新步')
|
| 516 |
+
ax2.set_ylabel('策略损失')
|
| 517 |
+
ax2.set_title('Actor损失对比')
|
| 518 |
+
ax2.legend()
|
| 519 |
+
ax2.grid(True, alpha=0.3)
|
| 520 |
+
|
| 521 |
+
# 3. Critic损失对比
|
| 522 |
+
ax3 = axes[0, 2]
|
| 523 |
+
if hasattr(ac_agent.training_stats, 'critic_loss'):
|
| 524 |
+
ax3.plot(ac_agent.training_stats['critic_loss'],
|
| 525 |
+
label='Actor-Critic', color='blue', alpha=0.7)
|
| 526 |
+
ax3.plot(a2c_agent.training_stats['value_loss'],
|
| 527 |
+
label='A2C', color='red', alpha=0.7)
|
| 528 |
+
ax3.set_xlabel('更新步')
|
| 529 |
+
ax3.set_ylabel('价值损失')
|
| 530 |
+
ax3.set_title('Critic损失对比')
|
| 531 |
+
ax3.legend()
|
| 532 |
+
ax3.grid(True, alpha=0.3)
|
| 533 |
+
|
| 534 |
+
# 4. 熵变化 (A2C特有)
|
| 535 |
+
ax4 = axes[1, 0]
|
| 536 |
+
if a2c_agent.training_stats['entropy']:
|
| 537 |
+
ax4.plot(a2c_agent.training_stats['entropy'],
|
| 538 |
+
color='green', linewidth=2)
|
| 539 |
+
ax4.set_xlabel('更新步')
|
| 540 |
+
ax4.set_ylabel('策略熵')
|
| 541 |
+
ax4.set_title('A2C - 策略熵变化')
|
| 542 |
+
ax4.grid(True, alpha=0.3)
|
| 543 |
+
|
| 544 |
+
# 5. 收敛速度箱线图
|
| 545 |
+
ax5 = axes[1, 1]
|
| 546 |
+
|
| 547 |
+
# 计算收敛回合数
|
| 548 |
+
def get_convergence_episode(rewards, threshold=450):
|
| 549 |
+
for i, r in enumerate(rewards):
|
| 550 |
+
if r >= threshold and np.mean(rewards[max(0, i - 10):i + 1]) >= threshold:
|
| 551 |
+
return i + 1
|
| 552 |
+
return len(rewards)
|
| 553 |
+
|
| 554 |
+
# 多次实验
|
| 555 |
+
n_trials = 10
|
| 556 |
+
ac_convergence = []
|
| 557 |
+
a2c_convergence = []
|
| 558 |
+
|
| 559 |
+
for trial in range(n_trials):
|
| 560 |
+
env_trial = gym.make('CartPole-v1')
|
| 561 |
+
|
| 562 |
+
ac_trial = ActorCriticAgent(env_trial, actor_lr=1e-3, critic_lr=1e-3)
|
| 563 |
+
ac_rewards_trial = ac_trial.train(num_episodes=200)
|
| 564 |
+
ac_convergence.append(get_convergence_episode(ac_rewards_trial))
|
| 565 |
+
|
| 566 |
+
a2c_trial = A2CAgent(env_trial, learning_rate=3e-4)
|
| 567 |
+
a2c_rewards_trial = a2c_trial.train(num_episodes=200)
|
| 568 |
+
a2c_convergence.append(get_convergence_episode(a2c_rewards_trial))
|
| 569 |
+
|
| 570 |
+
bp = ax5.boxplot([ac_convergence, a2c_convergence],
|
| 571 |
+
labels=['Actor-Critic', 'A2C'],
|
| 572 |
+
patch_artist=True)
|
| 573 |
+
bp['boxes'][0].set_facecolor('lightblue')
|
| 574 |
+
bp['boxes'][1].set_facecolor('lightcoral')
|
| 575 |
+
ax5.set_ylabel('收敛所需回合数')
|
| 576 |
+
ax5.set_title('收敛速度对比 (越低越好)')
|
| 577 |
+
ax5.grid(True, alpha=0.3)
|
| 578 |
+
|
| 579 |
+
# 6. 算法对比表格
|
| 580 |
+
ax6 = axes[1, 2]
|
| 581 |
+
ax6.axis('off')
|
| 582 |
+
|
| 583 |
+
# 创建对比表格
|
| 584 |
+
col_labels = ['算法', '更新方式', '优势估计', '熵正则', '收敛速度', '稳定性']
|
| 585 |
+
data = [
|
| 586 |
+
['Actor-Critic', '单步TD', 'TD误差', '无', '较慢', '中等'],
|
| 587 |
+
['A2C', '多步回报', 'GAE', '有', '快', '稳定']
|
| 588 |
+
]
|
| 589 |
+
|
| 590 |
+
table = ax6.table(cellText=data,
|
| 591 |
+
colLabels=col_labels,
|
| 592 |
+
cellLoc='center',
|
| 593 |
+
loc='center',
|
| 594 |
+
bbox=[0, 0, 1, 1])
|
| 595 |
+
table.auto_set_font_size(False)
|
| 596 |
+
table.set_fontsize(10)
|
| 597 |
+
table.scale(1, 1.5)
|
| 598 |
+
|
| 599 |
+
plt.suptitle('Actor-Critic vs A2C: 性能对比分析', fontsize=14, y=1.02)
|
| 600 |
+
plt.tight_layout()
|
| 601 |
+
plt.savefig('ac_vs_a2c_comparison.png', dpi=150, bbox_inches='tight')
|
| 602 |
+
plt.show()
|
| 603 |
+
|
| 604 |
+
# 打印总结
|
| 605 |
+
print("\n" + "=" * 70)
|
| 606 |
+
print("实验总结:")
|
| 607 |
+
print("=" * 70)
|
| 608 |
+
print(f"\n基础Actor-Critic:")
|
| 609 |
+
print(f" - 平均收敛回合: {np.mean(ac_convergence):.1f} ± {np.std(ac_convergence):.1f}")
|
| 610 |
+
print(f" - 最终平均奖励: {np.mean(ac_rewards[-50:]):.1f}")
|
| 611 |
+
|
| 612 |
+
print(f"\nA2C (带GAE和熵正则):")
|
| 613 |
+
print(f" - 平均收敛回合: {np.mean(a2c_convergence):.1f} ± {np.std(a2c_convergence):.1f}")
|
| 614 |
+
print(f" - 最终平均奖励: {np.mean(a2c_rewards[-50:]):.1f}")
|
| 615 |
+
|
| 616 |
+
improvement = (np.mean(ac_convergence) - np.mean(a2c_convergence)) / np.mean(ac_convergence) * 100
|
| 617 |
+
print(f"\nA2C收敛速度提升: {improvement:.1f}%")
|
| 618 |
+
|
| 619 |
+
return ac_agent, a2c_agent
|
| 620 |
+
|
| 621 |
+
|
| 622 |
+
# ============== 4. 主函数 ==============
|
| 623 |
+
def main():
|
| 624 |
+
"""主函数"""
|
| 625 |
+
print("=" * 70)
|
| 626 |
+
print("Actor-Critic 方法完整实现")
|
| 627 |
+
print("=" * 70)
|
| 628 |
+
|
| 629 |
+
# 参数设置
|
| 630 |
+
import argparse
|
| 631 |
+
parser = argparse.ArgumentParser(description='Actor-Critic算法实现')
|
| 632 |
+
parser.add_argument('--algo', type=str, default='a2c',
|
| 633 |
+
choices=['ac', 'a2c', 'compare'],
|
| 634 |
+
help='选择算法: ac (基础Actor-Critic), a2c, compare')
|
| 635 |
+
parser.add_argument('--episodes', type=int, default=5000,
|
| 636 |
+
help='训练回合数')
|
| 637 |
+
parser.add_argument('--render', action='store_true',
|
| 638 |
+
help='渲染环境')
|
| 639 |
+
args = parser.parse_args()
|
| 640 |
+
|
| 641 |
+
env = gym.make('CartPole-v1')
|
| 642 |
+
|
| 643 |
+
if args.algo == 'ac':
|
| 644 |
+
print("\n训练基础Actor-Critic...")
|
| 645 |
+
agent = ActorCriticAgent(env, render=args.render)
|
| 646 |
+
rewards = agent.train(num_episodes=args.episodes)
|
| 647 |
+
|
| 648 |
+
plt.figure(figsize=(10, 6))
|
| 649 |
+
plt.plot(rewards, alpha=0.6, label='Episode Reward')
|
| 650 |
+
|
| 651 |
+
# 平滑曲线
|
| 652 |
+
window = 20
|
| 653 |
+
if len(rewards) >= window:
|
| 654 |
+
smoothed = np.convolve(rewards, np.ones(window) / window, mode='valid')
|
| 655 |
+
plt.plot(range(window - 1, len(smoothed) + window - 1),
|
| 656 |
+
smoothed, 'r-', linewidth=2, label=f'Moving Avg (window={window})')
|
| 657 |
+
|
| 658 |
+
plt.xlabel('Episode')
|
| 659 |
+
plt.ylabel('Total Reward')
|
| 660 |
+
plt.title('Actor-Critic Training on CartPole')
|
| 661 |
+
plt.legend()
|
| 662 |
+
plt.grid(True, alpha=0.3)
|
| 663 |
+
plt.savefig('actor_critic_training.png', dpi=150)
|
| 664 |
+
plt.show()
|
| 665 |
+
|
| 666 |
+
elif args.algo == 'a2c':
|
| 667 |
+
print("\n训练A2C...")
|
| 668 |
+
agent = A2CAgent(env, learning_rate=3e-4)
|
| 669 |
+
rewards = agent.train(num_episodes=args.episodes)
|
| 670 |
+
|
| 671 |
+
fig, axes = plt.subplots(2, 2, figsize=(12, 8))
|
| 672 |
+
|
| 673 |
+
# 奖励曲线
|
| 674 |
+
axes[0, 0].plot(rewards, alpha=0.6, color='blue')
|
| 675 |
+
if len(rewards) >= 20:
|
| 676 |
+
smoothed = np.convolve(rewards, np.ones(20) / 20, mode='valid')
|
| 677 |
+
axes[0, 0].plot(range(19, len(smoothed) + 19), smoothed, 'r-', linewidth=2)
|
| 678 |
+
axes[0, 0].set_xlabel('Episode')
|
| 679 |
+
axes[0, 0].set_ylabel('Total Reward')
|
| 680 |
+
axes[0, 0].set_title('A2C Training Rewards')
|
| 681 |
+
axes[0, 0].grid(True, alpha=0.3)
|
| 682 |
+
|
| 683 |
+
# 策略损失
|
| 684 |
+
axes[0, 1].plot(agent.training_stats['policy_loss'], color='purple')
|
| 685 |
+
axes[0, 1].set_xlabel('Update Step')
|
| 686 |
+
axes[0, 1].set_ylabel('Policy Loss')
|
| 687 |
+
axes[0, 1].set_title('Policy Loss')
|
| 688 |
+
axes[0, 1].grid(True, alpha=0.3)
|
| 689 |
+
|
| 690 |
+
# 价值损失
|
| 691 |
+
axes[1, 0].plot(agent.training_stats['value_loss'], color='orange')
|
| 692 |
+
axes[1, 0].set_xlabel('Update Step')
|
| 693 |
+
axes[1, 0].set_ylabel('Value Loss')
|
| 694 |
+
axes[1, 0].set_title('Value Loss')
|
| 695 |
+
axes[1, 0].grid(True, alpha=0.3)
|
| 696 |
+
|
| 697 |
+
# 策略熵
|
| 698 |
+
axes[1, 1].plot(agent.training_stats['entropy'], color='green')
|
| 699 |
+
axes[1, 1].set_xlabel('Update Step')
|
| 700 |
+
axes[1, 1].set_ylabel('Policy Entropy')
|
| 701 |
+
axes[1, 1].set_title('Policy Entropy')
|
| 702 |
+
axes[1, 1].grid(True, alpha=0.3)
|
| 703 |
+
|
| 704 |
+
plt.tight_layout()
|
| 705 |
+
plt.savefig('a2c_training.png', dpi=150)
|
| 706 |
+
plt.show()
|
| 707 |
+
|
| 708 |
+
else: # compare
|
| 709 |
+
compare_algorithms()
|
| 710 |
+
|
| 711 |
+
env.close()
|
| 712 |
+
print("\n训练完成!")
|
| 713 |
+
|
| 714 |
+
|
| 715 |
+
if __name__ == "__main__":
|
| 716 |
+
main()
|
examples/tutorials/rl/cart_pole/step_2_ppo_clip.py
ADDED
|
@@ -0,0 +1,739 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""
|
| 4 |
+
🌱 **第一代:REINFORCE (1992)**
|
| 5 |
+
├── 核心创新:蒙特卡洛策略梯度
|
| 6 |
+
├── 公式:∇θ ∝ Σ ∇θ log π * Gₜ
|
| 7 |
+
├── 痛点:▸ 必须等回合结束 ▸ 方差极大 ▸ 样本效率低
|
| 8 |
+
└── 贡献:开创了策略梯度范式
|
| 9 |
+
|
| 10 |
+
🌿 **第二代:REINFORCE with Baseline (2000s)**
|
| 11 |
+
├── 核心创新:引入基线降低方差
|
| 12 |
+
├── 公式:∇θ ∝ Σ ∇θ log π * (Gₜ - b(s))
|
| 13 |
+
├── 痛点:▸ 仍需完整回合 ▸ 基线需要单独学习
|
| 14 |
+
└── 贡献:方差降低,训练更稳定
|
| 15 |
+
|
| 16 |
+
🍃 **第三代:Actor-Critic (2000s)**
|
| 17 |
+
├── 核心创新:单步更新,不再等待回合
|
| 18 |
+
├── 公式:∇θ ∝ ∇θ log π * (r + γV(s') - V(s))
|
| 19 |
+
├── 痛点:▸ 单步TD偏差大 ▸ 价值估计不准
|
| 20 |
+
└── 贡献:实现了真正的在线学习
|
| 21 |
+
|
| 22 |
+
🌳 **第四代:A2C/A3C (2016)**
|
| 23 |
+
├── 核心创新:优势函数 + 多步回报
|
| 24 |
+
├── 公式:∇θ ∝ ∇θ log π * Â(s,a)
|
| 25 |
+
├── GAE:Â = Σ (γλ)ᵏ δ_{t+k} ★滑动平均★
|
| 26 |
+
├── 痛点:▸ 更新步长敏感 ▸ 容易破坏策略
|
| 27 |
+
└── 贡献:GAE成为标准配置
|
| 28 |
+
|
| 29 |
+
🌲 **第五代:PPO (2017)**
|
| 30 |
+
├── 核心创新:**所有前人智慧的集大成**
|
| 31 |
+
├── 1️⃣ 继承AC/A2C:单步更新 + GAE
|
| 32 |
+
├── 2️⃣ 继承重要性采样:可以复用数据
|
| 33 |
+
├── 3️⃣ ✨ **独创:Clipped Surrogate Objective**
|
| 34 |
+
│ L = min(r(θ)Â, clip(r(θ), 1-ε, 1+ε)Â)
|
| 35 |
+
├── 4️⃣ ✨ **独创:自适应KL惩罚**
|
| 36 |
+
└── 贡献:**稳定、高效、易用,成为事实标准**
|
| 37 |
+
|
| 38 |
+
由于函数的优化步长受到优势的直接影响。当优步长过大时很容易直接跨过,导致优化失败。
|
| 39 |
+
虽然A2C 已经对历史优势进行了移动平均,但问题仍然存在。
|
| 40 |
+
尤其是当训练的早期价值函数还没有获得较好的训练,这种问题尤其容易出现。
|
| 41 |
+
|
| 42 |
+
clip 当新模型相比于旧模型动作的概率变化幅度 radio 已经较大时,则不再进行优化。
|
| 43 |
+
clip 操作会对 radio 值进行截断,会切断梯度反向传播的通道。
|
| 44 |
+
|
| 45 |
+
同时对动作的概率做熵最大化,目的是避免策略函数过早确定化,保持探索能力。
|
| 46 |
+
|
| 47 |
+
PPO-clip 使用radio * advantages
|
| 48 |
+
其中:ratio = probs / old_probs
|
| 49 |
+
而不是 A2C 中的 log_probs * advantages
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
"""
|
| 53 |
+
import gymnasium as gym
|
| 54 |
+
import torch
|
| 55 |
+
import torch.nn as nn
|
| 56 |
+
import torch.optim as optim
|
| 57 |
+
from torch.distributions import Categorical
|
| 58 |
+
import numpy as np
|
| 59 |
+
import matplotlib.pyplot as plt
|
| 60 |
+
from collections import deque
|
| 61 |
+
import time
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
# ============== 1. PPO-Clip Network ==============
|
| 65 |
+
class PPONetwork(nn.Module):
|
| 66 |
+
"""PPO Network: Shared feature extractor + Actor head + Critic head"""
|
| 67 |
+
|
| 68 |
+
def __init__(self, state_dim, action_dim, hidden_dim=64):
|
| 69 |
+
super(PPONetwork, self).__init__()
|
| 70 |
+
|
| 71 |
+
# Shared feature extractor
|
| 72 |
+
self.feature_layer = nn.Sequential(
|
| 73 |
+
nn.Linear(state_dim, hidden_dim),
|
| 74 |
+
nn.Tanh(),
|
| 75 |
+
nn.Linear(hidden_dim, hidden_dim),
|
| 76 |
+
nn.Tanh()
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
# Actor head: action probability distribution
|
| 80 |
+
self.actor = nn.Sequential(
|
| 81 |
+
nn.Linear(hidden_dim, action_dim),
|
| 82 |
+
nn.Softmax(dim=-1)
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
# Critic head: state value
|
| 86 |
+
self.critic = nn.Linear(hidden_dim, 1)
|
| 87 |
+
|
| 88 |
+
def forward(self, state):
|
| 89 |
+
features = self.feature_layer(state)
|
| 90 |
+
action_probs = self.actor(features)
|
| 91 |
+
state_value = self.critic(features)
|
| 92 |
+
return action_probs, state_value
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
# ============== 2. PPO-Clip Agent ==============
|
| 96 |
+
class PPOClipAgent:
|
| 97 |
+
"""PPO-Clip Agent with Gymnasium API"""
|
| 98 |
+
|
| 99 |
+
def __init__(self,
|
| 100 |
+
env,
|
| 101 |
+
learning_rate=3e-4,
|
| 102 |
+
gamma=0.99,
|
| 103 |
+
gae_lambda=0.95,
|
| 104 |
+
clip_epsilon=0.2,
|
| 105 |
+
entropy_coef=0.01,
|
| 106 |
+
value_coef=0.5,
|
| 107 |
+
max_grad_norm=0.5,
|
| 108 |
+
update_epochs=4,
|
| 109 |
+
mini_batch_size=64,
|
| 110 |
+
horizon=2048,
|
| 111 |
+
hidden_dim=64):
|
| 112 |
+
|
| 113 |
+
self.env = env
|
| 114 |
+
self.gamma = gamma
|
| 115 |
+
self.gae_lambda = gae_lambda
|
| 116 |
+
self.clip_epsilon = clip_epsilon
|
| 117 |
+
self.entropy_coef = entropy_coef
|
| 118 |
+
self.value_coef = value_coef
|
| 119 |
+
self.max_grad_norm = max_grad_norm
|
| 120 |
+
self.update_epochs = update_epochs
|
| 121 |
+
self.mini_batch_size = mini_batch_size
|
| 122 |
+
self.horizon = horizon
|
| 123 |
+
|
| 124 |
+
self.state_dim = env.observation_space.shape[0]
|
| 125 |
+
self.action_dim = env.action_space.n
|
| 126 |
+
|
| 127 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 128 |
+
|
| 129 |
+
self.network = PPONetwork(self.state_dim, self.action_dim, hidden_dim).to(self.device)
|
| 130 |
+
self.optimizer = optim.Adam(self.network.parameters(), lr=learning_rate)
|
| 131 |
+
|
| 132 |
+
self.reset_buffer()
|
| 133 |
+
|
| 134 |
+
self.training_stats = {
|
| 135 |
+
'episode_rewards': [],
|
| 136 |
+
'episode_lengths': [],
|
| 137 |
+
'policy_loss': [],
|
| 138 |
+
'value_loss': [],
|
| 139 |
+
'entropy': [],
|
| 140 |
+
'clip_fraction': [],
|
| 141 |
+
'explained_variance': []
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
+
# Statistics for table logging
|
| 145 |
+
self.recent_rewards = deque(maxlen=20)
|
| 146 |
+
self.recent_lengths = deque(maxlen=20)
|
| 147 |
+
|
| 148 |
+
def reset_buffer(self):
|
| 149 |
+
"""Reset experience buffer"""
|
| 150 |
+
self.buffer = {
|
| 151 |
+
'states': [],
|
| 152 |
+
'actions': [],
|
| 153 |
+
'rewards': [],
|
| 154 |
+
'next_states': [],
|
| 155 |
+
'dones': [],
|
| 156 |
+
'log_probs': [],
|
| 157 |
+
'values': []
|
| 158 |
+
}
|
| 159 |
+
|
| 160 |
+
def select_action(self, state, eval_mode=False):
|
| 161 |
+
"""Select action using current policy"""
|
| 162 |
+
state = torch.FloatTensor(state).unsqueeze(0).to(self.device)
|
| 163 |
+
|
| 164 |
+
with torch.no_grad():
|
| 165 |
+
action_probs, state_value = self.network(state)
|
| 166 |
+
|
| 167 |
+
m = Categorical(action_probs)
|
| 168 |
+
action = m.sample()
|
| 169 |
+
log_prob = m.log_prob(action)
|
| 170 |
+
|
| 171 |
+
if eval_mode:
|
| 172 |
+
action = torch.argmax(action_probs)
|
| 173 |
+
log_prob = m.log_prob(action)
|
| 174 |
+
|
| 175 |
+
return action.item(), log_prob.cpu().item(), state_value.cpu().item()
|
| 176 |
+
|
| 177 |
+
def store_transition(self, state, action, reward, next_state, done, log_prob, value):
|
| 178 |
+
"""Store one step of experience"""
|
| 179 |
+
self.buffer['states'].append(state)
|
| 180 |
+
self.buffer['actions'].append(action)
|
| 181 |
+
self.buffer['rewards'].append(reward)
|
| 182 |
+
self.buffer['next_states'].append(next_state)
|
| 183 |
+
self.buffer['dones'].append(done)
|
| 184 |
+
self.buffer['log_probs'].append(log_prob)
|
| 185 |
+
self.buffer['values'].append(value)
|
| 186 |
+
|
| 187 |
+
def compute_gae(self, rewards, values, next_values, dones):
|
| 188 |
+
"""Compute Generalized Advantage Estimation"""
|
| 189 |
+
advantages = []
|
| 190 |
+
gae = 0
|
| 191 |
+
|
| 192 |
+
for t in reversed(range(len(rewards))):
|
| 193 |
+
delta = rewards[t] + self.gamma * next_values[t] * (1 - dones[t]) - values[t]
|
| 194 |
+
gae = delta + self.gamma * self.gae_lambda * (1 - dones[t]) * gae
|
| 195 |
+
advantages.insert(0, gae)
|
| 196 |
+
|
| 197 |
+
return advantages
|
| 198 |
+
|
| 199 |
+
def update(self):
|
| 200 |
+
"""PPO-Clip update"""
|
| 201 |
+
if len(self.buffer['rewards']) < self.mini_batch_size:
|
| 202 |
+
return None
|
| 203 |
+
|
| 204 |
+
# Convert to tensors
|
| 205 |
+
states = torch.FloatTensor(np.array(self.buffer['states'])).to(self.device)
|
| 206 |
+
actions = torch.LongTensor(self.buffer['actions']).to(self.device)
|
| 207 |
+
old_log_probs = torch.FloatTensor(self.buffer['log_probs']).to(self.device)
|
| 208 |
+
|
| 209 |
+
values = torch.FloatTensor(self.buffer['values']).to(self.device)
|
| 210 |
+
next_states = torch.FloatTensor(np.array(self.buffer['next_states'])).to(self.device)
|
| 211 |
+
dones = torch.FloatTensor(self.buffer['dones']).to(self.device)
|
| 212 |
+
rewards = self.buffer['rewards']
|
| 213 |
+
|
| 214 |
+
# Compute next state values
|
| 215 |
+
with torch.no_grad():
|
| 216 |
+
_, next_values = self.network(next_states)
|
| 217 |
+
next_values = next_values.squeeze().cpu().numpy()
|
| 218 |
+
|
| 219 |
+
# Compute advantages and returns
|
| 220 |
+
advantages = self.compute_gae(
|
| 221 |
+
rewards,
|
| 222 |
+
values.cpu().numpy(),
|
| 223 |
+
next_values,
|
| 224 |
+
dones.cpu().numpy()
|
| 225 |
+
)
|
| 226 |
+
advantages = torch.FloatTensor(advantages).to(self.device)
|
| 227 |
+
returns = advantages + values
|
| 228 |
+
|
| 229 |
+
# Normalize advantages
|
| 230 |
+
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
|
| 231 |
+
|
| 232 |
+
total_policy_loss = 0
|
| 233 |
+
total_value_loss = 0
|
| 234 |
+
total_entropy = 0
|
| 235 |
+
total_clip_fraction = 0
|
| 236 |
+
|
| 237 |
+
dataset_size = len(states)
|
| 238 |
+
|
| 239 |
+
# Multiple epochs of PPO update
|
| 240 |
+
for epoch in range(self.update_epochs):
|
| 241 |
+
indices = np.random.permutation(dataset_size)
|
| 242 |
+
|
| 243 |
+
for start in range(0, dataset_size, self.mini_batch_size):
|
| 244 |
+
end = start + self.mini_batch_size
|
| 245 |
+
batch_indices = indices[start:end]
|
| 246 |
+
|
| 247 |
+
batch_states = states[batch_indices]
|
| 248 |
+
batch_actions = actions[batch_indices]
|
| 249 |
+
batch_old_log_probs = old_log_probs[batch_indices]
|
| 250 |
+
batch_advantages = advantages[batch_indices]
|
| 251 |
+
batch_returns = returns[batch_indices]
|
| 252 |
+
|
| 253 |
+
# Forward pass
|
| 254 |
+
action_probs, state_values = self.network(batch_states)
|
| 255 |
+
state_values = state_values.squeeze()
|
| 256 |
+
|
| 257 |
+
# Compute log probs and entropy
|
| 258 |
+
m = Categorical(action_probs)
|
| 259 |
+
log_probs = m.log_prob(batch_actions)
|
| 260 |
+
entropy = m.entropy().mean()
|
| 261 |
+
|
| 262 |
+
# Importance sampling ratio
|
| 263 |
+
ratio = torch.exp(log_probs - batch_old_log_probs)
|
| 264 |
+
|
| 265 |
+
# PPO-Clip objective
|
| 266 |
+
surr1 = ratio * batch_advantages
|
| 267 |
+
surr2 = torch.clamp(ratio, 1 - self.clip_epsilon, 1 + self.clip_epsilon) * batch_advantages
|
| 268 |
+
policy_loss = -torch.min(surr1, surr2).mean()
|
| 269 |
+
|
| 270 |
+
# Value loss
|
| 271 |
+
value_loss = nn.MSELoss()(state_values, batch_returns)
|
| 272 |
+
|
| 273 |
+
# Entropy loss
|
| 274 |
+
entropy_loss = -self.entropy_coef * entropy
|
| 275 |
+
|
| 276 |
+
# Total loss
|
| 277 |
+
total_loss = policy_loss + self.value_coef * value_loss + entropy_loss
|
| 278 |
+
|
| 279 |
+
# Backward pass
|
| 280 |
+
self.optimizer.zero_grad()
|
| 281 |
+
total_loss.backward()
|
| 282 |
+
torch.nn.utils.clip_grad_norm_(self.network.parameters(), self.max_grad_norm)
|
| 283 |
+
self.optimizer.step()
|
| 284 |
+
|
| 285 |
+
total_policy_loss += policy_loss.item()
|
| 286 |
+
total_value_loss += value_loss.item()
|
| 287 |
+
total_entropy += entropy.item()
|
| 288 |
+
|
| 289 |
+
# Compute clip fraction
|
| 290 |
+
with torch.no_grad():
|
| 291 |
+
clip_mask = (ratio < 1 - self.clip_epsilon) | (ratio > 1 + self.clip_epsilon)
|
| 292 |
+
clip_fraction = clip_mask.float().mean().item()
|
| 293 |
+
total_clip_fraction += clip_fraction
|
| 294 |
+
|
| 295 |
+
n_updates = self.update_epochs * (dataset_size // self.mini_batch_size + 1)
|
| 296 |
+
|
| 297 |
+
# Compute explained variance
|
| 298 |
+
with torch.no_grad():
|
| 299 |
+
_, pred_values = self.network(states)
|
| 300 |
+
pred_values = pred_values.squeeze()
|
| 301 |
+
explained_variance = 1 - torch.var(returns - pred_values) / (torch.var(returns) + 1e-8)
|
| 302 |
+
explained_variance = explained_variance.cpu().item()
|
| 303 |
+
|
| 304 |
+
stats = {
|
| 305 |
+
'policy_loss': total_policy_loss / n_updates,
|
| 306 |
+
'value_loss': total_value_loss / n_updates,
|
| 307 |
+
'entropy': total_entropy / n_updates,
|
| 308 |
+
'clip_fraction': total_clip_fraction / n_updates,
|
| 309 |
+
'explained_variance': explained_variance
|
| 310 |
+
}
|
| 311 |
+
|
| 312 |
+
self.reset_buffer()
|
| 313 |
+
|
| 314 |
+
return stats
|
| 315 |
+
|
| 316 |
+
def print_header(self):
|
| 317 |
+
"""Print table header"""
|
| 318 |
+
print("\n" + "=" * 130)
|
| 319 |
+
print(
|
| 320 |
+
f"{'Episode':>8} | {'Avg Reward':>12} | {'Avg Length':>12} | {'Policy Loss':>12} | {'Value Loss':>12} | {'Entropy':>10} | {'Clip%':>10} | {'Expl Var':>12} | {'Time':>10}")
|
| 321 |
+
print("-" * 130)
|
| 322 |
+
|
| 323 |
+
def print_row(self, episode, total_episodes, stats=None, elapsed_time=None):
|
| 324 |
+
"""Print table row"""
|
| 325 |
+
|
| 326 |
+
# Update statistics cache
|
| 327 |
+
if len(self.training_stats['episode_rewards']) > 0:
|
| 328 |
+
self.recent_rewards.append(self.training_stats['episode_rewards'][-1])
|
| 329 |
+
if len(self.training_stats['episode_lengths']) > 0:
|
| 330 |
+
self.recent_lengths.append(self.training_stats['episode_lengths'][-1])
|
| 331 |
+
|
| 332 |
+
# Calculate averages
|
| 333 |
+
avg_reward = np.mean(self.recent_rewards) if self.recent_rewards else 0
|
| 334 |
+
avg_length = np.mean(self.recent_lengths) if self.recent_lengths else 0
|
| 335 |
+
|
| 336 |
+
# Format output
|
| 337 |
+
if stats:
|
| 338 |
+
print(f"{episode:>8}/{total_episodes} | "
|
| 339 |
+
f"{avg_reward:>12.2f} | "
|
| 340 |
+
f"{avg_length:>12.1f} | "
|
| 341 |
+
f"{stats['policy_loss']:>12.4f} | "
|
| 342 |
+
f"{stats['value_loss']:>12.4f} | "
|
| 343 |
+
f"{stats['entropy']:>10.4f} | "
|
| 344 |
+
f"{stats['clip_fraction']:>10.3f} | "
|
| 345 |
+
f"{stats['explained_variance']:>12.3f} | "
|
| 346 |
+
f"{elapsed_time:>10.1f}s")
|
| 347 |
+
else:
|
| 348 |
+
print(f"{episode:>8}/{total_episodes} | "
|
| 349 |
+
f"{avg_reward:>12.2f} | "
|
| 350 |
+
f"{avg_length:>12.1f} | "
|
| 351 |
+
f"{'-':>12} | "
|
| 352 |
+
f"{'-':>12} | "
|
| 353 |
+
f"{'-':>10} | "
|
| 354 |
+
f"{'-':>10} | "
|
| 355 |
+
f"{'-':>12} | "
|
| 356 |
+
f"{elapsed_time:>10.1f}s")
|
| 357 |
+
|
| 358 |
+
def print_summary(self, total_time, num_episodes):
|
| 359 |
+
"""Print training summary"""
|
| 360 |
+
print("=" * 130)
|
| 361 |
+
print(f"\n🎯 Training completed! Episodes: {num_episodes}, Total time: {total_time:.1f}s")
|
| 362 |
+
|
| 363 |
+
if len(self.training_stats['episode_rewards']) >= 20:
|
| 364 |
+
final_avg_reward = np.mean(self.training_stats['episode_rewards'][-20:])
|
| 365 |
+
final_avg_length = np.mean(self.training_stats['episode_lengths'][-20:])
|
| 366 |
+
print(f"📊 Last 20 episodes - Avg Reward: {final_avg_reward:.2f}, Avg Length: {final_avg_length:.1f}")
|
| 367 |
+
|
| 368 |
+
if self.training_stats['policy_loss']:
|
| 369 |
+
print(f"📉 Final Policy Loss: {self.training_stats['policy_loss'][-1]:.4f}")
|
| 370 |
+
print(f"📉 Final Value Loss: {self.training_stats['value_loss'][-1]:.4f}")
|
| 371 |
+
print(f"🎲 Final Entropy: {self.training_stats['entropy'][-1]:.4f}")
|
| 372 |
+
print(f"✂️ Final Clip Fraction: {self.training_stats['clip_fraction'][-1]:.3f}")
|
| 373 |
+
print(f"📈 Final Explained Variance: {self.training_stats['explained_variance'][-1]:.3f}")
|
| 374 |
+
|
| 375 |
+
print("=" * 130)
|
| 376 |
+
|
| 377 |
+
def train(self, num_episodes=1000, max_steps_per_episode=500, log_interval=20):
|
| 378 |
+
"""Train PPO-Clip agent"""
|
| 379 |
+
|
| 380 |
+
print("\n" + "🚀" * 65)
|
| 381 |
+
print("PPO-Clip Training Started (Gymnasium API)")
|
| 382 |
+
print("🚀" * 65)
|
| 383 |
+
|
| 384 |
+
print(f"\n📋 Hyperparameters:")
|
| 385 |
+
print(f" Learning Rate: {self.optimizer.param_groups[0]['lr']:.6f}")
|
| 386 |
+
print(f" Gamma: {self.gamma:.2f}, GAE Lambda: {self.gae_lambda:.2f}")
|
| 387 |
+
print(f" Clip Epsilon: {self.clip_epsilon:.2f}, Update Epochs: {self.update_epochs}")
|
| 388 |
+
print(f" Mini-batch: {self.mini_batch_size}, Horizon: {self.horizon}")
|
| 389 |
+
print(f" Entropy Coef: {self.entropy_coef:.3f}, Value Coef: {self.value_coef:.1f}")
|
| 390 |
+
print(f" Device: {self.device}")
|
| 391 |
+
|
| 392 |
+
self.print_header()
|
| 393 |
+
|
| 394 |
+
total_steps = 0
|
| 395 |
+
episode = 0
|
| 396 |
+
start_time = time.time()
|
| 397 |
+
|
| 398 |
+
while episode < num_episodes:
|
| 399 |
+
# Gymnasium API: reset() returns state, info
|
| 400 |
+
state, _ = self.env.reset()
|
| 401 |
+
episode_reward = 0
|
| 402 |
+
episode_step = 0
|
| 403 |
+
|
| 404 |
+
while episode_step < max_steps_per_episode:
|
| 405 |
+
action, log_prob, value = self.select_action(state)
|
| 406 |
+
|
| 407 |
+
# Gymnasium API: step() returns next_state, reward, terminated, truncated, info
|
| 408 |
+
next_state, reward, terminated, truncated, _ = self.env.step(action)
|
| 409 |
+
done = terminated or truncated
|
| 410 |
+
|
| 411 |
+
self.store_transition(state, action, reward, next_state, done, log_prob, value)
|
| 412 |
+
|
| 413 |
+
episode_reward += reward
|
| 414 |
+
episode_step += 1
|
| 415 |
+
total_steps += 1
|
| 416 |
+
state = next_state
|
| 417 |
+
|
| 418 |
+
# Update when buffer is full or episode ends
|
| 419 |
+
if len(self.buffer['rewards']) >= self.horizon or done:
|
| 420 |
+
stats = self.update()
|
| 421 |
+
if stats:
|
| 422 |
+
self.training_stats['policy_loss'].append(stats['policy_loss'])
|
| 423 |
+
self.training_stats['value_loss'].append(stats['value_loss'])
|
| 424 |
+
self.training_stats['entropy'].append(stats['entropy'])
|
| 425 |
+
self.training_stats['clip_fraction'].append(stats['clip_fraction'])
|
| 426 |
+
self.training_stats['explained_variance'].append(stats['explained_variance'])
|
| 427 |
+
|
| 428 |
+
if done:
|
| 429 |
+
break
|
| 430 |
+
|
| 431 |
+
# Record episode statistics
|
| 432 |
+
self.training_stats['episode_rewards'].append(episode_reward)
|
| 433 |
+
self.training_stats['episode_lengths'].append(episode_step)
|
| 434 |
+
|
| 435 |
+
episode += 1
|
| 436 |
+
|
| 437 |
+
# Log periodically
|
| 438 |
+
if episode % log_interval == 0:
|
| 439 |
+
current_time = time.time()
|
| 440 |
+
elapsed = current_time - start_time
|
| 441 |
+
|
| 442 |
+
recent_stats = None
|
| 443 |
+
if self.training_stats['policy_loss']:
|
| 444 |
+
recent_stats = {
|
| 445 |
+
'policy_loss': self.training_stats['policy_loss'][-1],
|
| 446 |
+
'value_loss': self.training_stats['value_loss'][-1],
|
| 447 |
+
'entropy': self.training_stats['entropy'][-1],
|
| 448 |
+
'clip_fraction': self.training_stats['clip_fraction'][-1],
|
| 449 |
+
'explained_variance': self.training_stats['explained_variance'][-1]
|
| 450 |
+
}
|
| 451 |
+
|
| 452 |
+
self.print_row(episode, num_episodes, recent_stats, elapsed)
|
| 453 |
+
|
| 454 |
+
total_time = time.time() - start_time
|
| 455 |
+
self.print_summary(total_time, num_episodes)
|
| 456 |
+
|
| 457 |
+
return self.training_stats['episode_rewards'], self.training_stats['episode_lengths']
|
| 458 |
+
|
| 459 |
+
def save(self, path):
|
| 460 |
+
"""Save model"""
|
| 461 |
+
torch.save({
|
| 462 |
+
'network_state_dict': self.network.state_dict(),
|
| 463 |
+
'optimizer_state_dict': self.optimizer.state_dict(),
|
| 464 |
+
'training_stats': self.training_stats
|
| 465 |
+
}, path)
|
| 466 |
+
print(f"\n💾 Model saved to {path}")
|
| 467 |
+
|
| 468 |
+
def load(self, path):
|
| 469 |
+
"""Load model"""
|
| 470 |
+
checkpoint = torch.load(path)
|
| 471 |
+
self.network.load_state_dict(checkpoint['network_state_dict'])
|
| 472 |
+
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
| 473 |
+
self.training_stats = checkpoint['training_stats']
|
| 474 |
+
print(f"\n📂 Model loaded from {path}")
|
| 475 |
+
|
| 476 |
+
|
| 477 |
+
# ============== 3. Evaluation Function ==============
|
| 478 |
+
def evaluate_agent(agent, env, num_episodes=10, render=False):
|
| 479 |
+
"""Evaluate trained agent"""
|
| 480 |
+
print("\n" + "🎯" * 35)
|
| 481 |
+
print("Evaluation Started")
|
| 482 |
+
print("🎯" * 35)
|
| 483 |
+
|
| 484 |
+
print(f"\n{'Episode':^12} | {'Reward':^12} | {'Length':^12} | {'Avg Reward':^14}")
|
| 485 |
+
print("-" * 60)
|
| 486 |
+
|
| 487 |
+
episode_rewards = []
|
| 488 |
+
episode_lengths = []
|
| 489 |
+
|
| 490 |
+
for episode in range(num_episodes):
|
| 491 |
+
state, _ = env.reset()
|
| 492 |
+
episode_reward = 0
|
| 493 |
+
episode_step = 0
|
| 494 |
+
|
| 495 |
+
while True:
|
| 496 |
+
if render:
|
| 497 |
+
env.render()
|
| 498 |
+
time.sleep(0.02)
|
| 499 |
+
|
| 500 |
+
action, _, _ = agent.select_action(state, eval_mode=True)
|
| 501 |
+
next_state, reward, terminated, truncated, _ = env.step(action)
|
| 502 |
+
done = terminated or truncated
|
| 503 |
+
|
| 504 |
+
episode_reward += reward
|
| 505 |
+
episode_step += 1
|
| 506 |
+
state = next_state
|
| 507 |
+
|
| 508 |
+
if done:
|
| 509 |
+
break
|
| 510 |
+
|
| 511 |
+
episode_rewards.append(episode_reward)
|
| 512 |
+
episode_lengths.append(episode_step)
|
| 513 |
+
|
| 514 |
+
avg_so_far = np.mean(episode_rewards)
|
| 515 |
+
print(f"{episode + 1:^12} | {episode_reward:^12.1f} | {episode_step:^12} | {avg_so_far:^14.2f}")
|
| 516 |
+
|
| 517 |
+
print("-" * 60)
|
| 518 |
+
print(f"\n📊 Evaluation Results ({num_episodes} episodes):")
|
| 519 |
+
print(f" Avg Reward: {np.mean(episode_rewards):.2f} ± {np.std(episode_rewards):.2f}")
|
| 520 |
+
print(f" Avg Length: {np.mean(episode_lengths):.2f} ± {np.std(episode_lengths):.2f}")
|
| 521 |
+
print(f" Max Reward: {np.max(episode_rewards):.2f}")
|
| 522 |
+
print(f" Min Reward: {np.min(episode_rewards):.2f}")
|
| 523 |
+
print(f" Success Rate (>=475): {np.mean(np.array(episode_rewards) >= 475) * 100:.1f}%")
|
| 524 |
+
print("=" * 60)
|
| 525 |
+
|
| 526 |
+
return episode_rewards, episode_lengths
|
| 527 |
+
|
| 528 |
+
|
| 529 |
+
# ============== 4. Visualization Function (English Only) ==============
|
| 530 |
+
def plot_training_results(agent, save_path='ppo_training_results.png'):
|
| 531 |
+
"""Plot training results - English only, no Chinese font issues"""
|
| 532 |
+
|
| 533 |
+
stats = agent.training_stats
|
| 534 |
+
|
| 535 |
+
fig, axes = plt.subplots(2, 3, figsize=(18, 10))
|
| 536 |
+
fig.suptitle('PPO-Clip Training Results (CartPole-v1)', fontsize=16, y=1.02)
|
| 537 |
+
|
| 538 |
+
window = 20
|
| 539 |
+
|
| 540 |
+
# 1. Episode Rewards
|
| 541 |
+
ax1 = axes[0, 0]
|
| 542 |
+
rewards = stats['episode_rewards']
|
| 543 |
+
ax1.plot(rewards, alpha=0.3, color='blue', label='Raw Reward')
|
| 544 |
+
|
| 545 |
+
if len(rewards) >= window:
|
| 546 |
+
smoothed = np.convolve(rewards, np.ones(window) / window, mode='valid')
|
| 547 |
+
ax1.plot(range(window - 1, len(smoothed) + window - 1), smoothed,
|
| 548 |
+
'r-', linewidth=2, label=f'{window}-Episode MA')
|
| 549 |
+
|
| 550 |
+
ax1.axhline(y=500, color='green', linestyle='--', alpha=0.7, label='Target (500)')
|
| 551 |
+
ax1.set_xlabel('Episode')
|
| 552 |
+
ax1.set_ylabel('Total Reward')
|
| 553 |
+
ax1.set_title('Training Rewards')
|
| 554 |
+
ax1.legend()
|
| 555 |
+
ax1.grid(True, alpha=0.3)
|
| 556 |
+
|
| 557 |
+
# 2. Episode Lengths
|
| 558 |
+
ax2 = axes[0, 1]
|
| 559 |
+
lengths = stats['episode_lengths']
|
| 560 |
+
ax2.plot(lengths, alpha=0.3, color='orange', label='Episode Length')
|
| 561 |
+
|
| 562 |
+
if len(lengths) >= window:
|
| 563 |
+
smoothed_length = np.convolve(lengths, np.ones(window) / window, mode='valid')
|
| 564 |
+
ax2.plot(range(window - 1, len(smoothed_length) + window - 1), smoothed_length,
|
| 565 |
+
'r-', linewidth=2)
|
| 566 |
+
|
| 567 |
+
ax2.set_xlabel('Episode')
|
| 568 |
+
ax2.set_ylabel('Episode Length')
|
| 569 |
+
ax2.set_title('Episode Lengths')
|
| 570 |
+
ax2.grid(True, alpha=0.3)
|
| 571 |
+
|
| 572 |
+
# 3. Policy Loss
|
| 573 |
+
ax3 = axes[0, 2]
|
| 574 |
+
if stats['policy_loss']:
|
| 575 |
+
ax3.plot(stats['policy_loss'], color='purple', linewidth=1.5)
|
| 576 |
+
ax3.set_xlabel('Update Step')
|
| 577 |
+
ax3.set_ylabel('Policy Loss')
|
| 578 |
+
ax3.set_title('Policy Loss')
|
| 579 |
+
ax3.grid(True, alpha=0.3)
|
| 580 |
+
|
| 581 |
+
# 4. Value Loss
|
| 582 |
+
ax4 = axes[1, 0]
|
| 583 |
+
if stats['value_loss']:
|
| 584 |
+
ax4.plot(stats['value_loss'], color='brown', linewidth=1.5)
|
| 585 |
+
ax4.set_xlabel('Update Step')
|
| 586 |
+
ax4.set_ylabel('Value Loss')
|
| 587 |
+
ax4.set_title('Value Loss')
|
| 588 |
+
ax4.grid(True, alpha=0.3)
|
| 589 |
+
|
| 590 |
+
# 5. Policy Entropy
|
| 591 |
+
ax5 = axes[1, 1]
|
| 592 |
+
if stats['entropy']:
|
| 593 |
+
ax5.plot(stats['entropy'], color='green', linewidth=1.5)
|
| 594 |
+
ax5.set_xlabel('Update Step')
|
| 595 |
+
ax5.set_ylabel('Policy Entropy')
|
| 596 |
+
ax5.set_title('Policy Entropy (Exploration)')
|
| 597 |
+
ax5.grid(True, alpha=0.3)
|
| 598 |
+
|
| 599 |
+
# 6. Clip Fraction & Explained Variance
|
| 600 |
+
ax6 = axes[1, 2]
|
| 601 |
+
if stats['clip_fraction']:
|
| 602 |
+
ax6.plot(stats['clip_fraction'], color='red', linewidth=1.5, label='Clip Fraction')
|
| 603 |
+
ax6.set_xlabel('Update Step')
|
| 604 |
+
ax6.set_ylabel('Clip Fraction', color='red')
|
| 605 |
+
ax6.tick_params(axis='y', labelcolor='red')
|
| 606 |
+
ax6.grid(True, alpha=0.3)
|
| 607 |
+
|
| 608 |
+
ax6_twin = ax6.twinx()
|
| 609 |
+
if stats['explained_variance']:
|
| 610 |
+
ax6_twin.plot(stats['explained_variance'], color='blue', linewidth=1.5,
|
| 611 |
+
label='Explained Variance')
|
| 612 |
+
ax6_twin.set_ylabel('Explained Variance', color='blue')
|
| 613 |
+
ax6_twin.tick_params(axis='y', labelcolor='blue')
|
| 614 |
+
|
| 615 |
+
plt.tight_layout()
|
| 616 |
+
plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
| 617 |
+
plt.show()
|
| 618 |
+
print(f"\n📸 Training results saved to {save_path}")
|
| 619 |
+
|
| 620 |
+
|
| 621 |
+
# ============== 5. Main Function ==============
|
| 622 |
+
def main():
|
| 623 |
+
"""Main function"""
|
| 624 |
+
|
| 625 |
+
# Create environment (Gymnasium)
|
| 626 |
+
env = gym.make('CartPole-v1')
|
| 627 |
+
|
| 628 |
+
# PPO-Clip hyperparameters
|
| 629 |
+
config = {
|
| 630 |
+
'learning_rate': 3e-4,
|
| 631 |
+
'gamma': 0.99,
|
| 632 |
+
'gae_lambda': 0.95,
|
| 633 |
+
'clip_epsilon': 0.2,
|
| 634 |
+
'entropy_coef': 0.01,
|
| 635 |
+
'value_coef': 0.5,
|
| 636 |
+
'max_grad_norm': 0.5,
|
| 637 |
+
'update_epochs': 4,
|
| 638 |
+
'mini_batch_size': 64,
|
| 639 |
+
'horizon': 2048,
|
| 640 |
+
'hidden_dim': 64
|
| 641 |
+
}
|
| 642 |
+
|
| 643 |
+
print("\n" + "=" * 90)
|
| 644 |
+
print("PPO-Clip for CartPole-v1 (Gymnasium)")
|
| 645 |
+
print("=" * 90)
|
| 646 |
+
print("\n📋 Hyperparameters:")
|
| 647 |
+
for key, value in config.items():
|
| 648 |
+
print(f" {key:20}: {value}")
|
| 649 |
+
|
| 650 |
+
# Create PPO agent
|
| 651 |
+
agent = PPOClipAgent(env, **config)
|
| 652 |
+
|
| 653 |
+
try:
|
| 654 |
+
# Train
|
| 655 |
+
rewards, lengths = agent.train(num_episodes=5000, log_interval=20)
|
| 656 |
+
|
| 657 |
+
# Save model
|
| 658 |
+
agent.save('ppo_cartpole_gymnasium.pth')
|
| 659 |
+
|
| 660 |
+
# Plot results
|
| 661 |
+
plot_training_results(agent)
|
| 662 |
+
|
| 663 |
+
# Evaluate
|
| 664 |
+
print("\n")
|
| 665 |
+
eval_env = gym.make('CartPole-v1')
|
| 666 |
+
evaluate_agent(agent, eval_env, num_episodes=20, render=False)
|
| 667 |
+
|
| 668 |
+
# Demo (with rendering)
|
| 669 |
+
print("\n")
|
| 670 |
+
demo_env = gym.make('CartPole-v1', render_mode='human')
|
| 671 |
+
evaluate_agent(agent, demo_env, num_episodes=3, render=True)
|
| 672 |
+
|
| 673 |
+
except KeyboardInterrupt:
|
| 674 |
+
print("\n\n⚠️ Training interrupted, saving model...")
|
| 675 |
+
agent.save('ppo_cartpole_gymnasium_interrupted.pth')
|
| 676 |
+
|
| 677 |
+
finally:
|
| 678 |
+
env.close()
|
| 679 |
+
if 'eval_env' in locals():
|
| 680 |
+
eval_env.close()
|
| 681 |
+
if 'demo_env' in locals():
|
| 682 |
+
demo_env.close()
|
| 683 |
+
|
| 684 |
+
|
| 685 |
+
# ============== 6. Hyperparameter Sweep ==============
|
| 686 |
+
def hyperparameter_sweep():
|
| 687 |
+
"""Hyperparameter tuning experiment"""
|
| 688 |
+
|
| 689 |
+
print("\n" + "🔬" * 45)
|
| 690 |
+
print("PPO-Clip Hyperparameter Sweep")
|
| 691 |
+
print("🔬" * 45)
|
| 692 |
+
|
| 693 |
+
clip_values = [0.1, 0.2, 0.3]
|
| 694 |
+
results = {}
|
| 695 |
+
|
| 696 |
+
for clip_eps in clip_values:
|
| 697 |
+
print(f"\n📊 Testing clip_epsilon = {clip_eps}")
|
| 698 |
+
print("-" * 60)
|
| 699 |
+
|
| 700 |
+
env = gym.make('CartPole-v1')
|
| 701 |
+
agent = PPOClipAgent(
|
| 702 |
+
env,
|
| 703 |
+
learning_rate=3e-4,
|
| 704 |
+
clip_epsilon=clip_eps,
|
| 705 |
+
update_epochs=4,
|
| 706 |
+
horizon=2048,
|
| 707 |
+
mini_batch_size=64
|
| 708 |
+
)
|
| 709 |
+
|
| 710 |
+
rewards, _ = agent.train(num_episodes=2000, log_interval=20)
|
| 711 |
+
results[f'ε={clip_eps}'] = rewards
|
| 712 |
+
env.close()
|
| 713 |
+
|
| 714 |
+
# Plot comparison
|
| 715 |
+
plt.figure(figsize=(12, 6))
|
| 716 |
+
|
| 717 |
+
for name, rewards in results.items():
|
| 718 |
+
window = 20
|
| 719 |
+
smoothed = np.convolve(rewards, np.ones(window) / window, mode='valid')
|
| 720 |
+
plt.plot(range(window - 1, len(smoothed) + window - 1), smoothed,
|
| 721 |
+
linewidth=2, label=name)
|
| 722 |
+
|
| 723 |
+
plt.xlabel('Episode')
|
| 724 |
+
plt.ylabel(f'Avg Reward ({window}-Episode MA)')
|
| 725 |
+
plt.title('PPO-Clip: Different Clip Epsilon Comparison')
|
| 726 |
+
plt.legend()
|
| 727 |
+
plt.grid(True, alpha=0.3)
|
| 728 |
+
plt.savefig('ppo_clip_comparison.png', dpi=150)
|
| 729 |
+
plt.show()
|
| 730 |
+
|
| 731 |
+
return results
|
| 732 |
+
|
| 733 |
+
|
| 734 |
+
if __name__ == "__main__":
|
| 735 |
+
main()
|
| 736 |
+
# hyperparameter_sweep() # Uncomment to run hyperparameter sweep
|
| 737 |
+
|
| 738 |
+
|
| 739 |
+
|
examples/tutorials/rl/cart_pole/step_2_ppo_penalty.py
ADDED
|
@@ -0,0 +1,767 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""
|
| 4 |
+
KL散度损失就是为了直接约束新旧策略之间的变化程度。
|
| 5 |
+
|
| 6 |
+
使用 KL散度的好处:
|
| 7 |
+
penalty 更像是,不管优势多大,它总能将其进行可控的相对缩放。
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
clip
|
| 11 |
+
当新模型相比旧模型的变化幅度已较大时,clip以阻断优化,会切断梯度传导。
|
| 12 |
+
penalty
|
| 13 |
+
直接将新模型与旧模型的动作概率约束在一个 target_kl 附近,限制每一次迭代的优化幅度。
|
| 14 |
+
KL散度不会切断梯度传导,总是可以进行有效的优化。
|
| 15 |
+
|
| 16 |
+
"""
|
| 17 |
+
import gymnasium as gym
|
| 18 |
+
import torch
|
| 19 |
+
import torch.nn as nn
|
| 20 |
+
import torch.optim as optim
|
| 21 |
+
from torch.distributions import Categorical
|
| 22 |
+
import numpy as np
|
| 23 |
+
import matplotlib.pyplot as plt
|
| 24 |
+
from collections import deque
|
| 25 |
+
import time
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# ============== 1. PPO-Penalty Network ==============
|
| 29 |
+
class PPONetwork(nn.Module):
|
| 30 |
+
"""PPO Network: Shared feature extractor + Actor head + Critic head"""
|
| 31 |
+
|
| 32 |
+
def __init__(self, state_dim, action_dim, hidden_dim=64):
|
| 33 |
+
super(PPONetwork, self).__init__()
|
| 34 |
+
|
| 35 |
+
# Shared feature extractor
|
| 36 |
+
self.feature_layer = nn.Sequential(
|
| 37 |
+
nn.Linear(state_dim, hidden_dim),
|
| 38 |
+
nn.Tanh(),
|
| 39 |
+
nn.Linear(hidden_dim, hidden_dim),
|
| 40 |
+
nn.Tanh()
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
# Actor head: action probability distribution
|
| 44 |
+
self.actor = nn.Sequential(
|
| 45 |
+
nn.Linear(hidden_dim, action_dim),
|
| 46 |
+
nn.Softmax(dim=-1)
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
# Critic head: state value
|
| 50 |
+
self.critic = nn.Linear(hidden_dim, 1)
|
| 51 |
+
|
| 52 |
+
def forward(self, state):
|
| 53 |
+
features = self.feature_layer(state)
|
| 54 |
+
action_probs = self.actor(features)
|
| 55 |
+
state_value = self.critic(features)
|
| 56 |
+
return action_probs, state_value
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
# ============== 2. PPO-Penalty Agent ==============
|
| 60 |
+
class PPOPenaltyAgent:
|
| 61 |
+
"""PPO-Penalty (Adaptive KL Penalty) with Gymnasium API"""
|
| 62 |
+
|
| 63 |
+
def __init__(self,
|
| 64 |
+
env,
|
| 65 |
+
learning_rate=3e-4,
|
| 66 |
+
gamma=0.99,
|
| 67 |
+
gae_lambda=0.95,
|
| 68 |
+
kl_target=0.01, # Target KL divergence
|
| 69 |
+
kl_coef_init=1.0, # Initial KL penalty coefficient
|
| 70 |
+
kl_coef_adapt=1.5, # KL coefficient adaptation rate
|
| 71 |
+
entropy_coef=0.01,
|
| 72 |
+
value_coef=0.5,
|
| 73 |
+
max_grad_norm=0.5,
|
| 74 |
+
update_epochs=10, # PPO-Penalty typically uses more epochs
|
| 75 |
+
mini_batch_size=64,
|
| 76 |
+
horizon=2048,
|
| 77 |
+
hidden_dim=64):
|
| 78 |
+
|
| 79 |
+
self.env = env
|
| 80 |
+
self.gamma = gamma
|
| 81 |
+
self.gae_lambda = gae_lambda
|
| 82 |
+
self.kl_target = kl_target
|
| 83 |
+
self.kl_coef = kl_coef_init
|
| 84 |
+
self.kl_coef_adapt = kl_coef_adapt
|
| 85 |
+
self.entropy_coef = entropy_coef
|
| 86 |
+
self.value_coef = value_coef
|
| 87 |
+
self.max_grad_norm = max_grad_norm
|
| 88 |
+
self.update_epochs = update_epochs
|
| 89 |
+
self.mini_batch_size = mini_batch_size
|
| 90 |
+
self.horizon = horizon
|
| 91 |
+
|
| 92 |
+
self.state_dim = env.observation_space.shape[0]
|
| 93 |
+
self.action_dim = env.action_space.n
|
| 94 |
+
|
| 95 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 96 |
+
|
| 97 |
+
# Policy network
|
| 98 |
+
self.policy = PPONetwork(self.state_dim, self.action_dim, hidden_dim).to(self.device)
|
| 99 |
+
self.optimizer = optim.Adam(self.policy.parameters(), lr=learning_rate)
|
| 100 |
+
|
| 101 |
+
# Old policy for KL calculation
|
| 102 |
+
self.old_policy = PPONetwork(self.state_dim, self.action_dim, hidden_dim).to(self.device)
|
| 103 |
+
self.update_old_policy()
|
| 104 |
+
|
| 105 |
+
self.reset_buffer()
|
| 106 |
+
|
| 107 |
+
self.training_stats = {
|
| 108 |
+
'episode_rewards': [],
|
| 109 |
+
'episode_lengths': [],
|
| 110 |
+
'policy_loss': [],
|
| 111 |
+
'value_loss': [],
|
| 112 |
+
'entropy': [],
|
| 113 |
+
'kl_divergence': [],
|
| 114 |
+
'kl_coef': [],
|
| 115 |
+
'explained_variance': []
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
# Statistics for table logging
|
| 119 |
+
self.recent_rewards = deque(maxlen=20)
|
| 120 |
+
self.recent_lengths = deque(maxlen=20)
|
| 121 |
+
|
| 122 |
+
def update_old_policy(self):
|
| 123 |
+
"""Copy current policy to old policy"""
|
| 124 |
+
self.old_policy.load_state_dict(self.policy.state_dict())
|
| 125 |
+
|
| 126 |
+
def reset_buffer(self):
|
| 127 |
+
"""Reset experience buffer"""
|
| 128 |
+
self.buffer = {
|
| 129 |
+
'states': [],
|
| 130 |
+
'actions': [],
|
| 131 |
+
'rewards': [],
|
| 132 |
+
'next_states': [],
|
| 133 |
+
'dones': [],
|
| 134 |
+
'log_probs': [],
|
| 135 |
+
'values': []
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
def select_action(self, state, eval_mode=False):
|
| 139 |
+
"""Select action using current policy"""
|
| 140 |
+
state = torch.FloatTensor(state).unsqueeze(0).to(self.device)
|
| 141 |
+
|
| 142 |
+
with torch.no_grad():
|
| 143 |
+
action_probs, state_value = self.policy(state)
|
| 144 |
+
|
| 145 |
+
m = Categorical(action_probs)
|
| 146 |
+
action = m.sample()
|
| 147 |
+
log_prob = m.log_prob(action)
|
| 148 |
+
|
| 149 |
+
if eval_mode:
|
| 150 |
+
action = torch.argmax(action_probs)
|
| 151 |
+
log_prob = m.log_prob(action)
|
| 152 |
+
|
| 153 |
+
return action.item(), log_prob.cpu().item(), state_value.cpu().item()
|
| 154 |
+
|
| 155 |
+
def store_transition(self, state, action, reward, next_state, done, log_prob, value):
|
| 156 |
+
"""Store one step of experience"""
|
| 157 |
+
self.buffer['states'].append(state)
|
| 158 |
+
self.buffer['actions'].append(action)
|
| 159 |
+
self.buffer['rewards'].append(reward)
|
| 160 |
+
self.buffer['next_states'].append(next_state)
|
| 161 |
+
self.buffer['dones'].append(done)
|
| 162 |
+
self.buffer['log_probs'].append(log_prob)
|
| 163 |
+
self.buffer['values'].append(value)
|
| 164 |
+
|
| 165 |
+
def compute_gae(self, rewards, values, next_values, dones):
|
| 166 |
+
"""Compute Generalized Advantage Estimation"""
|
| 167 |
+
advantages = []
|
| 168 |
+
gae = 0
|
| 169 |
+
|
| 170 |
+
for t in reversed(range(len(rewards))):
|
| 171 |
+
delta = rewards[t] + self.gamma * next_values[t] * (1 - dones[t]) - values[t]
|
| 172 |
+
gae = delta + self.gamma * self.gae_lambda * (1 - dones[t]) * gae
|
| 173 |
+
advantages.insert(0, gae)
|
| 174 |
+
|
| 175 |
+
return advantages
|
| 176 |
+
|
| 177 |
+
def compute_kl_divergence(self, states, actions):
|
| 178 |
+
"""Compute KL divergence between old and new policy"""
|
| 179 |
+
with torch.no_grad():
|
| 180 |
+
# Get old policy distributions
|
| 181 |
+
old_probs, _ = self.old_policy(states)
|
| 182 |
+
old_m = Categorical(old_probs)
|
| 183 |
+
|
| 184 |
+
# Get new policy distributions
|
| 185 |
+
new_probs, _ = self.policy(states)
|
| 186 |
+
new_m = Categorical(new_probs)
|
| 187 |
+
|
| 188 |
+
# Compute KL divergence
|
| 189 |
+
kl = torch.distributions.kl.kl_divergence(old_m, new_m).mean()
|
| 190 |
+
|
| 191 |
+
return kl.item()
|
| 192 |
+
|
| 193 |
+
def update(self):
|
| 194 |
+
"""PPO-Penalty update with adaptive KL penalty"""
|
| 195 |
+
if len(self.buffer['rewards']) < self.mini_batch_size:
|
| 196 |
+
return None
|
| 197 |
+
|
| 198 |
+
# Convert to tensors
|
| 199 |
+
states = torch.FloatTensor(np.array(self.buffer['states'])).to(self.device)
|
| 200 |
+
actions = torch.LongTensor(self.buffer['actions']).to(self.device)
|
| 201 |
+
old_log_probs = torch.FloatTensor(self.buffer['log_probs']).to(self.device)
|
| 202 |
+
|
| 203 |
+
values = torch.FloatTensor(self.buffer['values']).to(self.device)
|
| 204 |
+
next_states = torch.FloatTensor(np.array(self.buffer['next_states'])).to(self.device)
|
| 205 |
+
dones = torch.FloatTensor(self.buffer['dones']).to(self.device)
|
| 206 |
+
rewards = self.buffer['rewards']
|
| 207 |
+
|
| 208 |
+
# Compute next state values
|
| 209 |
+
with torch.no_grad():
|
| 210 |
+
_, next_values = self.policy(next_states)
|
| 211 |
+
next_values = next_values.squeeze().cpu().numpy()
|
| 212 |
+
|
| 213 |
+
# Compute advantages and returns
|
| 214 |
+
advantages = self.compute_gae(
|
| 215 |
+
rewards,
|
| 216 |
+
values.cpu().numpy(),
|
| 217 |
+
next_values,
|
| 218 |
+
dones.cpu().numpy()
|
| 219 |
+
)
|
| 220 |
+
advantages = torch.FloatTensor(advantages).to(self.device)
|
| 221 |
+
returns = advantages + values
|
| 222 |
+
|
| 223 |
+
# Normalize advantages
|
| 224 |
+
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
|
| 225 |
+
|
| 226 |
+
total_policy_loss = 0
|
| 227 |
+
total_value_loss = 0
|
| 228 |
+
total_entropy = 0
|
| 229 |
+
total_kl = 0
|
| 230 |
+
|
| 231 |
+
dataset_size = len(states)
|
| 232 |
+
|
| 233 |
+
# Multiple epochs of PPO update
|
| 234 |
+
for epoch in range(self.update_epochs):
|
| 235 |
+
indices = np.random.permutation(dataset_size)
|
| 236 |
+
|
| 237 |
+
for start in range(0, dataset_size, self.mini_batch_size):
|
| 238 |
+
end = start + self.mini_batch_size
|
| 239 |
+
batch_indices = indices[start:end]
|
| 240 |
+
|
| 241 |
+
batch_states = states[batch_indices]
|
| 242 |
+
batch_actions = actions[batch_indices]
|
| 243 |
+
batch_old_log_probs = old_log_probs[batch_indices]
|
| 244 |
+
batch_advantages = advantages[batch_indices]
|
| 245 |
+
batch_returns = returns[batch_indices]
|
| 246 |
+
|
| 247 |
+
# Forward pass
|
| 248 |
+
action_probs, state_values = self.policy(batch_states)
|
| 249 |
+
state_values = state_values.squeeze()
|
| 250 |
+
|
| 251 |
+
# Compute log probs and entropy
|
| 252 |
+
m = Categorical(action_probs)
|
| 253 |
+
log_probs = m.log_prob(batch_actions)
|
| 254 |
+
entropy = m.entropy().mean()
|
| 255 |
+
|
| 256 |
+
# Importance sampling ratio
|
| 257 |
+
ratio = torch.exp(log_probs - batch_old_log_probs)
|
| 258 |
+
|
| 259 |
+
# PPO-Penalty objective (with KL penalty, no clipping!)
|
| 260 |
+
policy_loss = -(ratio * batch_advantages).mean()
|
| 261 |
+
|
| 262 |
+
# Compute KL divergence for this batch
|
| 263 |
+
with torch.no_grad():
|
| 264 |
+
old_probs, _ = self.old_policy(batch_states)
|
| 265 |
+
old_m = Categorical(old_probs)
|
| 266 |
+
kl_batch = torch.distributions.kl.kl_divergence(old_m, m).mean()
|
| 267 |
+
total_kl += kl_batch.item()
|
| 268 |
+
|
| 269 |
+
# Add KL penalty to policy loss
|
| 270 |
+
policy_loss_penalized = policy_loss + self.kl_coef * kl_batch
|
| 271 |
+
|
| 272 |
+
# Value loss
|
| 273 |
+
value_loss = nn.MSELoss()(state_values, batch_returns)
|
| 274 |
+
|
| 275 |
+
# Entropy loss (encourage exploration)
|
| 276 |
+
entropy_loss = -self.entropy_coef * entropy
|
| 277 |
+
|
| 278 |
+
# Total loss
|
| 279 |
+
total_loss = policy_loss_penalized + self.value_coef * value_loss + entropy_loss
|
| 280 |
+
|
| 281 |
+
# Backward pass
|
| 282 |
+
self.optimizer.zero_grad()
|
| 283 |
+
total_loss.backward()
|
| 284 |
+
torch.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
|
| 285 |
+
self.optimizer.step()
|
| 286 |
+
|
| 287 |
+
total_policy_loss += policy_loss.item()
|
| 288 |
+
total_value_loss += value_loss.item()
|
| 289 |
+
total_entropy += entropy.item()
|
| 290 |
+
|
| 291 |
+
# Compute average KL divergence
|
| 292 |
+
n_updates = self.update_epochs * (dataset_size // self.mini_batch_size + 1)
|
| 293 |
+
avg_kl = total_kl / n_updates
|
| 294 |
+
|
| 295 |
+
# Adapt KL coefficient (核心:自适应调整KL惩罚系数)
|
| 296 |
+
if avg_kl < self.kl_target / 1.5:
|
| 297 |
+
# KL too small -> reduce penalty
|
| 298 |
+
self.kl_coef /= self.kl_coef_adapt
|
| 299 |
+
elif avg_kl > self.kl_target * 1.5:
|
| 300 |
+
# KL too large -> increase penalty
|
| 301 |
+
self.kl_coef *= self.kl_coef_adapt
|
| 302 |
+
|
| 303 |
+
# Keep KL coefficient in reasonable range
|
| 304 |
+
self.kl_coef = np.clip(self.kl_coef, 1e-10, 10.0)
|
| 305 |
+
|
| 306 |
+
# Compute explained variance
|
| 307 |
+
with torch.no_grad():
|
| 308 |
+
_, pred_values = self.policy(states)
|
| 309 |
+
pred_values = pred_values.squeeze()
|
| 310 |
+
explained_variance = 1 - torch.var(returns - pred_values) / (torch.var(returns) + 1e-8)
|
| 311 |
+
explained_variance = explained_variance.cpu().item()
|
| 312 |
+
|
| 313 |
+
stats = {
|
| 314 |
+
'policy_loss': total_policy_loss / n_updates,
|
| 315 |
+
'value_loss': total_value_loss / n_updates,
|
| 316 |
+
'entropy': total_entropy / n_updates,
|
| 317 |
+
'kl_divergence': avg_kl,
|
| 318 |
+
'kl_coef': self.kl_coef,
|
| 319 |
+
'explained_variance': explained_variance
|
| 320 |
+
}
|
| 321 |
+
|
| 322 |
+
# Update old policy
|
| 323 |
+
self.update_old_policy()
|
| 324 |
+
self.reset_buffer()
|
| 325 |
+
|
| 326 |
+
return stats
|
| 327 |
+
|
| 328 |
+
def print_header(self):
|
| 329 |
+
"""Print table header"""
|
| 330 |
+
print("\n" + "=" * 150)
|
| 331 |
+
print(
|
| 332 |
+
f"{'Episode':>8} | {'Avg Reward':>12} | {'Avg Length':>12} | {'Policy Loss':>12} | {'Value Loss':>12} | {'Entropy':>10} | {'KL Div':>10} | {'KL Coef':>10} | {'Expl Var':>12} | {'Time':>10}")
|
| 333 |
+
print("-" * 150)
|
| 334 |
+
|
| 335 |
+
def print_row(self, episode, total_episodes, stats=None, elapsed_time=None):
|
| 336 |
+
"""Print table row"""
|
| 337 |
+
|
| 338 |
+
# Update statistics cache
|
| 339 |
+
if len(self.training_stats['episode_rewards']) > 0:
|
| 340 |
+
self.recent_rewards.append(self.training_stats['episode_rewards'][-1])
|
| 341 |
+
if len(self.training_stats['episode_lengths']) > 0:
|
| 342 |
+
self.recent_lengths.append(self.training_stats['episode_lengths'][-1])
|
| 343 |
+
|
| 344 |
+
# Calculate averages
|
| 345 |
+
avg_reward = np.mean(self.recent_rewards) if self.recent_rewards else 0
|
| 346 |
+
avg_length = np.mean(self.recent_lengths) if self.recent_lengths else 0
|
| 347 |
+
|
| 348 |
+
# Format output
|
| 349 |
+
if stats:
|
| 350 |
+
print(f"{episode:>8}/{total_episodes} | "
|
| 351 |
+
f"{avg_reward:>12.2f} | "
|
| 352 |
+
f"{avg_length:>12.1f} | "
|
| 353 |
+
f"{stats['policy_loss']:>12.4f} | "
|
| 354 |
+
f"{stats['value_loss']:>12.4f} | "
|
| 355 |
+
f"{stats['entropy']:>10.4f} | "
|
| 356 |
+
f"{stats['kl_divergence']:>10.6f} | "
|
| 357 |
+
f"{stats['kl_coef']:>10.6f} | "
|
| 358 |
+
f"{stats['explained_variance']:>12.3f} | "
|
| 359 |
+
f"{elapsed_time:>10.1f}s")
|
| 360 |
+
else:
|
| 361 |
+
print(f"{episode:>8}/{total_episodes} | "
|
| 362 |
+
f"{avg_reward:>12.2f} | "
|
| 363 |
+
f"{avg_length:>12.1f} | "
|
| 364 |
+
f"{'-':>12} | "
|
| 365 |
+
f"{'-':>12} | "
|
| 366 |
+
f"{'-':>10} | "
|
| 367 |
+
f"{'-':>10} | "
|
| 368 |
+
f"{'-':>10} | "
|
| 369 |
+
f"{'-':>12} | "
|
| 370 |
+
f"{elapsed_time:>10.1f}s")
|
| 371 |
+
|
| 372 |
+
def print_summary(self, total_time, num_episodes):
|
| 373 |
+
"""Print training summary"""
|
| 374 |
+
print("=" * 150)
|
| 375 |
+
print(f"\n🎯 Training completed! Episodes: {num_episodes}, Total time: {total_time:.1f}s")
|
| 376 |
+
|
| 377 |
+
if len(self.training_stats['episode_rewards']) >= 20:
|
| 378 |
+
final_avg_reward = np.mean(self.training_stats['episode_rewards'][-20:])
|
| 379 |
+
final_avg_length = np.mean(self.training_stats['episode_lengths'][-20:])
|
| 380 |
+
print(f"📊 Last 20 episodes - Avg Reward: {final_avg_reward:.2f}, Avg Length: {final_avg_length:.1f}")
|
| 381 |
+
|
| 382 |
+
if self.training_stats['policy_loss']:
|
| 383 |
+
print(f"📉 Final Policy Loss: {self.training_stats['policy_loss'][-1]:.4f}")
|
| 384 |
+
print(f"📉 Final Value Loss: {self.training_stats['value_loss'][-1]:.4f}")
|
| 385 |
+
print(f"🎲 Final Entropy: {self.training_stats['entropy'][-1]:.4f}")
|
| 386 |
+
print(f"📏 Final KL Divergence: {self.training_stats['kl_divergence'][-1]:.6f}")
|
| 387 |
+
print(f"⚖️ Final KL Coefficient: {self.training_stats['kl_coef'][-1]:.6f}")
|
| 388 |
+
print(f"📈 Final Explained Variance: {self.training_stats['explained_variance'][-1]:.3f}")
|
| 389 |
+
|
| 390 |
+
print("=" * 150)
|
| 391 |
+
|
| 392 |
+
def train(self, num_episodes=1000, max_steps_per_episode=500, log_interval=20):
|
| 393 |
+
"""Train PPO-Penalty agent"""
|
| 394 |
+
|
| 395 |
+
print("\n" + "🚀" * 75)
|
| 396 |
+
print("PPO-Penalty (Adaptive KL) Training Started - Gymnasium API")
|
| 397 |
+
print("🚀" * 75)
|
| 398 |
+
|
| 399 |
+
print(f"\n📋 Hyperparameters:")
|
| 400 |
+
print(f" Learning Rate: {self.optimizer.param_groups[0]['lr']:.6f}")
|
| 401 |
+
print(f" Gamma: {self.gamma:.2f}, GAE Lambda: {self.gae_lambda:.2f}")
|
| 402 |
+
print(f" KL Target: {self.kl_target:.4f}, KL Coef Init: {self.kl_coef:.3f}")
|
| 403 |
+
print(f" KL Adapt Rate: {self.kl_coef_adapt:.2f}")
|
| 404 |
+
print(f" Update Epochs: {self.update_epochs}")
|
| 405 |
+
print(f" Mini-batch: {self.mini_batch_size}, Horizon: {self.horizon}")
|
| 406 |
+
print(f" Entropy Coef: {self.entropy_coef:.3f}, Value Coef: {self.value_coef:.1f}")
|
| 407 |
+
print(f" Device: {self.device}")
|
| 408 |
+
|
| 409 |
+
self.print_header()
|
| 410 |
+
|
| 411 |
+
total_steps = 0
|
| 412 |
+
episode = 0
|
| 413 |
+
start_time = time.time()
|
| 414 |
+
|
| 415 |
+
while episode < num_episodes:
|
| 416 |
+
# Gymnasium API: reset() returns state, info
|
| 417 |
+
state, _ = self.env.reset()
|
| 418 |
+
episode_reward = 0
|
| 419 |
+
episode_step = 0
|
| 420 |
+
|
| 421 |
+
while episode_step < max_steps_per_episode:
|
| 422 |
+
action, log_prob, value = self.select_action(state)
|
| 423 |
+
|
| 424 |
+
# Gymnasium API: step() returns next_state, reward, terminated, truncated, info
|
| 425 |
+
next_state, reward, terminated, truncated, _ = self.env.step(action)
|
| 426 |
+
done = terminated or truncated
|
| 427 |
+
|
| 428 |
+
self.store_transition(state, action, reward, next_state, done, log_prob, value)
|
| 429 |
+
|
| 430 |
+
episode_reward += reward
|
| 431 |
+
episode_step += 1
|
| 432 |
+
total_steps += 1
|
| 433 |
+
state = next_state
|
| 434 |
+
|
| 435 |
+
# Update when buffer is full or episode ends
|
| 436 |
+
if len(self.buffer['rewards']) >= self.horizon or done:
|
| 437 |
+
stats = self.update()
|
| 438 |
+
if stats:
|
| 439 |
+
self.training_stats['policy_loss'].append(stats['policy_loss'])
|
| 440 |
+
self.training_stats['value_loss'].append(stats['value_loss'])
|
| 441 |
+
self.training_stats['entropy'].append(stats['entropy'])
|
| 442 |
+
self.training_stats['kl_divergence'].append(stats['kl_divergence'])
|
| 443 |
+
self.training_stats['kl_coef'].append(stats['kl_coef'])
|
| 444 |
+
self.training_stats['explained_variance'].append(stats['explained_variance'])
|
| 445 |
+
|
| 446 |
+
if done:
|
| 447 |
+
break
|
| 448 |
+
|
| 449 |
+
# Record episode statistics
|
| 450 |
+
self.training_stats['episode_rewards'].append(episode_reward)
|
| 451 |
+
self.training_stats['episode_lengths'].append(episode_step)
|
| 452 |
+
|
| 453 |
+
episode += 1
|
| 454 |
+
|
| 455 |
+
# Log periodically
|
| 456 |
+
if episode % log_interval == 0:
|
| 457 |
+
current_time = time.time()
|
| 458 |
+
elapsed = current_time - start_time
|
| 459 |
+
|
| 460 |
+
recent_stats = None
|
| 461 |
+
if self.training_stats['policy_loss']:
|
| 462 |
+
recent_stats = {
|
| 463 |
+
'policy_loss': self.training_stats['policy_loss'][-1],
|
| 464 |
+
'value_loss': self.training_stats['value_loss'][-1],
|
| 465 |
+
'entropy': self.training_stats['entropy'][-1],
|
| 466 |
+
'kl_divergence': self.training_stats['kl_divergence'][-1],
|
| 467 |
+
'kl_coef': self.training_stats['kl_coef'][-1],
|
| 468 |
+
'explained_variance': self.training_stats['explained_variance'][-1]
|
| 469 |
+
}
|
| 470 |
+
|
| 471 |
+
self.print_row(episode, num_episodes, recent_stats, elapsed)
|
| 472 |
+
|
| 473 |
+
total_time = time.time() - start_time
|
| 474 |
+
self.print_summary(total_time, num_episodes)
|
| 475 |
+
|
| 476 |
+
return self.training_stats['episode_rewards'], self.training_stats['episode_lengths']
|
| 477 |
+
|
| 478 |
+
def save(self, path):
|
| 479 |
+
"""Save model"""
|
| 480 |
+
torch.save({
|
| 481 |
+
'policy_state_dict': self.policy.state_dict(),
|
| 482 |
+
'old_policy_state_dict': self.old_policy.state_dict(),
|
| 483 |
+
'optimizer_state_dict': self.optimizer.state_dict(),
|
| 484 |
+
'kl_coef': self.kl_coef,
|
| 485 |
+
'training_stats': self.training_stats
|
| 486 |
+
}, path)
|
| 487 |
+
print(f"\n💾 Model saved to {path}")
|
| 488 |
+
|
| 489 |
+
def load(self, path):
|
| 490 |
+
"""Load model"""
|
| 491 |
+
checkpoint = torch.load(path)
|
| 492 |
+
self.policy.load_state_dict(checkpoint['policy_state_dict'])
|
| 493 |
+
self.old_policy.load_state_dict(checkpoint['old_policy_state_dict'])
|
| 494 |
+
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
| 495 |
+
self.kl_coef = checkpoint['kl_coef']
|
| 496 |
+
self.training_stats = checkpoint['training_stats']
|
| 497 |
+
print(f"\n📂 Model loaded from {path}")
|
| 498 |
+
|
| 499 |
+
|
| 500 |
+
# ============== 3. Evaluation Function ==============
|
| 501 |
+
def evaluate_agent(agent, env, num_episodes=10, render=False):
|
| 502 |
+
"""Evaluate trained agent"""
|
| 503 |
+
print("\n" + "🎯" * 35)
|
| 504 |
+
print("Evaluation Started")
|
| 505 |
+
print("🎯" * 35)
|
| 506 |
+
|
| 507 |
+
print(f"\n{'Episode':^12} | {'Reward':^12} | {'Length':^12} | {'Avg Reward':^14}")
|
| 508 |
+
print("-" * 60)
|
| 509 |
+
|
| 510 |
+
episode_rewards = []
|
| 511 |
+
episode_lengths = []
|
| 512 |
+
|
| 513 |
+
for episode in range(num_episodes):
|
| 514 |
+
state, _ = env.reset()
|
| 515 |
+
episode_reward = 0
|
| 516 |
+
episode_step = 0
|
| 517 |
+
|
| 518 |
+
while True:
|
| 519 |
+
if render:
|
| 520 |
+
env.render()
|
| 521 |
+
time.sleep(0.02)
|
| 522 |
+
|
| 523 |
+
action, _, _ = agent.select_action(state, eval_mode=True)
|
| 524 |
+
next_state, reward, terminated, truncated, _ = env.step(action)
|
| 525 |
+
done = terminated or truncated
|
| 526 |
+
|
| 527 |
+
episode_reward += reward
|
| 528 |
+
episode_step += 1
|
| 529 |
+
state = next_state
|
| 530 |
+
|
| 531 |
+
if done:
|
| 532 |
+
break
|
| 533 |
+
|
| 534 |
+
episode_rewards.append(episode_reward)
|
| 535 |
+
episode_lengths.append(episode_step)
|
| 536 |
+
|
| 537 |
+
avg_so_far = np.mean(episode_rewards)
|
| 538 |
+
print(f"{episode + 1:^12} | {episode_reward:^12.1f} | {episode_step:^12} | {avg_so_far:^14.2f}")
|
| 539 |
+
|
| 540 |
+
print("-" * 60)
|
| 541 |
+
print(f"\n📊 Evaluation Results ({num_episodes} episodes):")
|
| 542 |
+
print(f" Avg Reward: {np.mean(episode_rewards):.2f} ± {np.std(episode_rewards):.2f}")
|
| 543 |
+
print(f" Avg Length: {np.mean(episode_lengths):.2f} ± {np.std(episode_lengths):.2f}")
|
| 544 |
+
print(f" Max Reward: {np.max(episode_rewards):.2f}")
|
| 545 |
+
print(f" Min Reward: {np.min(episode_rewards):.2f}")
|
| 546 |
+
print(f" Success Rate (>=475): {np.mean(np.array(episode_rewards) >= 475) * 100:.1f}%")
|
| 547 |
+
print("=" * 60)
|
| 548 |
+
|
| 549 |
+
return episode_rewards, episode_lengths
|
| 550 |
+
|
| 551 |
+
|
| 552 |
+
# ============== 4. Visualization Function ==============
|
| 553 |
+
def plot_training_results(agent, save_path='ppo_penalty_training_results.png'):
|
| 554 |
+
"""Plot training results"""
|
| 555 |
+
|
| 556 |
+
stats = agent.training_stats
|
| 557 |
+
|
| 558 |
+
fig, axes = plt.subplots(2, 3, figsize=(18, 10))
|
| 559 |
+
fig.suptitle('PPO-Penalty (Adaptive KL) Training Results - CartPole-v1', fontsize=16, y=1.02)
|
| 560 |
+
|
| 561 |
+
window = 20
|
| 562 |
+
|
| 563 |
+
# 1. Episode Rewards
|
| 564 |
+
ax1 = axes[0, 0]
|
| 565 |
+
rewards = stats['episode_rewards']
|
| 566 |
+
ax1.plot(rewards, alpha=0.3, color='blue', label='Raw Reward')
|
| 567 |
+
|
| 568 |
+
if len(rewards) >= window:
|
| 569 |
+
smoothed = np.convolve(rewards, np.ones(window) / window, mode='valid')
|
| 570 |
+
ax1.plot(range(window - 1, len(smoothed) + window - 1), smoothed,
|
| 571 |
+
'r-', linewidth=2, label=f'{window}-Episode MA')
|
| 572 |
+
|
| 573 |
+
ax1.axhline(y=500, color='green', linestyle='--', alpha=0.7, label='Target (500)')
|
| 574 |
+
ax1.set_xlabel('Episode')
|
| 575 |
+
ax1.set_ylabel('Total Reward')
|
| 576 |
+
ax1.set_title('Training Rewards')
|
| 577 |
+
ax1.legend()
|
| 578 |
+
ax1.grid(True, alpha=0.3)
|
| 579 |
+
|
| 580 |
+
# 2. Episode Lengths
|
| 581 |
+
ax2 = axes[0, 1]
|
| 582 |
+
lengths = stats['episode_lengths']
|
| 583 |
+
ax2.plot(lengths, alpha=0.3, color='orange', label='Episode Length')
|
| 584 |
+
|
| 585 |
+
if len(lengths) >= window:
|
| 586 |
+
smoothed_length = np.convolve(lengths, np.ones(window) / window, mode='valid')
|
| 587 |
+
ax2.plot(range(window - 1, len(smoothed_length) + window - 1), smoothed_length,
|
| 588 |
+
'r-', linewidth=2)
|
| 589 |
+
|
| 590 |
+
ax2.set_xlabel('Episode')
|
| 591 |
+
ax2.set_ylabel('Episode Length')
|
| 592 |
+
ax2.set_title('Episode Lengths')
|
| 593 |
+
ax2.grid(True, alpha=0.3)
|
| 594 |
+
|
| 595 |
+
# 3. Policy Loss & KL Divergence
|
| 596 |
+
ax3 = axes[0, 2]
|
| 597 |
+
if stats['policy_loss']:
|
| 598 |
+
ax3_twin = ax3.twinx()
|
| 599 |
+
ax3.plot(stats['policy_loss'], color='purple', linewidth=1.5, label='Policy Loss')
|
| 600 |
+
ax3.set_xlabel('Update Step')
|
| 601 |
+
ax3.set_ylabel('Policy Loss', color='purple')
|
| 602 |
+
ax3.tick_params(axis='y', labelcolor='purple')
|
| 603 |
+
|
| 604 |
+
if stats['kl_divergence']:
|
| 605 |
+
ax3_twin.plot(stats['kl_divergence'], color='orange', linewidth=1.5, label='KL Div')
|
| 606 |
+
ax3_twin.set_ylabel('KL Divergence', color='orange')
|
| 607 |
+
ax3_twin.tick_params(axis='y', labelcolor='orange')
|
| 608 |
+
ax3.grid(True, alpha=0.3)
|
| 609 |
+
|
| 610 |
+
# 4. Value Loss
|
| 611 |
+
ax4 = axes[1, 0]
|
| 612 |
+
if stats['value_loss']:
|
| 613 |
+
ax4.plot(stats['value_loss'], color='brown', linewidth=1.5)
|
| 614 |
+
ax4.set_xlabel('Update Step')
|
| 615 |
+
ax4.set_ylabel('Value Loss')
|
| 616 |
+
ax4.set_title('Value Loss')
|
| 617 |
+
ax4.grid(True, alpha=0.3)
|
| 618 |
+
|
| 619 |
+
# 5. Policy Entropy
|
| 620 |
+
ax5 = axes[1, 1]
|
| 621 |
+
if stats['entropy']:
|
| 622 |
+
ax5.plot(stats['entropy'], color='green', linewidth=1.5)
|
| 623 |
+
ax5.set_xlabel('Update Step')
|
| 624 |
+
ax5.set_ylabel('Policy Entropy')
|
| 625 |
+
ax5.set_title('Policy Entropy (Exploration)')
|
| 626 |
+
ax5.grid(True, alpha=0.3)
|
| 627 |
+
|
| 628 |
+
# 6. KL Coefficient & Explained Variance
|
| 629 |
+
ax6 = axes[1, 2]
|
| 630 |
+
if stats['kl_coef']:
|
| 631 |
+
ax6_twin = ax6.twinx()
|
| 632 |
+
ax6.plot(stats['kl_coef'], color='red', linewidth=1.5, label='KL Coef')
|
| 633 |
+
ax6.set_xlabel('Update Step')
|
| 634 |
+
ax6.set_ylabel('KL Coefficient', color='red')
|
| 635 |
+
ax6.tick_params(axis='y', labelcolor='red')
|
| 636 |
+
|
| 637 |
+
if stats['explained_variance']:
|
| 638 |
+
ax6_twin.plot(stats['explained_variance'], color='blue', linewidth=1.5, label='Expl Var')
|
| 639 |
+
ax6_twin.set_ylabel('Explained Variance', color='blue')
|
| 640 |
+
ax6_twin.tick_params(axis='y', labelcolor='blue')
|
| 641 |
+
ax6.grid(True, alpha=0.3)
|
| 642 |
+
|
| 643 |
+
plt.tight_layout()
|
| 644 |
+
plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
| 645 |
+
plt.show()
|
| 646 |
+
print(f"\n📸 Training results saved to {save_path}")
|
| 647 |
+
|
| 648 |
+
|
| 649 |
+
# ============== 5. Main Function ==============
|
| 650 |
+
def main():
|
| 651 |
+
"""Main function"""
|
| 652 |
+
|
| 653 |
+
# Create environment (Gymnasium)
|
| 654 |
+
env = gym.make('CartPole-v1')
|
| 655 |
+
|
| 656 |
+
# PPO-Penalty hyperparameters
|
| 657 |
+
config = {
|
| 658 |
+
'learning_rate': 3e-4,
|
| 659 |
+
'gamma': 0.99,
|
| 660 |
+
'gae_lambda': 0.95,
|
| 661 |
+
'kl_target': 0.01, # Target KL divergence per update
|
| 662 |
+
'kl_coef_init': 1.0, # Initial KL penalty coefficient
|
| 663 |
+
'kl_coef_adapt': 1.5, # Adaptation rate
|
| 664 |
+
'entropy_coef': 0.01,
|
| 665 |
+
'value_coef': 0.5,
|
| 666 |
+
'max_grad_norm': 0.5,
|
| 667 |
+
'update_epochs': 10, # More epochs for PPO-Penalty
|
| 668 |
+
'mini_batch_size': 64,
|
| 669 |
+
'horizon': 2048,
|
| 670 |
+
'hidden_dim': 64
|
| 671 |
+
}
|
| 672 |
+
|
| 673 |
+
print("\n" + "=" * 100)
|
| 674 |
+
print("PPO-Penalty (Adaptive KL) for CartPole-v1 (Gymnasium)")
|
| 675 |
+
print("=" * 100)
|
| 676 |
+
print("\n📋 Hyperparameters:")
|
| 677 |
+
for key, value in config.items():
|
| 678 |
+
print(f" {key:20}: {value}")
|
| 679 |
+
|
| 680 |
+
# Create PPO-Penalty agent
|
| 681 |
+
agent = PPOPenaltyAgent(env, **config)
|
| 682 |
+
|
| 683 |
+
try:
|
| 684 |
+
# Train
|
| 685 |
+
rewards, lengths = agent.train(num_episodes=500, log_interval=20)
|
| 686 |
+
|
| 687 |
+
# Save model
|
| 688 |
+
agent.save('ppo_penalty_cartpole.pth')
|
| 689 |
+
|
| 690 |
+
# Plot results
|
| 691 |
+
plot_training_results(agent)
|
| 692 |
+
|
| 693 |
+
# Evaluate
|
| 694 |
+
print("\n")
|
| 695 |
+
eval_env = gym.make('CartPole-v1')
|
| 696 |
+
evaluate_agent(agent, eval_env, num_episodes=20, render=False)
|
| 697 |
+
|
| 698 |
+
# Demo (with rendering)
|
| 699 |
+
print("\n")
|
| 700 |
+
demo_env = gym.make('CartPole-v1', render_mode='human')
|
| 701 |
+
evaluate_agent(agent, demo_env, num_episodes=3, render=True)
|
| 702 |
+
|
| 703 |
+
except KeyboardInterrupt:
|
| 704 |
+
print("\n\n⚠️ Training interrupted, saving model...")
|
| 705 |
+
agent.save('ppo_penalty_cartpole_interrupted.pth')
|
| 706 |
+
|
| 707 |
+
finally:
|
| 708 |
+
env.close()
|
| 709 |
+
if 'eval_env' in locals():
|
| 710 |
+
eval_env.close()
|
| 711 |
+
if 'demo_env' in locals():
|
| 712 |
+
demo_env.close()
|
| 713 |
+
|
| 714 |
+
|
| 715 |
+
# ============== 6. Compare PPO-Clip vs PPO-Penalty ==============
|
| 716 |
+
def compare_ppo_variants():
|
| 717 |
+
"""Compare PPO-Clip and PPO-Penalty"""
|
| 718 |
+
|
| 719 |
+
print("\n" + "🔬" * 50)
|
| 720 |
+
print("PPO-Clip vs PPO-Penalty Comparison")
|
| 721 |
+
print("🔬" * 50)
|
| 722 |
+
|
| 723 |
+
# This would require implementing both agents and running experiments
|
| 724 |
+
# For brevity, here's the conceptual comparison:
|
| 725 |
+
|
| 726 |
+
comparison = """
|
| 727 |
+
📊 **PPO-Clip vs PPO-Penalty Comparison**
|
| 728 |
+
|
| 729 |
+
============================================================
|
| 730 |
+
Feature | PPO-Clip | PPO-Penalty
|
| 731 |
+
============================================================
|
| 732 |
+
Constraint Type | Hard clipping | Soft KL penalty
|
| 733 |
+
------------------------------------------------------------
|
| 734 |
+
Update Limit | r ∈ [1-ε, 1+ε] | KL(π||π_old) < target
|
| 735 |
+
------------------------------------------------------------
|
| 736 |
+
Adaptation | Fixed ε | Adaptive KL coef
|
| 737 |
+
------------------------------------------------------------
|
| 738 |
+
Implementation | Simple | More complex
|
| 739 |
+
------------------------------------------------------------
|
| 740 |
+
Compute Cost | Low | Higher (KL calc)
|
| 741 |
+
------------------------------------------------------------
|
| 742 |
+
Stability | Very stable | Very stable
|
| 743 |
+
------------------------------------------------------------
|
| 744 |
+
Sample Efficiency | Good | Good
|
| 745 |
+
------------------------------------------------------------
|
| 746 |
+
Hyperparameter | ε=0.2 (robust) | kl_target=0.01
|
| 747 |
+
------------------------------------------------------------
|
| 748 |
+
TRPO Relation | Approximation | Direct descendant
|
| 749 |
+
------------------------------------------------------------
|
| 750 |
+
|
| 751 |
+
**When to use PPO-Penalty:**
|
| 752 |
+
• When you need precise KL control
|
| 753 |
+
• When you're comfortable tuning kl_target
|
| 754 |
+
• When you want to stay closer to TRPO theory
|
| 755 |
+
|
| 756 |
+
**When to use PPO-Clip:**
|
| 757 |
+
• Default choice for most problems
|
| 758 |
+
• Simpler, fewer hyperparameters
|
| 759 |
+
• More widely adopted in practice
|
| 760 |
+
"""
|
| 761 |
+
|
| 762 |
+
print(comparison)
|
| 763 |
+
|
| 764 |
+
|
| 765 |
+
if __name__ == "__main__":
|
| 766 |
+
main()
|
| 767 |
+
# compare_ppo_variants() # Uncomment to see comparison
|
examples/tutorials/rl/cart_pole/step_2_reinforce.py
ADDED
|
@@ -0,0 +1,382 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""
|
| 4 |
+
策略梯度法
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
如果在相同的动作序列下,环境会输出相同的状态和奖励,此方法会彻底失效。
|
| 8 |
+
|
| 9 |
+
推车立杆任体力比较简单,每次只有2个动作,它很容易就将每一步的最优选择学会了,但对于复杂任务,REINFORCE 可能要困难得多得多。
|
| 10 |
+
|
| 11 |
+
其本质是搜集多局游戏数据,进行奖励最大化,再依赖环境的随机性进一步迭代。其类似于动态规划算法。 并不保证能找到全局最优解。
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
"""
|
| 15 |
+
import gymnasium as gym
|
| 16 |
+
import torch
|
| 17 |
+
import torch.nn as nn
|
| 18 |
+
import torch.optim as optim
|
| 19 |
+
from torch.distributions import Categorical
|
| 20 |
+
import numpy as np
|
| 21 |
+
import matplotlib.pyplot as plt
|
| 22 |
+
from collections import deque
|
| 23 |
+
import warnings
|
| 24 |
+
|
| 25 |
+
warnings.filterwarnings('ignore')
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# ==================== 策略网络 ====================
|
| 29 |
+
class PolicyNetwork(nn.Module):
|
| 30 |
+
"""
|
| 31 |
+
策略网络:状态 -> 动作概率分布
|
| 32 |
+
输出是每个动作的概率(经过softmax)
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
def __init__(self, state_dim, hidden_dim, action_dim):
|
| 36 |
+
super(PolicyNetwork, self).__init__()
|
| 37 |
+
|
| 38 |
+
self.network = nn.Sequential(
|
| 39 |
+
nn.Linear(state_dim, hidden_dim),
|
| 40 |
+
nn.ReLU(),
|
| 41 |
+
nn.Linear(hidden_dim, hidden_dim),
|
| 42 |
+
nn.ReLU(),
|
| 43 |
+
nn.Linear(hidden_dim, action_dim),
|
| 44 |
+
nn.Softmax(dim=-1) # 输出概率分布
|
| 45 |
+
)
|
| 46 |
+
self.apply(self._init_weights)
|
| 47 |
+
|
| 48 |
+
def _init_weights(self, module):
|
| 49 |
+
if isinstance(module, nn.Linear):
|
| 50 |
+
nn.init.orthogonal_(module.weight, gain=np.sqrt(2))
|
| 51 |
+
nn.init.constant_(module.bias, 0)
|
| 52 |
+
|
| 53 |
+
def forward(self, state):
|
| 54 |
+
"""
|
| 55 |
+
输入: state [batch_size, state_dim]
|
| 56 |
+
输出: action_probs [batch_size, action_dim]
|
| 57 |
+
"""
|
| 58 |
+
return self.network(state)
|
| 59 |
+
|
| 60 |
+
def get_action(self, state):
|
| 61 |
+
"""
|
| 62 |
+
根据概率分布采样动作
|
| 63 |
+
返回: action, log_prob
|
| 64 |
+
"""
|
| 65 |
+
state = torch.FloatTensor(state).unsqueeze(0)
|
| 66 |
+
# state shape: [1, 4]
|
| 67 |
+
probs = self.forward(state)
|
| 68 |
+
# probs shape: [1, 2]
|
| 69 |
+
dist = Categorical(probs)
|
| 70 |
+
action = dist.sample()
|
| 71 |
+
# action shape: [1], 一个数值。
|
| 72 |
+
log_prob = dist.log_prob(action)
|
| 73 |
+
# log_prob shape: [1], 一个数值。
|
| 74 |
+
return action.item(), log_prob
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
# ==================== REINFORCE智能体 ====================
|
| 78 |
+
class ReinforceAgent:
|
| 79 |
+
"""
|
| 80 |
+
REINFORCE: 蒙特卡洛策略梯度
|
| 81 |
+
核心思想:用完整轨迹的累积奖励来更新策略
|
| 82 |
+
好的轨迹 -> 增加这些动作的概率
|
| 83 |
+
坏的轨迹 -> 减少这些动作的概率
|
| 84 |
+
"""
|
| 85 |
+
|
| 86 |
+
def __init__(self,
|
| 87 |
+
state_dim,
|
| 88 |
+
hidden_dim=128,
|
| 89 |
+
action_dim=2,
|
| 90 |
+
lr=1e-3,
|
| 91 |
+
gamma=0.99):
|
| 92 |
+
|
| 93 |
+
self.gamma = gamma
|
| 94 |
+
self.policy = PolicyNetwork(state_dim, hidden_dim, action_dim)
|
| 95 |
+
self.optimizer = optim.Adam(self.policy.parameters(), lr=lr)
|
| 96 |
+
|
| 97 |
+
# 轨迹存储
|
| 98 |
+
self.log_probs = [] # 每个动作的对数概率
|
| 99 |
+
self.rewards = [] # 每个时间步的奖励
|
| 100 |
+
|
| 101 |
+
def select_action(self, state):
|
| 102 |
+
"""选择动作并记录对数概率"""
|
| 103 |
+
action, log_prob = self.policy.get_action(state)
|
| 104 |
+
self.log_probs.append(log_prob)
|
| 105 |
+
return action
|
| 106 |
+
|
| 107 |
+
def store_reward(self, reward):
|
| 108 |
+
"""存储奖励"""
|
| 109 |
+
self.rewards.append(reward)
|
| 110 |
+
|
| 111 |
+
def update(self):
|
| 112 |
+
"""
|
| 113 |
+
REINFORCE更新公式:
|
| 114 |
+
∇J = E[ Σ ∇logπ(a_t|s_t) * G_t ]
|
| 115 |
+
其中 G_t = Σ γ^(k-t) * r_k 是从t开始的累积折扣奖励
|
| 116 |
+
"""
|
| 117 |
+
# 计算累积折扣奖励 G_t
|
| 118 |
+
returns = []
|
| 119 |
+
G = 0
|
| 120 |
+
for r in reversed(self.rewards):
|
| 121 |
+
G = r + self.gamma * G
|
| 122 |
+
returns.insert(0, G)
|
| 123 |
+
|
| 124 |
+
returns = torch.tensor(returns)
|
| 125 |
+
|
| 126 |
+
# 标准化 returns(降低方差,不是必须但很有帮助)
|
| 127 |
+
returns = (returns - returns.mean()) / (returns.std() + 1e-9)
|
| 128 |
+
|
| 129 |
+
# 计算策略梯度损失
|
| 130 |
+
policy_loss = []
|
| 131 |
+
for log_prob, G in zip(self.log_probs, returns):
|
| 132 |
+
# 核心公式:-log_prob * G
|
| 133 |
+
# 负号是因为PyTorch做梯度下降,我们要最大化J
|
| 134 |
+
policy_loss.append(-log_prob * G)
|
| 135 |
+
|
| 136 |
+
policy_loss = torch.stack(policy_loss).sum()
|
| 137 |
+
|
| 138 |
+
# 更新策略
|
| 139 |
+
self.optimizer.zero_grad()
|
| 140 |
+
policy_loss.backward()
|
| 141 |
+
torch.nn.utils.clip_grad_norm_(self.policy.parameters(), max_norm=0.5)
|
| 142 |
+
self.optimizer.step()
|
| 143 |
+
|
| 144 |
+
# 清空轨迹
|
| 145 |
+
self.log_probs = []
|
| 146 |
+
self.rewards = []
|
| 147 |
+
|
| 148 |
+
return policy_loss.item()
|
| 149 |
+
|
| 150 |
+
def save(self, path):
|
| 151 |
+
torch.save(self.policy.state_dict(), path)
|
| 152 |
+
|
| 153 |
+
def load(self, path):
|
| 154 |
+
self.policy.load_state_dict(torch.load(path))
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
# ==================== 训练函数 ====================
|
| 158 |
+
def train_reinforce(env_name='CartPole-v1',
|
| 159 |
+
hidden_dim=128,
|
| 160 |
+
lr=1e-3,
|
| 161 |
+
gamma=0.99,
|
| 162 |
+
max_episodes=1000,
|
| 163 |
+
log_interval=20):
|
| 164 |
+
"""
|
| 165 |
+
训练REINFORCE智能体
|
| 166 |
+
"""
|
| 167 |
+
# 创建环境
|
| 168 |
+
# env = gym.make(env_name, render_mode='human')
|
| 169 |
+
env = gym.make(env_name)
|
| 170 |
+
state_dim = env.observation_space.shape[0]
|
| 171 |
+
action_dim = env.action_space.n
|
| 172 |
+
|
| 173 |
+
# 初始化智能体
|
| 174 |
+
agent = ReinforceAgent(
|
| 175 |
+
state_dim=state_dim,
|
| 176 |
+
hidden_dim=hidden_dim,
|
| 177 |
+
action_dim=action_dim,
|
| 178 |
+
lr=lr,
|
| 179 |
+
gamma=gamma
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
# 记录训练过程
|
| 183 |
+
episode_rewards = []
|
| 184 |
+
episode_losses = []
|
| 185 |
+
moving_avg_rewards = deque(maxlen=100)
|
| 186 |
+
|
| 187 |
+
print(f"开始训练 REINFORCE on {env_name}")
|
| 188 |
+
print(f"状态维度: {state_dim}, 动作维度: {action_dim}")
|
| 189 |
+
print(f"学习率: {lr}, 折扣因子: {gamma}")
|
| 190 |
+
print("-" * 50)
|
| 191 |
+
|
| 192 |
+
for episode in range(1, max_episodes + 1):
|
| 193 |
+
state, _ = env.reset()
|
| 194 |
+
episode_reward = 0
|
| 195 |
+
episode_loss = 0
|
| 196 |
+
done = False
|
| 197 |
+
|
| 198 |
+
# 收集一条完整轨迹
|
| 199 |
+
while not done:
|
| 200 |
+
# 选择动作
|
| 201 |
+
action = agent.select_action(state)
|
| 202 |
+
|
| 203 |
+
# 执行动作
|
| 204 |
+
next_state, reward, terminated, truncated, _ = env.step(action)
|
| 205 |
+
done = terminated or truncated
|
| 206 |
+
|
| 207 |
+
# 存储奖励
|
| 208 |
+
agent.store_reward(reward)
|
| 209 |
+
|
| 210 |
+
state = next_state
|
| 211 |
+
episode_reward += reward
|
| 212 |
+
|
| 213 |
+
# 一条轨迹结束后更新策略
|
| 214 |
+
loss = agent.update()
|
| 215 |
+
episode_loss = loss
|
| 216 |
+
|
| 217 |
+
# 记录
|
| 218 |
+
episode_rewards.append(episode_reward)
|
| 219 |
+
episode_losses.append(episode_loss)
|
| 220 |
+
moving_avg_rewards.append(episode_reward)
|
| 221 |
+
|
| 222 |
+
# 打印进度
|
| 223 |
+
if episode % log_interval == 0:
|
| 224 |
+
avg_reward = np.mean(moving_avg_rewards)
|
| 225 |
+
print(f"Episode {episode:5d} | "
|
| 226 |
+
f"Reward: {episode_reward:6.2f} | "
|
| 227 |
+
f"Avg Reward: {avg_reward:6.2f} | "
|
| 228 |
+
f"Loss: {episode_loss:8.4f}")
|
| 229 |
+
|
| 230 |
+
# 早停:如果连续100局平均分>=475
|
| 231 |
+
if len(moving_avg_rewards) == 100 and np.mean(moving_avg_rewards) >= 475:
|
| 232 |
+
print(f"\n🎉 在第 {episode} 回合解决问题!平均奖励: {np.mean(moving_avg_rewards):.2f}")
|
| 233 |
+
break
|
| 234 |
+
|
| 235 |
+
env.close()
|
| 236 |
+
return agent, episode_rewards, episode_losses
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
# ==================== 可视化函数 ====================
|
| 240 |
+
def plot_training(rewards, losses, save_path=None):
|
| 241 |
+
"""
|
| 242 |
+
绘制训练曲线
|
| 243 |
+
"""
|
| 244 |
+
fig, axes = plt.subplots(2, 2, figsize=(15, 10))
|
| 245 |
+
|
| 246 |
+
# 1. 原始奖励曲线
|
| 247 |
+
axes[0, 0].plot(rewards, alpha=0.6, color='blue', linewidth=0.8)
|
| 248 |
+
axes[0, 0].set_xlabel('Episode')
|
| 249 |
+
axes[0, 0].set_ylabel('Total Reward')
|
| 250 |
+
axes[0, 0].set_title('Training Rewards')
|
| 251 |
+
axes[0, 0].grid(True, alpha=0.3)
|
| 252 |
+
|
| 253 |
+
# 2. 移动平均奖励
|
| 254 |
+
window = 20
|
| 255 |
+
moving_avg = np.convolve(rewards, np.ones(window) / window, mode='valid')
|
| 256 |
+
axes[0, 1].plot(moving_avg, color='red', linewidth=2)
|
| 257 |
+
axes[0, 1].fill_between(range(len(moving_avg)),
|
| 258 |
+
moving_avg - np.std(rewards[:len(moving_avg)]),
|
| 259 |
+
moving_avg + np.std(rewards[:len(moving_avg)]),
|
| 260 |
+
alpha=0.2, color='red')
|
| 261 |
+
axes[0, 1].set_xlabel('Episode')
|
| 262 |
+
axes[0, 1].set_ylabel(f'Moving Avg Reward (window={window})')
|
| 263 |
+
axes[0, 1].set_title('Smoothed Training Curve')
|
| 264 |
+
axes[0, 1].grid(True, alpha=0.3)
|
| 265 |
+
|
| 266 |
+
# 3. 损失曲线
|
| 267 |
+
axes[1, 0].plot(losses, color='green', alpha=0.6, linewidth=0.8)
|
| 268 |
+
axes[1, 0].set_xlabel('Episode')
|
| 269 |
+
axes[1, 0].set_ylabel('Policy Loss')
|
| 270 |
+
axes[1, 0].set_title('Training Loss')
|
| 271 |
+
axes[1, 0].grid(True, alpha=0.3)
|
| 272 |
+
|
| 273 |
+
# 4. 奖励分布直方图
|
| 274 |
+
axes[1, 1].hist(rewards[-100:], bins=20, color='purple', alpha=0.7, edgecolor='black')
|
| 275 |
+
axes[1, 1].set_xlabel('Total Reward')
|
| 276 |
+
axes[1, 1].set_ylabel('Frequency')
|
| 277 |
+
axes[1, 1].set_title('Reward Distribution (Last 100 Episodes)')
|
| 278 |
+
axes[1, 1].axvline(x=np.mean(rewards[-100:]), color='red', linestyle='--',
|
| 279 |
+
label=f'Mean: {np.mean(rewards[-100:]):.1f}')
|
| 280 |
+
axes[1, 1].legend()
|
| 281 |
+
axes[1, 1].grid(True, alpha=0.3)
|
| 282 |
+
|
| 283 |
+
plt.tight_layout()
|
| 284 |
+
|
| 285 |
+
if save_path:
|
| 286 |
+
plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
| 287 |
+
plt.show()
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
# ==================== 演示智能体 ====================
|
| 291 |
+
def demo_agent(agent, env_name='CartPole-v1', episodes=5):
|
| 292 |
+
"""
|
| 293 |
+
演示训练好的智能体
|
| 294 |
+
"""
|
| 295 |
+
env = gym.make(env_name, render_mode='human')
|
| 296 |
+
# env = gym.make(env_name)
|
| 297 |
+
|
| 298 |
+
for episode in range(episodes):
|
| 299 |
+
state, _ = env.reset()
|
| 300 |
+
total_reward = 0
|
| 301 |
+
done = False
|
| 302 |
+
|
| 303 |
+
while not done:
|
| 304 |
+
action, _ = agent.policy.get_action(state)
|
| 305 |
+
state, reward, terminated, truncated, _ = env.step(action)
|
| 306 |
+
done = terminated or truncated
|
| 307 |
+
total_reward += reward
|
| 308 |
+
|
| 309 |
+
print(f"Demo Episode {episode + 1}: Reward = {total_reward}")
|
| 310 |
+
|
| 311 |
+
env.close()
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
# ==================== 超参数调优 ====================
|
| 315 |
+
def hyperparameter_sweep():
|
| 316 |
+
"""
|
| 317 |
+
简单的超参数搜索
|
| 318 |
+
"""
|
| 319 |
+
learning_rates = [1e-4, 3e-4, 1e-3, 3e-3]
|
| 320 |
+
hidden_sizes = [64, 128, 256]
|
| 321 |
+
|
| 322 |
+
results = {}
|
| 323 |
+
|
| 324 |
+
for lr in learning_rates:
|
| 325 |
+
for hidden in hidden_sizes:
|
| 326 |
+
print(f"\n测试 lr={lr}, hidden={hidden}")
|
| 327 |
+
_, rewards, _ = train_reinforce(
|
| 328 |
+
lr=lr,
|
| 329 |
+
hidden_dim=hidden,
|
| 330 |
+
max_episodes=300,
|
| 331 |
+
log_interval=50
|
| 332 |
+
)
|
| 333 |
+
|
| 334 |
+
avg_reward = np.mean(rewards[-50:])
|
| 335 |
+
results[(lr, hidden)] = avg_reward
|
| 336 |
+
print(f"平均奖励: {avg_reward:.2f}")
|
| 337 |
+
|
| 338 |
+
# 找出最佳参数
|
| 339 |
+
best_params = max(results, key=results.get)
|
| 340 |
+
print(f"\n最佳参数: lr={best_params[0]}, hidden={best_params[1]}")
|
| 341 |
+
print(f"最佳平均奖励: {results[best_params]:.2f}")
|
| 342 |
+
|
| 343 |
+
return results
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
# ==================== 主程序 ====================
|
| 347 |
+
if __name__ == "__main__":
|
| 348 |
+
# 设置随机种子(可复现)
|
| 349 |
+
torch.manual_seed(42)
|
| 350 |
+
np.random.seed(42)
|
| 351 |
+
|
| 352 |
+
# 训练参数
|
| 353 |
+
CONFIG = {
|
| 354 |
+
'env_name': 'CartPole-v1',
|
| 355 |
+
'hidden_dim': 128,
|
| 356 |
+
'lr': 1e-3,
|
| 357 |
+
'gamma': 0.99,
|
| 358 |
+
'max_episodes': 800,
|
| 359 |
+
'log_interval': 20
|
| 360 |
+
}
|
| 361 |
+
|
| 362 |
+
# 训练智能体
|
| 363 |
+
agent, rewards, losses = train_reinforce(**CONFIG)
|
| 364 |
+
|
| 365 |
+
# 绘制训练曲线
|
| 366 |
+
plot_training(rewards, losses, save_path='reinforce_training.png')
|
| 367 |
+
|
| 368 |
+
# 打印最终结果
|
| 369 |
+
print("\n" + "=" * 50)
|
| 370 |
+
print("训练完成!")
|
| 371 |
+
print(f"最高奖励: {max(rewards):.2f}")
|
| 372 |
+
print(f"平均奖励(最后100局): {np.mean(rewards[-100:]):.2f}")
|
| 373 |
+
print(f"标准差(最后100局): {np.std(rewards[-100:]):.2f}")
|
| 374 |
+
print("=" * 50)
|
| 375 |
+
|
| 376 |
+
# 保存模型
|
| 377 |
+
agent.save('reinforce_cartpole.pth')
|
| 378 |
+
print("模型已保存到 reinforce_cartpole.pth")
|
| 379 |
+
|
| 380 |
+
# 演示
|
| 381 |
+
print("\n开始演示...")
|
| 382 |
+
demo_agent(agent, episodes=3)
|
examples/tutorials/rl/cart_pole/step_2_reinforce_with_baseline.py
ADDED
|
@@ -0,0 +1,332 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""
|
| 4 |
+
带基线的方法(REINFORCE with Baseline),相当于是先预测一个未来预期奖励的估计,
|
| 5 |
+
然后如果实际的奖励大于这个值,则模型得到正向反馈,动作的概率会被加大,
|
| 6 |
+
如果小于这个值,则动作的概率会减小。
|
| 7 |
+
这相比于原始的方法(REINFORCE)永远只是增大动作的概率使训练变得更稳定。
|
| 8 |
+
|
| 9 |
+
如果按照训练的轮次/游戏局数来比较,带基线的方法确实收敛更快。
|
| 10 |
+
但额外的价值网络会增加额外的价值网络。
|
| 11 |
+
|
| 12 |
+
原始REINFORCE 方法,当确定的动作序列导致确定的状态序列时,会失效。
|
| 13 |
+
但带基线REINFORCE 仍然有效。
|
| 14 |
+
|
| 15 |
+
"""
|
| 16 |
+
import argparse
|
| 17 |
+
from collections import deque
|
| 18 |
+
|
| 19 |
+
import gymnasium as gym
|
| 20 |
+
import matplotlib.pyplot as plt
|
| 21 |
+
import numpy as np
|
| 22 |
+
import torch
|
| 23 |
+
import torch.nn as nn
|
| 24 |
+
import torch.optim as optim
|
| 25 |
+
from torch.distributions import Categorical
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class PolicyNetwork(nn.Module):
|
| 29 |
+
"""策略网络 - 输出动作概率分布"""
|
| 30 |
+
|
| 31 |
+
def __init__(self, state_dim, hidden_dim, action_dim):
|
| 32 |
+
super(PolicyNetwork, self).__init__()
|
| 33 |
+
self.fc1 = nn.Linear(state_dim, hidden_dim)
|
| 34 |
+
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
|
| 35 |
+
self.fc3 = nn.Linear(hidden_dim, action_dim)
|
| 36 |
+
self.softmax = nn.Softmax(dim=-1)
|
| 37 |
+
|
| 38 |
+
def forward(self, state):
|
| 39 |
+
x = torch.relu(self.fc1(state))
|
| 40 |
+
x = torch.relu(self.fc2(x))
|
| 41 |
+
x = self.fc3(x)
|
| 42 |
+
return self.softmax(x)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class ValueNetwork(nn.Module):
|
| 46 |
+
"""价值网络 - 作为基线,估计状态价值"""
|
| 47 |
+
|
| 48 |
+
def __init__(self, state_dim, hidden_dim):
|
| 49 |
+
super(ValueNetwork, self).__init__()
|
| 50 |
+
self.fc1 = nn.Linear(state_dim, hidden_dim)
|
| 51 |
+
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
|
| 52 |
+
self.fc3 = nn.Linear(hidden_dim, 1)
|
| 53 |
+
|
| 54 |
+
def forward(self, state):
|
| 55 |
+
x = torch.relu(self.fc1(state))
|
| 56 |
+
x = torch.relu(self.fc2(x))
|
| 57 |
+
return self.fc3(x)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class REINFORCEwithBaseline:
|
| 61 |
+
"""带基线的策略梯度算法"""
|
| 62 |
+
|
| 63 |
+
def __init__(self,
|
| 64 |
+
env,
|
| 65 |
+
policy_lr=1e-3,
|
| 66 |
+
value_lr=1e-3,
|
| 67 |
+
gamma=0.99,
|
| 68 |
+
hidden_dim=128,
|
| 69 |
+
render=False):
|
| 70 |
+
|
| 71 |
+
self.env = env
|
| 72 |
+
self.gamma = gamma
|
| 73 |
+
self.render = render
|
| 74 |
+
|
| 75 |
+
# 获取状态和动作维度
|
| 76 |
+
self.state_dim = env.observation_space.shape[0]
|
| 77 |
+
self.action_dim = env.action_space.n
|
| 78 |
+
|
| 79 |
+
# 初始化策略网络和价值网络
|
| 80 |
+
self.policy_net = PolicyNetwork(self.state_dim, hidden_dim, self.action_dim)
|
| 81 |
+
self.value_net = ValueNetwork(self.state_dim, hidden_dim)
|
| 82 |
+
|
| 83 |
+
# 优化器
|
| 84 |
+
self.policy_optimizer = optim.Adam(self.policy_net.parameters(), lr=policy_lr)
|
| 85 |
+
self.value_optimizer = optim.Adam(self.value_net.parameters(), lr=value_lr)
|
| 86 |
+
|
| 87 |
+
# 存储轨迹
|
| 88 |
+
self.reset_memory()
|
| 89 |
+
|
| 90 |
+
# 记录训练信息
|
| 91 |
+
self.training_stats = {'episode_rewards': [], 'baseline_loss': []}
|
| 92 |
+
|
| 93 |
+
def reset_memory(self):
|
| 94 |
+
"""重置存储的记忆"""
|
| 95 |
+
self.states = []
|
| 96 |
+
self.actions = []
|
| 97 |
+
self.rewards = []
|
| 98 |
+
self.log_probs = []
|
| 99 |
+
|
| 100 |
+
def select_action(self, state):
|
| 101 |
+
"""根据当前策略选择动作"""
|
| 102 |
+
state = torch.FloatTensor(state).unsqueeze(0)
|
| 103 |
+
probs = self.policy_net(state)
|
| 104 |
+
m = Categorical(probs)
|
| 105 |
+
action = m.sample()
|
| 106 |
+
log_prob = m.log_prob(action)
|
| 107 |
+
|
| 108 |
+
# 存储经验
|
| 109 |
+
self.states.append(state)
|
| 110 |
+
self.actions.append(action)
|
| 111 |
+
self.log_probs.append(log_prob)
|
| 112 |
+
|
| 113 |
+
return action.item()
|
| 114 |
+
|
| 115 |
+
def compute_returns(self):
|
| 116 |
+
"""计算折扣回报"""
|
| 117 |
+
returns = []
|
| 118 |
+
R = 0
|
| 119 |
+
for r in reversed(self.rewards):
|
| 120 |
+
R = r + self.gamma * R
|
| 121 |
+
returns.insert(0, R)
|
| 122 |
+
returns = torch.FloatTensor(returns)
|
| 123 |
+
# 标准化回报以稳定训练
|
| 124 |
+
returns = (returns - returns.mean()) / (returns.std() + 1e-9)
|
| 125 |
+
return returns
|
| 126 |
+
|
| 127 |
+
def update(self):
|
| 128 |
+
"""更新策略网络和价值网络"""
|
| 129 |
+
if len(self.rewards) == 0:
|
| 130 |
+
return
|
| 131 |
+
|
| 132 |
+
# 计算回报和价值估计
|
| 133 |
+
returns = self.compute_returns()
|
| 134 |
+
states = torch.cat(self.states)
|
| 135 |
+
|
| 136 |
+
# 1. 更新价值网络(基线)
|
| 137 |
+
value_pred = self.value_net(states).squeeze()
|
| 138 |
+
value_loss = nn.MSELoss()(value_pred, returns)
|
| 139 |
+
|
| 140 |
+
self.value_optimizer.zero_grad()
|
| 141 |
+
value_loss.backward()
|
| 142 |
+
torch.nn.utils.clip_grad_norm_(self.value_net.parameters(), 0.5)
|
| 143 |
+
self.value_optimizer.step()
|
| 144 |
+
|
| 145 |
+
# 2. 更新策略网络
|
| 146 |
+
# 重新计算价值估计用于优势函数
|
| 147 |
+
with torch.no_grad():
|
| 148 |
+
baselines = self.value_net(states).squeeze()
|
| 149 |
+
|
| 150 |
+
# 计算策略梯度
|
| 151 |
+
policy_loss = []
|
| 152 |
+
for log_prob, ret, baseline in zip(self.log_probs, returns, baselines):
|
| 153 |
+
advantage = ret - baseline
|
| 154 |
+
policy_loss.append(-log_prob * advantage)
|
| 155 |
+
|
| 156 |
+
policy_loss = torch.cat(policy_loss).sum()
|
| 157 |
+
|
| 158 |
+
self.policy_optimizer.zero_grad()
|
| 159 |
+
policy_loss.backward()
|
| 160 |
+
torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 0.5)
|
| 161 |
+
self.policy_optimizer.step()
|
| 162 |
+
|
| 163 |
+
# 记录价值损失
|
| 164 |
+
self.training_stats['baseline_loss'].append(value_loss.item())
|
| 165 |
+
|
| 166 |
+
# 清空记忆
|
| 167 |
+
self.reset_memory()
|
| 168 |
+
|
| 169 |
+
def train(self, num_episodes, max_steps_per_episode=500):
|
| 170 |
+
"""训练智能体"""
|
| 171 |
+
episode_rewards = []
|
| 172 |
+
best_avg_reward = -np.inf
|
| 173 |
+
reward_window = deque(maxlen=100)
|
| 174 |
+
|
| 175 |
+
for episode in range(num_episodes):
|
| 176 |
+
state, _ = self.env.reset()
|
| 177 |
+
episode_reward = 0
|
| 178 |
+
|
| 179 |
+
for step in range(max_steps_per_episode):
|
| 180 |
+
if self.render:
|
| 181 |
+
self.env.render()
|
| 182 |
+
|
| 183 |
+
action = self.select_action(state)
|
| 184 |
+
next_state, reward, terminated, truncated, _ = self.env.step(action)
|
| 185 |
+
done = terminated or truncated
|
| 186 |
+
|
| 187 |
+
self.rewards.append(reward)
|
| 188 |
+
episode_reward += reward
|
| 189 |
+
|
| 190 |
+
if done:
|
| 191 |
+
break
|
| 192 |
+
|
| 193 |
+
state = next_state
|
| 194 |
+
|
| 195 |
+
# 更新网络
|
| 196 |
+
self.update()
|
| 197 |
+
|
| 198 |
+
# 记录并输出训练信息
|
| 199 |
+
episode_rewards.append(episode_reward)
|
| 200 |
+
reward_window.append(episode_reward)
|
| 201 |
+
avg_reward = np.mean(reward_window)
|
| 202 |
+
|
| 203 |
+
self.training_stats['episode_rewards'].append(episode_reward)
|
| 204 |
+
|
| 205 |
+
if (episode + 1) % 10 == 0:
|
| 206 |
+
avg_baseline_loss = np.mean(self.training_stats['baseline_loss'][-10:])
|
| 207 |
+
print(f'Episode {episode + 1}/{num_episodes}, '
|
| 208 |
+
f'Reward: {episode_reward:.1f}, '
|
| 209 |
+
f'Avg Reward: {avg_reward:.1f}, '
|
| 210 |
+
f'Baseline Loss: {avg_baseline_loss:.4f}')
|
| 211 |
+
|
| 212 |
+
# 保存最佳模型
|
| 213 |
+
if avg_reward > best_avg_reward and episode > 100:
|
| 214 |
+
best_avg_reward = avg_reward
|
| 215 |
+
torch.save({
|
| 216 |
+
'policy_state_dict': self.policy_net.state_dict(),
|
| 217 |
+
'value_state_dict': self.value_net.state_dict(),
|
| 218 |
+
'avg_reward': avg_reward,
|
| 219 |
+
'episode': episode
|
| 220 |
+
}, 'best_model.pth')
|
| 221 |
+
|
| 222 |
+
return episode_rewards
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
def plot_training_results(rewards, baseline_loss, window=100):
|
| 226 |
+
"""绘制训练结果"""
|
| 227 |
+
fig, axes = plt.subplots(2, 1, figsize=(10, 8))
|
| 228 |
+
|
| 229 |
+
# 绘制奖励曲线
|
| 230 |
+
axes[0].plot(rewards, alpha=0.3, color='blue', label='Episode Reward')
|
| 231 |
+
|
| 232 |
+
# 绘制平滑曲线
|
| 233 |
+
if len(rewards) >= window:
|
| 234 |
+
smoothed = np.convolve(rewards, np.ones(window) / window, mode='valid')
|
| 235 |
+
axes[0].plot(range(window - 1, len(rewards)), smoothed,
|
| 236 |
+
color='red', linewidth=2, label=f'Moving Avg (window={window})')
|
| 237 |
+
|
| 238 |
+
axes[0].set_xlabel('Episode')
|
| 239 |
+
axes[0].set_ylabel('Total Reward')
|
| 240 |
+
axes[0].set_title('REINFORCE with Baseline - Training Rewards')
|
| 241 |
+
axes[0].legend()
|
| 242 |
+
axes[0].grid(True, alpha=0.3)
|
| 243 |
+
|
| 244 |
+
# 绘制基线损失
|
| 245 |
+
axes[1].plot(baseline_loss, color='green', alpha=0.6, label='Baseline Loss')
|
| 246 |
+
axes[1].set_xlabel('Update Step')
|
| 247 |
+
axes[1].set_ylabel('MSE Loss')
|
| 248 |
+
axes[1].set_title('Value Network (Baseline) Loss')
|
| 249 |
+
axes[1].legend()
|
| 250 |
+
axes[1].grid(True, alpha=0.3)
|
| 251 |
+
|
| 252 |
+
plt.tight_layout()
|
| 253 |
+
plt.savefig('training_results.png', dpi=100)
|
| 254 |
+
plt.show()
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
def test_agent(env, policy_net, num_episodes=10, render=True):
|
| 258 |
+
"""测试训练好的智能体"""
|
| 259 |
+
episode_rewards = []
|
| 260 |
+
|
| 261 |
+
for episode in range(num_episodes):
|
| 262 |
+
state, _ = env.reset()
|
| 263 |
+
episode_reward = 0
|
| 264 |
+
done = False
|
| 265 |
+
|
| 266 |
+
while not done:
|
| 267 |
+
if render:
|
| 268 |
+
env.render()
|
| 269 |
+
|
| 270 |
+
state_tensor = torch.FloatTensor(state).unsqueeze(0)
|
| 271 |
+
with torch.no_grad():
|
| 272 |
+
probs = policy_net(state_tensor)
|
| 273 |
+
action = torch.argmax(probs).item()
|
| 274 |
+
|
| 275 |
+
next_state, reward, terminated, truncated, _ = env.step(action)
|
| 276 |
+
done = terminated or truncated
|
| 277 |
+
episode_reward += reward
|
| 278 |
+
state = next_state
|
| 279 |
+
|
| 280 |
+
episode_rewards.append(episode_reward)
|
| 281 |
+
print(f'Test Episode {episode + 1}: Reward = {episode_reward}')
|
| 282 |
+
|
| 283 |
+
print(f'Average Test Reward: {np.mean(episode_rewards):.2f} +/- {np.std(episode_rewards):.2f}')
|
| 284 |
+
return episode_rewards
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
def main():
|
| 288 |
+
parser = argparse.ArgumentParser(description='REINFORCE with Baseline for CartPole')
|
| 289 |
+
parser.add_argument('--episodes', type=int, default=1000, help='Number of training episodes')
|
| 290 |
+
parser.add_argument('--policy_lr', type=float, default=1e-3, help='Learning rate for policy network')
|
| 291 |
+
parser.add_argument('--value_lr', type=float, default=1e-3, help='Learning rate for value network')
|
| 292 |
+
parser.add_argument('--gamma', type=float, default=0.99, help='Discount factor')
|
| 293 |
+
parser.add_argument('--hidden_dim', type=int, default=128, help='Hidden layer dimension')
|
| 294 |
+
parser.add_argument('--render_train', action='store_true', help='Render during training')
|
| 295 |
+
parser.add_argument('--render_test', action='store_true', help='Render during testing')
|
| 296 |
+
args = parser.parse_args()
|
| 297 |
+
|
| 298 |
+
# 创建环境
|
| 299 |
+
env = gym.make('CartPole-v1')
|
| 300 |
+
|
| 301 |
+
# 创建智能体
|
| 302 |
+
agent = REINFORCEwithBaseline(
|
| 303 |
+
env=env,
|
| 304 |
+
policy_lr=args.policy_lr,
|
| 305 |
+
value_lr=args.value_lr,
|
| 306 |
+
gamma=args.gamma,
|
| 307 |
+
hidden_dim=args.hidden_dim,
|
| 308 |
+
render=args.render_train
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
# 训练
|
| 312 |
+
print("开始训练 REINFORCE with Baseline...")
|
| 313 |
+
rewards = agent.train(num_episodes=args.episodes)
|
| 314 |
+
|
| 315 |
+
# 绘制训练结果
|
| 316 |
+
plot_training_results(
|
| 317 |
+
agent.training_stats['episode_rewards'],
|
| 318 |
+
agent.training_stats['baseline_loss']
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
# 测试
|
| 322 |
+
print("\n测试训练好的模型...")
|
| 323 |
+
test_env = gym.make('CartPole-v1')
|
| 324 |
+
test_agent(test_env, agent.policy_net, num_episodes=10, render=args.render_test)
|
| 325 |
+
|
| 326 |
+
# 关闭环境
|
| 327 |
+
env.close()
|
| 328 |
+
test_env.close()
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
if __name__ == "__main__":
|
| 332 |
+
main()
|
examples/tutorials/rl/cart_pole/step_2_rl_dqn.py
ADDED
|
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""
|
| 4 |
+
深度强化学习 + 基于值 + 模型无关 + 异策略 + 离线学习
|
| 5 |
+
|
| 6 |
+
强化学习算法
|
| 7 |
+
│
|
| 8 |
+
├── 🔵 基于值的方法 (Value-Based)
|
| 9 |
+
│ ├── 传统: Q-learning, SARSA
|
| 10 |
+
│ └── 🎯 DQN ← 在这里!
|
| 11 |
+
│ ├── DQN
|
| 12 |
+
│ ├── Double DQN
|
| 13 |
+
│ ├── Dueling DQN
|
| 14 |
+
│ └── Rainbow
|
| 15 |
+
│
|
| 16 |
+
├── 🔴 基于策略的方法 (Policy-Based)
|
| 17 |
+
│ ├── REINFORCE
|
| 18 |
+
│ ├── PPO
|
| 19 |
+
│ └── TRPO
|
| 20 |
+
│
|
| 21 |
+
└── 🟣 演员-评论家方法 (Actor-Critic)
|
| 22 |
+
├── A2C/A3C
|
| 23 |
+
├── SAC
|
| 24 |
+
└── TD3
|
| 25 |
+
|
| 26 |
+
贝尔曼方程
|
| 27 |
+
|
| 28 |
+
贝尔曼的洞察(1957):
|
| 29 |
+
"最优策略有这样的性质:无论初始状态和初始决策如何,其余的决策必须构成一个以第一个决策产生的状态为初始状态的最优策略。"
|
| 30 |
+
翻译成人话:
|
| 31 |
+
如果A→B→C是最优路径,那么B→C也必须是最优路径!
|
| 32 |
+
|
| 33 |
+
备注:所以 DQN 是基于贝尔曼假设的,当环境不符合这个假设时,该方法不成立。
|
| 34 |
+
|
| 35 |
+
"""
|
| 36 |
+
from collections import deque
|
| 37 |
+
import random
|
| 38 |
+
|
| 39 |
+
import gymnasium as gym
|
| 40 |
+
import numpy as np
|
| 41 |
+
import matplotlib.pyplot as plt
|
| 42 |
+
import torch
|
| 43 |
+
import torch.nn as nn
|
| 44 |
+
import torch.optim as optim
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
# 神经网络定义
|
| 48 |
+
class DQN(nn.Module):
|
| 49 |
+
def __init__(self, state_size: int, action_size: int):
|
| 50 |
+
super(DQN, self).__init__()
|
| 51 |
+
self.fc1 = nn.Linear(state_size, 64)
|
| 52 |
+
self.fc2 = nn.Linear(64, 64)
|
| 53 |
+
self.fc3 = nn.Linear(64, action_size)
|
| 54 |
+
self.relu = nn.ReLU()
|
| 55 |
+
|
| 56 |
+
def forward(self, x):
|
| 57 |
+
x = self.relu(self.fc1(x))
|
| 58 |
+
x = self.relu(self.fc2(x))
|
| 59 |
+
return self.fc3(x)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
# 经验回放缓冲区
|
| 63 |
+
class ReplayBuffer:
|
| 64 |
+
def __init__(self, capacity: int):
|
| 65 |
+
self.buffer = deque(maxlen=capacity)
|
| 66 |
+
|
| 67 |
+
def push(self, state, action, reward, next_state, done):
|
| 68 |
+
"""
|
| 69 |
+
:param state: 状态 = [位置, 速度, 角度, 角速度]
|
| 70 |
+
:param action:
|
| 71 |
+
:param reward: float
|
| 72 |
+
:param next_state:
|
| 73 |
+
:param done:
|
| 74 |
+
:return:
|
| 75 |
+
"""
|
| 76 |
+
self.buffer.append((state, action, reward, next_state, done))
|
| 77 |
+
|
| 78 |
+
def sample(self, batch_size):
|
| 79 |
+
batch = random.sample(self.buffer, batch_size)
|
| 80 |
+
states, actions, rewards, next_states, dones = zip(*batch)
|
| 81 |
+
result = (
|
| 82 |
+
np.array(states),
|
| 83 |
+
np.array(actions),
|
| 84 |
+
np.array(rewards),
|
| 85 |
+
np.array(next_states),
|
| 86 |
+
np.array(dones)
|
| 87 |
+
)
|
| 88 |
+
return result
|
| 89 |
+
|
| 90 |
+
def __len__(self):
|
| 91 |
+
return len(self.buffer)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
# DQN智能体
|
| 95 |
+
class DQNAgent:
|
| 96 |
+
def __init__(self, state_size: int, action_size: int):
|
| 97 |
+
self.state_size = state_size
|
| 98 |
+
self.action_size = action_size
|
| 99 |
+
|
| 100 |
+
# 网络
|
| 101 |
+
self.policy_net = DQN(state_size, action_size)
|
| 102 |
+
self.target_net = DQN(state_size, action_size)
|
| 103 |
+
self.target_net.load_state_dict(self.policy_net.state_dict())
|
| 104 |
+
|
| 105 |
+
# 优化器
|
| 106 |
+
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=0.001)
|
| 107 |
+
|
| 108 |
+
# 超参数
|
| 109 |
+
self.gamma = 0.99
|
| 110 |
+
self.epsilon = 1.0
|
| 111 |
+
self.epsilon_min = 0.01
|
| 112 |
+
self.epsilon_decay = 0.995
|
| 113 |
+
self.batch_size = 64
|
| 114 |
+
self.buffer = ReplayBuffer(10000)
|
| 115 |
+
|
| 116 |
+
# 目标网络更新频率
|
| 117 |
+
self.target_update = 10
|
| 118 |
+
|
| 119 |
+
def select_action(self, state, training=True):
|
| 120 |
+
if training and np.random.random() < self.epsilon:
|
| 121 |
+
return random.randrange(self.action_size)
|
| 122 |
+
|
| 123 |
+
state = torch.FloatTensor(state).unsqueeze(0)
|
| 124 |
+
with torch.no_grad():
|
| 125 |
+
q_values = self.policy_net(state)
|
| 126 |
+
return q_values.argmax().item()
|
| 127 |
+
|
| 128 |
+
def train_step(self):
|
| 129 |
+
if len(self.buffer) < self.batch_size:
|
| 130 |
+
return
|
| 131 |
+
|
| 132 |
+
# 从缓冲区采样
|
| 133 |
+
states, actions, rewards, next_states, dones = self.buffer.sample(self.batch_size)
|
| 134 |
+
|
| 135 |
+
# 转换为张量
|
| 136 |
+
states = torch.FloatTensor(states)
|
| 137 |
+
actions = torch.LongTensor(actions).unsqueeze(1)
|
| 138 |
+
rewards = torch.FloatTensor(rewards).unsqueeze(1)
|
| 139 |
+
next_states = torch.FloatTensor(next_states)
|
| 140 |
+
dones = torch.FloatTensor(dones).unsqueeze(1)
|
| 141 |
+
|
| 142 |
+
# 计算当前Q值
|
| 143 |
+
actions_logits = self.policy_net(states)
|
| 144 |
+
# actions_logits shape: [batch_size, action_size]
|
| 145 |
+
current_q = torch.gather(actions_logits, dim=1, index=actions)
|
| 146 |
+
# current_q shape: [batch_size, 1]
|
| 147 |
+
|
| 148 |
+
# 计算目标Q值
|
| 149 |
+
with torch.no_grad():
|
| 150 |
+
next_actions_logits = self.target_net(next_states)
|
| 151 |
+
# next_actions_logits shape: [batch_size, action_size]
|
| 152 |
+
next_q, _ = torch.max(next_actions_logits, 1)
|
| 153 |
+
next_q = torch.unsqueeze(next_q, 1)
|
| 154 |
+
# next_q shape: [batch_size, 1]
|
| 155 |
+
# 贝尔曼方程
|
| 156 |
+
target_q = rewards + (1 - dones) * self.gamma * next_q
|
| 157 |
+
|
| 158 |
+
# 计算损失
|
| 159 |
+
# current_q 预测采取当前动作后未来能获取的总奖励。并对远期奖励降权。
|
| 160 |
+
# target_q 上一次预测未来能获取的总奖励 = 当前已获得奖励 + 当前预测未来能获取的总奖励
|
| 161 |
+
loss = nn.MSELoss()(current_q, target_q)
|
| 162 |
+
|
| 163 |
+
# 引入 target_net 而不是直接使用 policy_net 是因为:
|
| 164 |
+
# 每一个 train_step 只训练了一个 batch,但训练本身具有随机性,
|
| 165 |
+
# 因此可以认为,只在训练多个 batch 之后对模型的优化才真实有效。
|
| 166 |
+
# 所以设置为定期将 target_net 与 policy_net 同步。
|
| 167 |
+
|
| 168 |
+
# 优化
|
| 169 |
+
self.optimizer.zero_grad()
|
| 170 |
+
loss.backward()
|
| 171 |
+
torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 1.0)
|
| 172 |
+
self.optimizer.step()
|
| 173 |
+
|
| 174 |
+
# 衰减探索率
|
| 175 |
+
if self.epsilon > self.epsilon_min:
|
| 176 |
+
self.epsilon *= self.epsilon_decay
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
# 训练函数
|
| 180 |
+
def train_dqn():
|
| 181 |
+
env = gym.make("CartPole-v1")
|
| 182 |
+
state_size = env.observation_space.shape[0]
|
| 183 |
+
action_size = env.action_space.n
|
| 184 |
+
|
| 185 |
+
agent = DQNAgent(state_size, action_size)
|
| 186 |
+
episodes = 500
|
| 187 |
+
rewards = []
|
| 188 |
+
|
| 189 |
+
for episode in range(episodes):
|
| 190 |
+
state, _ = env.reset()
|
| 191 |
+
total_reward = 0
|
| 192 |
+
done = False
|
| 193 |
+
|
| 194 |
+
while not done:
|
| 195 |
+
# 选择动作
|
| 196 |
+
action = agent.select_action(state)
|
| 197 |
+
|
| 198 |
+
# 执行动作
|
| 199 |
+
next_state, reward, terminated, truncated, _ = env.step(action)
|
| 200 |
+
done = terminated or truncated
|
| 201 |
+
|
| 202 |
+
# 存储经验
|
| 203 |
+
agent.buffer.push(state, action, reward, next_state, done)
|
| 204 |
+
|
| 205 |
+
# 训练
|
| 206 |
+
agent.train_step()
|
| 207 |
+
|
| 208 |
+
state = next_state
|
| 209 |
+
total_reward += reward
|
| 210 |
+
|
| 211 |
+
# 更新目标网络
|
| 212 |
+
if episode % agent.target_update == 0:
|
| 213 |
+
agent.target_net.load_state_dict(agent.policy_net.state_dict())
|
| 214 |
+
|
| 215 |
+
rewards.append(total_reward)
|
| 216 |
+
|
| 217 |
+
if (episode + 1) % 50 == 0:
|
| 218 |
+
avg_reward = np.mean(rewards[-50:])
|
| 219 |
+
print(f"Episode {episode + 1}, Avg Reward: {avg_reward:.2f}, Epsilon: {agent.epsilon:.3f}")
|
| 220 |
+
|
| 221 |
+
env.close()
|
| 222 |
+
|
| 223 |
+
# 可视化
|
| 224 |
+
plt.figure(figsize=(12, 4))
|
| 225 |
+
|
| 226 |
+
plt.subplot(1, 2, 1)
|
| 227 |
+
plt.plot(rewards)
|
| 228 |
+
plt.xlabel('Episode')
|
| 229 |
+
plt.ylabel('Reward')
|
| 230 |
+
plt.title('Training Rewards')
|
| 231 |
+
|
| 232 |
+
plt.subplot(1, 2, 2)
|
| 233 |
+
window = 20
|
| 234 |
+
moving_avg = [np.mean(rewards[max(0, i - window):i + 1]) for i in range(len(rewards))]
|
| 235 |
+
plt.plot(moving_avg)
|
| 236 |
+
plt.xlabel('Episode')
|
| 237 |
+
plt.ylabel('Moving Avg Reward')
|
| 238 |
+
plt.title(f'Moving Average (window={window})')
|
| 239 |
+
|
| 240 |
+
plt.tight_layout()
|
| 241 |
+
plt.show()
|
| 242 |
+
|
| 243 |
+
return agent
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
# 运行训练
|
| 247 |
+
agent = train_dqn()
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
if __name__ == "__main__":
|
| 251 |
+
pass
|
examples/tutorials/rlhf/gpt2_sst2/step_1_prepare_data.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""
|
| 4 |
+
或使用命令行
|
| 5 |
+
pip install modelscope
|
| 6 |
+
modelscope download \
|
| 7 |
+
--model 'qgyd2021/Qwen3-8B-sft-deepspeed' \
|
| 8 |
+
--local_dir '/root/autodl-tmp/trained_models/Qwen3-8B-sft-deepspeed'
|
| 9 |
+
|
| 10 |
+
"""
|
| 11 |
+
import argparse
|
| 12 |
+
import os
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
import platform
|
| 15 |
+
|
| 16 |
+
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
|
| 17 |
+
|
| 18 |
+
if platform.system() in ("Windows", "Darwin"):
|
| 19 |
+
from project_settings import project_path, temp_directory
|
| 20 |
+
else:
|
| 21 |
+
project_path = os.path.abspath("../../../")
|
| 22 |
+
project_path = Path(project_path)
|
| 23 |
+
temp_directory = Path("/root/autodl-tmp/OpenMiniMind/temp")
|
| 24 |
+
|
| 25 |
+
from modelscope import snapshot_download
|
| 26 |
+
# from huggingface_hub import snapshot_download
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def get_args():
|
| 30 |
+
parser = argparse.ArgumentParser()
|
| 31 |
+
parser.add_argument("--repo_id", default="Qwen/Qwen2.5-0.5B", type=str)
|
| 32 |
+
parser.add_argument(
|
| 33 |
+
"--local_dir",
|
| 34 |
+
default=(temp_directory / "../trained_models/Qwen/Qwen2.5-0.5B").as_posix(),
|
| 35 |
+
type=str
|
| 36 |
+
)
|
| 37 |
+
args = parser.parse_args()
|
| 38 |
+
return args
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def main():
|
| 42 |
+
args = get_args()
|
| 43 |
+
|
| 44 |
+
#modelscope
|
| 45 |
+
# snapshot_download(
|
| 46 |
+
# model_id=args.repo_id,
|
| 47 |
+
# local_dir=args.local_dir,
|
| 48 |
+
# )
|
| 49 |
+
#huggingface_hub
|
| 50 |
+
snapshot_download(
|
| 51 |
+
repo_type="model",
|
| 52 |
+
repo_id=args.repo_id,
|
| 53 |
+
local_dir=args.local_dir,
|
| 54 |
+
)
|
| 55 |
+
return
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
if __name__ == "__main__":
|
| 59 |
+
main()
|
examples/tutorials/rlhf/gpt2_sst2/step_2_train_sft_model.py
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""
|
| 4 |
+
用sst的句子训练gpt2模型,让其随机生成一些评论。
|
| 5 |
+
"""
|
| 6 |
+
import argparse
|
| 7 |
+
import os
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
import platform
|
| 10 |
+
|
| 11 |
+
if platform.system() in ("Windows", "Darwin"):
|
| 12 |
+
from project_settings import project_path, temp_directory
|
| 13 |
+
else:
|
| 14 |
+
project_path = os.path.abspath("../../../")
|
| 15 |
+
project_path = Path(project_path)
|
| 16 |
+
temp_directory = Path("/root/autodl-tmp/OpenMiniMind/temp")
|
| 17 |
+
|
| 18 |
+
from datasets import load_dataset
|
| 19 |
+
import torch
|
| 20 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, DataCollatorForLanguageModeling, Trainer, TrainingArguments
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def get_args():
|
| 24 |
+
parser = argparse.ArgumentParser()
|
| 25 |
+
parser.add_argument(
|
| 26 |
+
"--model_name",
|
| 27 |
+
# default="openai-community/gpt2",
|
| 28 |
+
default=(project_path / "pretrained_models/openai-community/gpt2").as_posix(),
|
| 29 |
+
type=str
|
| 30 |
+
),
|
| 31 |
+
parser.add_argument(
|
| 32 |
+
"--dataset_path",
|
| 33 |
+
default="stanfordnlp/sst2",
|
| 34 |
+
type=str
|
| 35 |
+
),
|
| 36 |
+
parser.add_argument("--dataset_name", default=None, type=str),
|
| 37 |
+
parser.add_argument("--dataset_split", default=None, type=str),
|
| 38 |
+
parser.add_argument(
|
| 39 |
+
"--dataset_cache_dir",
|
| 40 |
+
default=(temp_directory / "hub_datasets").as_posix(),
|
| 41 |
+
type=str
|
| 42 |
+
),
|
| 43 |
+
parser.add_argument(
|
| 44 |
+
"--model_cache_dir",
|
| 45 |
+
default=(temp_directory / "hub_models").as_posix(),
|
| 46 |
+
type=str
|
| 47 |
+
),
|
| 48 |
+
parser.add_argument("--dataset_streaming", default=None, type=str),
|
| 49 |
+
parser.add_argument("--valid_dataset_size", default=1000, type=int),
|
| 50 |
+
parser.add_argument("--shuffle_buffer_size", default=5000, type=int),
|
| 51 |
+
|
| 52 |
+
parser.add_argument(
|
| 53 |
+
"--output_model_dir",
|
| 54 |
+
default=(project_path / "trained_models/gpt2-sst2-generation-20260213-2048").as_posix(),
|
| 55 |
+
type=str
|
| 56 |
+
),
|
| 57 |
+
|
| 58 |
+
parser.add_argument(
|
| 59 |
+
"--num_workers",
|
| 60 |
+
default=None if platform.system() in ("Windows", "Darwin") else os.cpu_count() // 2,
|
| 61 |
+
type=int
|
| 62 |
+
),
|
| 63 |
+
parser.add_argument(
|
| 64 |
+
"--device",
|
| 65 |
+
default=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
|
| 66 |
+
type=int
|
| 67 |
+
),
|
| 68 |
+
args = parser.parse_args()
|
| 69 |
+
return args
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def main():
|
| 73 |
+
args = get_args()
|
| 74 |
+
|
| 75 |
+
model = AutoModelForCausalLM.from_pretrained(args.model_name)
|
| 76 |
+
model = model.to(args.device)
|
| 77 |
+
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
|
| 78 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 79 |
+
|
| 80 |
+
dataset_dict = load_dataset(
|
| 81 |
+
path=args.dataset_path,
|
| 82 |
+
name=args.dataset_name,
|
| 83 |
+
split=args.dataset_split,
|
| 84 |
+
cache_dir=args.dataset_cache_dir,
|
| 85 |
+
# num_proc=args.num_workers if not args.dataset_streaming else None,
|
| 86 |
+
streaming=args.dataset_streaming,
|
| 87 |
+
)
|
| 88 |
+
train_dataset = dataset_dict["train"]
|
| 89 |
+
valid_dataset = dataset_dict["validation"]
|
| 90 |
+
# test_dataset = dataset_dict["test"]
|
| 91 |
+
|
| 92 |
+
def format_func(example):
|
| 93 |
+
sentence = example["sentence"]
|
| 94 |
+
sentence += tokenizer.eos_token
|
| 95 |
+
tokenized = tokenizer(sentence)
|
| 96 |
+
input_ids = tokenized["input_ids"]
|
| 97 |
+
attention_mask = tokenized["attention_mask"]
|
| 98 |
+
# print(input_ids)
|
| 99 |
+
# print(attention_mask)
|
| 100 |
+
result = {
|
| 101 |
+
"input_ids": input_ids,
|
| 102 |
+
"attention_mask": attention_mask,
|
| 103 |
+
}
|
| 104 |
+
return result
|
| 105 |
+
|
| 106 |
+
train_dataset = train_dataset.map(
|
| 107 |
+
format_func,
|
| 108 |
+
batched=False,
|
| 109 |
+
remove_columns=train_dataset.column_names,
|
| 110 |
+
)
|
| 111 |
+
valid_dataset = valid_dataset.map(
|
| 112 |
+
format_func,
|
| 113 |
+
batched=False,
|
| 114 |
+
remove_columns=valid_dataset.column_names,
|
| 115 |
+
)
|
| 116 |
+
train_dataset = train_dataset.filter(
|
| 117 |
+
function=lambda x: len(x["input_ids"]) > 5
|
| 118 |
+
)
|
| 119 |
+
valid_dataset = valid_dataset.filter(
|
| 120 |
+
function=lambda x: len(x["input_ids"]) > 5
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
data_collator = DataCollatorForLanguageModeling(
|
| 124 |
+
tokenizer,
|
| 125 |
+
mlm=False
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
training_args = TrainingArguments(
|
| 129 |
+
output_dir=args.output_model_dir,
|
| 130 |
+
# overwrite_output_dir=True,
|
| 131 |
+
num_train_epochs=1,
|
| 132 |
+
per_device_train_batch_size=16,
|
| 133 |
+
per_device_eval_batch_size=16,
|
| 134 |
+
eval_strategy="steps",
|
| 135 |
+
eval_steps=100,
|
| 136 |
+
save_strategy="steps",
|
| 137 |
+
save_steps=100,
|
| 138 |
+
save_total_limit=2,
|
| 139 |
+
logging_steps=100,
|
| 140 |
+
learning_rate=5e-5,
|
| 141 |
+
warmup_steps=500,
|
| 142 |
+
weight_decay=0.01,
|
| 143 |
+
fp16=torch.cuda.is_available(),
|
| 144 |
+
dataloader_num_workers=args.num_workers or 0,
|
| 145 |
+
remove_unused_columns=False,
|
| 146 |
+
load_best_model_at_end=False,
|
| 147 |
+
# metric_for_best_model="eval_loss",
|
| 148 |
+
# greater_is_better=False,
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
trainer = Trainer(
|
| 152 |
+
model=model,
|
| 153 |
+
args=training_args,
|
| 154 |
+
data_collator=data_collator,
|
| 155 |
+
train_dataset=train_dataset,
|
| 156 |
+
eval_dataset=valid_dataset,
|
| 157 |
+
tokenizer=tokenizer,
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
trainer.train()
|
| 161 |
+
trainer.save_model()
|
| 162 |
+
return
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
if __name__ == "__main__":
|
| 166 |
+
main()
|
examples/tutorials/rlhf/gpt2_sst2/step_3_train_reward_model.py
ADDED
|
@@ -0,0 +1,295 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
import argparse
|
| 4 |
+
import os
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
import platform
|
| 7 |
+
from typing import Any, Dict, List, Optional, Union, Tuple
|
| 8 |
+
|
| 9 |
+
if platform.system() in ("Windows", "Darwin"):
|
| 10 |
+
from project_settings import project_path, temp_directory
|
| 11 |
+
else:
|
| 12 |
+
project_path = os.path.abspath("../../../")
|
| 13 |
+
project_path = Path(project_path)
|
| 14 |
+
temp_directory = Path("/root/autodl-tmp/OpenMiniMind/temp")
|
| 15 |
+
|
| 16 |
+
from datasets import load_dataset
|
| 17 |
+
import numpy as np
|
| 18 |
+
import torch
|
| 19 |
+
import torch.nn as nn
|
| 20 |
+
import torch.nn.functional as F
|
| 21 |
+
from transformers import (AutoModelForCausalLM,
|
| 22 |
+
AutoTokenizer,
|
| 23 |
+
GPT2PreTrainedModel, GPT2Config, GPT2Model,
|
| 24 |
+
DataCollatorWithPadding,
|
| 25 |
+
Trainer, TrainingArguments
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def get_args():
|
| 30 |
+
parser = argparse.ArgumentParser()
|
| 31 |
+
parser.add_argument(
|
| 32 |
+
"--model_name",
|
| 33 |
+
# default="openai-community/gpt2",
|
| 34 |
+
default=(project_path / "pretrained_models/openai-community/gpt2").as_posix(),
|
| 35 |
+
type=str
|
| 36 |
+
),
|
| 37 |
+
parser.add_argument(
|
| 38 |
+
"--dataset_path",
|
| 39 |
+
default="stanfordnlp/sst2",
|
| 40 |
+
type=str
|
| 41 |
+
),
|
| 42 |
+
parser.add_argument("--dataset_name", default=None, type=str),
|
| 43 |
+
parser.add_argument("--dataset_split", default=None, type=str),
|
| 44 |
+
parser.add_argument(
|
| 45 |
+
"--dataset_cache_dir",
|
| 46 |
+
default=(temp_directory / "hub_datasets").as_posix(),
|
| 47 |
+
type=str
|
| 48 |
+
),
|
| 49 |
+
parser.add_argument(
|
| 50 |
+
"--model_cache_dir",
|
| 51 |
+
default=(temp_directory / "hub_models").as_posix(),
|
| 52 |
+
type=str
|
| 53 |
+
),
|
| 54 |
+
parser.add_argument("--dataset_streaming", default=None, type=str),
|
| 55 |
+
parser.add_argument("--valid_dataset_size", default=1000, type=int),
|
| 56 |
+
parser.add_argument("--shuffle_buffer_size", default=5000, type=int),
|
| 57 |
+
|
| 58 |
+
parser.add_argument(
|
| 59 |
+
"--output_model_dir",
|
| 60 |
+
default=(project_path / "trained_models/gpt2-sst2-reward-20260213-2122").as_posix(),
|
| 61 |
+
type=str
|
| 62 |
+
),
|
| 63 |
+
parser.add_argument(
|
| 64 |
+
"--num_workers",
|
| 65 |
+
default=None if platform.system() in ("Windows", "Darwin") else os.cpu_count() // 2,
|
| 66 |
+
type=int
|
| 67 |
+
),
|
| 68 |
+
parser.add_argument(
|
| 69 |
+
"--device",
|
| 70 |
+
default=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
|
| 71 |
+
type=str
|
| 72 |
+
),
|
| 73 |
+
args = parser.parse_args()
|
| 74 |
+
return args
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class RewardHead(nn.Module):
|
| 78 |
+
def __init__(self, hidden_size: int):
|
| 79 |
+
super().__init__()
|
| 80 |
+
self.hidden_size = hidden_size
|
| 81 |
+
self.linear = nn.Linear(self.hidden_size, 1)
|
| 82 |
+
self._post_init()
|
| 83 |
+
|
| 84 |
+
def _post_init(self):
|
| 85 |
+
nn.init.normal_(
|
| 86 |
+
self.linear.weight,
|
| 87 |
+
std=(1.0 / np.sqrt(self.hidden_size + 1))
|
| 88 |
+
)
|
| 89 |
+
nn.init.zeros_(self.linear.bias)
|
| 90 |
+
|
| 91 |
+
def forward(self, hidden_states):
|
| 92 |
+
# hidden_states shape: [batch_size, seq_len, hidden_size]
|
| 93 |
+
reward_logits = self.linear(hidden_states)
|
| 94 |
+
# reward_logits shape: [batch_size, seq_len, 1]
|
| 95 |
+
return reward_logits
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
class GPT2RewardModel(GPT2PreTrainedModel):
|
| 99 |
+
def __init__(self, config: GPT2Config):
|
| 100 |
+
super().__init__(config)
|
| 101 |
+
self.transformer = GPT2Model(config)
|
| 102 |
+
self.reward_head = RewardHead(config.hidden_size)
|
| 103 |
+
self.post_init()
|
| 104 |
+
|
| 105 |
+
def forward(
|
| 106 |
+
self,
|
| 107 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 108 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 109 |
+
) -> Union[Tuple, torch.Tensor]:
|
| 110 |
+
transformer_outputs = self.transformer(
|
| 111 |
+
input_ids=input_ids,
|
| 112 |
+
attention_mask=attention_mask,
|
| 113 |
+
output_hidden_states=True
|
| 114 |
+
)
|
| 115 |
+
last_hidden_state = transformer_outputs.hidden_states[-1]
|
| 116 |
+
# last_hidden_state shape: [batch_size, seq_len, hidden_size]
|
| 117 |
+
rewards_logits = self.reward_head(last_hidden_state)
|
| 118 |
+
# rewards_logits shape: [batch_size, seq_len, 1]
|
| 119 |
+
rewards_logits = torch.squeeze(rewards_logits, -1)
|
| 120 |
+
# rewards_logits shape: [batch_size, seq_len]
|
| 121 |
+
rewards = torch.sigmoid(rewards_logits)
|
| 122 |
+
# rewards shape: [batch_size, seq_len]
|
| 123 |
+
return rewards
|
| 124 |
+
|
| 125 |
+
@classmethod
|
| 126 |
+
def from_pretrained(cls, model_name_or_path, *model_args, **kwargs):
|
| 127 |
+
config = GPT2Config.from_pretrained(model_name_or_path)
|
| 128 |
+
model = cls(config)
|
| 129 |
+
pretrained_model = GPT2Model.from_pretrained(model_name_or_path)
|
| 130 |
+
model.transformer.load_state_dict(pretrained_model.state_dict(), strict=False)
|
| 131 |
+
return model
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
class SST2RewardTrainer(Trainer):
|
| 135 |
+
def compute_loss(
|
| 136 |
+
self,
|
| 137 |
+
model: nn.Module,
|
| 138 |
+
inputs: dict[str, Union[torch.Tensor, Any]],
|
| 139 |
+
return_outputs: bool = False,
|
| 140 |
+
num_items_in_batch: Optional[torch.Tensor] = None,
|
| 141 |
+
):
|
| 142 |
+
rewards = model(
|
| 143 |
+
input_ids=inputs["input_ids"],
|
| 144 |
+
attention_mask=inputs["attention_mask"]
|
| 145 |
+
)
|
| 146 |
+
sequence_lengths = inputs["attention_mask"].sum(dim=1) - 1
|
| 147 |
+
batch_indices = torch.arange(rewards.size(0), device=rewards.device)
|
| 148 |
+
sequence_reward = rewards[batch_indices, sequence_lengths]
|
| 149 |
+
# sequence_reward shape: [batch_size,]
|
| 150 |
+
loss = F.mse_loss(
|
| 151 |
+
sequence_reward,
|
| 152 |
+
inputs["score"].float()
|
| 153 |
+
)
|
| 154 |
+
if return_outputs:
|
| 155 |
+
return loss, {
|
| 156 |
+
"loss": loss,
|
| 157 |
+
"predictions": sequence_reward.detach(),
|
| 158 |
+
}
|
| 159 |
+
return loss
|
| 160 |
+
|
| 161 |
+
def prediction_step(
|
| 162 |
+
self,
|
| 163 |
+
model: nn.Module,
|
| 164 |
+
inputs: dict[str, Union[torch.Tensor, Any]],
|
| 165 |
+
prediction_loss_only: bool,
|
| 166 |
+
ignore_keys: Optional[list[str]] = None,
|
| 167 |
+
) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
| 168 |
+
with torch.no_grad():
|
| 169 |
+
loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
|
| 170 |
+
|
| 171 |
+
if prediction_loss_only:
|
| 172 |
+
return loss, None, None
|
| 173 |
+
|
| 174 |
+
predictions = outputs["predictions"]
|
| 175 |
+
labels = inputs["score"].float()
|
| 176 |
+
return loss, predictions, labels
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def compute_metrics(eval_pred):
|
| 180 |
+
"""计算评估指标"""
|
| 181 |
+
predictions, labels = eval_pred
|
| 182 |
+
predictions = torch.tensor(predictions)
|
| 183 |
+
labels = torch.tensor(labels)
|
| 184 |
+
|
| 185 |
+
error = (predictions - labels).abs()
|
| 186 |
+
|
| 187 |
+
return {
|
| 188 |
+
"mean_error": error.mean().item(),
|
| 189 |
+
"std_error": error.std().item(),
|
| 190 |
+
"reward_mean": predictions.mean().item(),
|
| 191 |
+
"reward_min": predictions.min().item(),
|
| 192 |
+
"reward_max": predictions.max().item(),
|
| 193 |
+
"score_mean": labels.mean().item(),
|
| 194 |
+
}
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def main():
|
| 198 |
+
args = get_args()
|
| 199 |
+
|
| 200 |
+
model = GPT2RewardModel.from_pretrained(args.model_name)
|
| 201 |
+
model = model.to(args.device)
|
| 202 |
+
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
|
| 203 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 204 |
+
|
| 205 |
+
# tokenized = tokenizer(
|
| 206 |
+
# "this is a good day",
|
| 207 |
+
# # "this is ",
|
| 208 |
+
# return_tensors="pt"
|
| 209 |
+
# )
|
| 210 |
+
# output_dict = model(**tokenized)
|
| 211 |
+
|
| 212 |
+
dataset_dict = load_dataset(
|
| 213 |
+
path=args.dataset_path,
|
| 214 |
+
name=args.dataset_name,
|
| 215 |
+
split=args.dataset_split,
|
| 216 |
+
cache_dir=args.dataset_cache_dir,
|
| 217 |
+
# num_proc=args.num_workers if not args.dataset_streaming else None,
|
| 218 |
+
streaming=args.dataset_streaming,
|
| 219 |
+
)
|
| 220 |
+
train_dataset = dataset_dict["train"]
|
| 221 |
+
valid_dataset = dataset_dict["validation"]
|
| 222 |
+
# test_dataset = dataset_dict["test"]
|
| 223 |
+
|
| 224 |
+
def format_func(example):
|
| 225 |
+
sentence: str = example["sentence"]
|
| 226 |
+
score: float = float(example["label"])
|
| 227 |
+
sentence += tokenizer.eos_token
|
| 228 |
+
tokenized = tokenizer(sentence)
|
| 229 |
+
input_ids = tokenized["input_ids"]
|
| 230 |
+
attention_mask = tokenized["attention_mask"]
|
| 231 |
+
result = {
|
| 232 |
+
"input_ids": input_ids,
|
| 233 |
+
"attention_mask": attention_mask,
|
| 234 |
+
"score": score,
|
| 235 |
+
}
|
| 236 |
+
return result
|
| 237 |
+
|
| 238 |
+
train_dataset = train_dataset.map(
|
| 239 |
+
format_func,
|
| 240 |
+
batched=False,
|
| 241 |
+
remove_columns=train_dataset.column_names,
|
| 242 |
+
)
|
| 243 |
+
valid_dataset = valid_dataset.map(
|
| 244 |
+
format_func,
|
| 245 |
+
batched=False,
|
| 246 |
+
remove_columns=valid_dataset.column_names,
|
| 247 |
+
)
|
| 248 |
+
train_dataset = train_dataset.filter(
|
| 249 |
+
function=lambda x: len(x["input_ids"]) > 6
|
| 250 |
+
)
|
| 251 |
+
valid_dataset = valid_dataset.filter(
|
| 252 |
+
function=lambda x: len(x["input_ids"]) > 6
|
| 253 |
+
)
|
| 254 |
+
data_collator = DataCollatorWithPadding(tokenizer)
|
| 255 |
+
|
| 256 |
+
training_args = TrainingArguments(
|
| 257 |
+
output_dir=args.output_model_dir,
|
| 258 |
+
# overwrite_output_dir=True,
|
| 259 |
+
num_train_epochs=1,
|
| 260 |
+
per_device_train_batch_size=16,
|
| 261 |
+
per_device_eval_batch_size=16,
|
| 262 |
+
eval_strategy="steps",
|
| 263 |
+
eval_steps=500,
|
| 264 |
+
save_strategy="steps",
|
| 265 |
+
save_steps=500,
|
| 266 |
+
save_total_limit=2,
|
| 267 |
+
logging_steps=500,
|
| 268 |
+
learning_rate=5e-5,
|
| 269 |
+
warmup_steps=1000,
|
| 270 |
+
weight_decay=0.01,
|
| 271 |
+
fp16=torch.cuda.is_available(),
|
| 272 |
+
dataloader_num_workers=args.num_workers or 0,
|
| 273 |
+
remove_unused_columns=False,
|
| 274 |
+
load_best_model_at_end=True,
|
| 275 |
+
metric_for_best_model="eval_loss",
|
| 276 |
+
greater_is_better=False,
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
trainer = SST2RewardTrainer(
|
| 280 |
+
model=model,
|
| 281 |
+
args=training_args,
|
| 282 |
+
data_collator=data_collator,
|
| 283 |
+
train_dataset=train_dataset,
|
| 284 |
+
eval_dataset=valid_dataset,
|
| 285 |
+
tokenizer=tokenizer,
|
| 286 |
+
compute_metrics=compute_metrics,
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
trainer.train()
|
| 290 |
+
trainer.save_model()
|
| 291 |
+
return
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
if __name__ == "__main__":
|
| 295 |
+
main()
|
examples/tutorials/rlhf/gpt2_sst2/step_4_test_reward_model.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
import argparse
|
| 4 |
+
import os
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
import platform
|
| 7 |
+
from typing import Any, Dict, List, Optional, Union, Tuple
|
| 8 |
+
|
| 9 |
+
if platform.system() in ("Windows", "Darwin"):
|
| 10 |
+
from project_settings import project_path, temp_directory
|
| 11 |
+
else:
|
| 12 |
+
project_path = os.path.abspath("../../../")
|
| 13 |
+
project_path = Path(project_path)
|
| 14 |
+
temp_directory = Path("/root/autodl-tmp/OpenMiniMind/temp")
|
| 15 |
+
|
| 16 |
+
from datasets import load_dataset
|
| 17 |
+
import numpy as np
|
| 18 |
+
import torch
|
| 19 |
+
import torch.nn as nn
|
| 20 |
+
from transformers import (AutoTokenizer,
|
| 21 |
+
GPT2PreTrainedModel, GPT2Config, GPT2Model,
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def get_args():
|
| 26 |
+
parser = argparse.ArgumentParser()
|
| 27 |
+
parser.add_argument(
|
| 28 |
+
"--model_name",
|
| 29 |
+
# default="openai-community/gpt2",
|
| 30 |
+
# default=(project_path / "trained_models/gpt2-sst2-reward").as_posix(),
|
| 31 |
+
default=(project_path / "trained_models/gpt2-sst2-reward-20260213-2122").as_posix(),
|
| 32 |
+
type=str
|
| 33 |
+
),
|
| 34 |
+
parser.add_argument(
|
| 35 |
+
"--dataset_path",
|
| 36 |
+
default="stanfordnlp/sst2",
|
| 37 |
+
type=str
|
| 38 |
+
),
|
| 39 |
+
parser.add_argument("--dataset_name", default=None, type=str),
|
| 40 |
+
parser.add_argument("--dataset_split", default=None, type=str),
|
| 41 |
+
parser.add_argument(
|
| 42 |
+
"--dataset_cache_dir",
|
| 43 |
+
default=(temp_directory / "hub_datasets").as_posix(),
|
| 44 |
+
type=str
|
| 45 |
+
),
|
| 46 |
+
parser.add_argument(
|
| 47 |
+
"--model_cache_dir",
|
| 48 |
+
default=(temp_directory / "hub_models").as_posix(),
|
| 49 |
+
type=str
|
| 50 |
+
),
|
| 51 |
+
parser.add_argument("--dataset_streaming", default=None, type=str),
|
| 52 |
+
|
| 53 |
+
parser.add_argument(
|
| 54 |
+
"--num_workers",
|
| 55 |
+
default=None if platform.system() in ("Windows", "Darwin") else os.cpu_count() // 2,
|
| 56 |
+
type=int
|
| 57 |
+
),
|
| 58 |
+
parser.add_argument(
|
| 59 |
+
"--device",
|
| 60 |
+
default=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
|
| 61 |
+
type=str
|
| 62 |
+
),
|
| 63 |
+
args = parser.parse_args()
|
| 64 |
+
return args
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class RewardHead(nn.Module):
|
| 68 |
+
def __init__(self, hidden_size: int):
|
| 69 |
+
super().__init__()
|
| 70 |
+
self.hidden_size = hidden_size
|
| 71 |
+
self.linear = nn.Linear(self.hidden_size, 1)
|
| 72 |
+
self._post_init()
|
| 73 |
+
|
| 74 |
+
def _post_init(self):
|
| 75 |
+
nn.init.normal_(
|
| 76 |
+
self.linear.weight,
|
| 77 |
+
std=(1.0 / np.sqrt(self.hidden_size + 1))
|
| 78 |
+
)
|
| 79 |
+
nn.init.zeros_(self.linear.bias)
|
| 80 |
+
|
| 81 |
+
def forward(self, hidden_states):
|
| 82 |
+
# hidden_states shape: [batch_size, seq_len, hidden_size]
|
| 83 |
+
reward_logits = self.linear(hidden_states)
|
| 84 |
+
# reward_logits shape: [batch_size, seq_len, 1]
|
| 85 |
+
return reward_logits
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
class GPT2RewardModel(GPT2PreTrainedModel):
|
| 89 |
+
def __init__(self, config: GPT2Config):
|
| 90 |
+
super().__init__(config)
|
| 91 |
+
self.transformer = GPT2Model(config)
|
| 92 |
+
self.reward_head = RewardHead(config.hidden_size)
|
| 93 |
+
self.post_init()
|
| 94 |
+
|
| 95 |
+
def forward(
|
| 96 |
+
self,
|
| 97 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 98 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 99 |
+
) -> Union[Tuple, torch.Tensor]:
|
| 100 |
+
transformer_outputs = self.transformer(
|
| 101 |
+
input_ids=input_ids,
|
| 102 |
+
attention_mask=attention_mask,
|
| 103 |
+
output_hidden_states=True
|
| 104 |
+
)
|
| 105 |
+
last_hidden_state = transformer_outputs.hidden_states[-1]
|
| 106 |
+
# last_hidden_state shape: [batch_size, seq_len, hidden_size]
|
| 107 |
+
rewards_logits = self.reward_head(last_hidden_state)
|
| 108 |
+
# rewards_logits shape: [batch_size, seq_len, 1]
|
| 109 |
+
rewards_logits = torch.squeeze(rewards_logits, -1)
|
| 110 |
+
# rewards_logits shape: [batch_size, seq_len]
|
| 111 |
+
rewards = torch.sigmoid(rewards_logits)
|
| 112 |
+
# rewards shape: [batch_size, seq_len]
|
| 113 |
+
return rewards
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def main():
|
| 117 |
+
args = get_args()
|
| 118 |
+
|
| 119 |
+
model = GPT2RewardModel.from_pretrained(args.model_name)
|
| 120 |
+
model = model.to(args.device)
|
| 121 |
+
# model.eval()
|
| 122 |
+
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
|
| 123 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 124 |
+
|
| 125 |
+
dataset_dict = load_dataset(
|
| 126 |
+
path=args.dataset_path,
|
| 127 |
+
name=args.dataset_name,
|
| 128 |
+
split=args.dataset_split,
|
| 129 |
+
cache_dir=args.dataset_cache_dir,
|
| 130 |
+
# num_proc=args.num_workers if not args.dataset_streaming else None,
|
| 131 |
+
streaming=args.dataset_streaming,
|
| 132 |
+
)
|
| 133 |
+
# dataset = dataset_dict["train"]
|
| 134 |
+
dataset = dataset_dict["validation"]
|
| 135 |
+
# dataset = dataset_dict["test"]
|
| 136 |
+
|
| 137 |
+
for example in dataset:
|
| 138 |
+
sentence: str = example["sentence"]
|
| 139 |
+
score: float = float(example["label"])
|
| 140 |
+
# sentence = "this is very good movie, I recommend it."
|
| 141 |
+
|
| 142 |
+
sentence += tokenizer.eos_token
|
| 143 |
+
tokenized = tokenizer(
|
| 144 |
+
sentence,
|
| 145 |
+
return_tensors="pt"
|
| 146 |
+
)
|
| 147 |
+
with torch.no_grad():
|
| 148 |
+
rewards = model(**tokenized)
|
| 149 |
+
rewards = rewards[0]
|
| 150 |
+
rewards = rewards.detach().cpu().numpy()
|
| 151 |
+
last_token_reward = rewards[-1]
|
| 152 |
+
#rewards: {rewards}\n
|
| 153 |
+
msg = f"last_token_reward: {last_token_reward}\nscore: {score}\nsentence: {sentence}\n"
|
| 154 |
+
print(msg)
|
| 155 |
+
|
| 156 |
+
return
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
if __name__ == "__main__":
|
| 160 |
+
main()
|
examples/tutorials/rlhf/gpt2_sst2/step_5_ppo_rlhf.py
ADDED
|
@@ -0,0 +1,430 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
import argparse
|
| 4 |
+
import copy
|
| 5 |
+
import os
|
| 6 |
+
import random
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
import platform
|
| 9 |
+
from typing import Optional, Tuple, List, Dict, Union
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
from torch.utils.data import DataLoader
|
| 16 |
+
from datasets import load_dataset
|
| 17 |
+
from transformers import (
|
| 18 |
+
AutoTokenizer, AutoModelForCausalLM, GPT2PreTrainedModel,
|
| 19 |
+
GPT2Config, GPT2Model, GPT2LMHeadModel, DataCollatorWithPadding
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
# 路径配置
|
| 23 |
+
if platform.system() in ("Windows", "Darwin"):
|
| 24 |
+
from project_settings import project_path, temp_directory
|
| 25 |
+
else:
|
| 26 |
+
project_path = Path(os.path.abspath("../../../"))
|
| 27 |
+
temp_directory = Path("/root/autodl-tmp/OpenMiniMind/temp")
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def get_args():
|
| 31 |
+
parser = argparse.ArgumentParser()
|
| 32 |
+
parser.add_argument("--reward_model_name", type=str,
|
| 33 |
+
default=(project_path / "trained_models/gpt2-sst2-reward").as_posix())
|
| 34 |
+
parser.add_argument("--sft_model_name", type=str,
|
| 35 |
+
default=(project_path / "trained_models/gpt2-sst2-generation").as_posix())
|
| 36 |
+
parser.add_argument("--dataset_path", default="stanfordnlp/sst2", type=str)
|
| 37 |
+
parser.add_argument("--dataset_cache_dir",
|
| 38 |
+
default=(temp_directory / "hub_datasets").as_posix(), type=str)
|
| 39 |
+
parser.add_argument("--model_cache_dir",
|
| 40 |
+
default=(temp_directory / "hub_models").as_posix(), type=str)
|
| 41 |
+
parser.add_argument("--valid_dataset_size", default=1000, type=int)
|
| 42 |
+
|
| 43 |
+
# 训练参数
|
| 44 |
+
parser.add_argument("--batch_size", default=16, type=int) # CPU上用小一点的batch
|
| 45 |
+
parser.add_argument("--ppo_epochs", default=4, type=int)
|
| 46 |
+
parser.add_argument("--mini_batch_size", default=4, type=int)
|
| 47 |
+
parser.add_argument("--kl_beta", default=0.2, type=float)
|
| 48 |
+
parser.add_argument("--gamma", default=1.0, type=float)
|
| 49 |
+
parser.add_argument("--lam", default=0.95, type=float)
|
| 50 |
+
parser.add_argument("--clip_epsilon", default=0.2, type=float)
|
| 51 |
+
parser.add_argument("--lr", default=1e-5, type=float)
|
| 52 |
+
parser.add_argument("--max_epochs", default=10, type=int)
|
| 53 |
+
|
| 54 |
+
# 生成参数
|
| 55 |
+
parser.add_argument("--max_new_tokens", default=32, type=int)
|
| 56 |
+
parser.add_argument("--top_p", default=0.85, type=float)
|
| 57 |
+
parser.add_argument("--temperature", default=0.85, type=float)
|
| 58 |
+
parser.add_argument("--min_response_len", default=5, type=int)
|
| 59 |
+
parser.add_argument("--max_response_len", default=16, type=int)
|
| 60 |
+
|
| 61 |
+
# 其他
|
| 62 |
+
parser.add_argument("--num_workers", default=0 if platform.system() == "Windows" else 2, type=int)
|
| 63 |
+
parser.add_argument("--device", default="cpu", type=str) # 强制用CPU
|
| 64 |
+
|
| 65 |
+
return parser.parse_args()
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class ValueHead(nn.Module):
|
| 69 |
+
"""价值头,为每个token预测一个价值"""
|
| 70 |
+
|
| 71 |
+
def __init__(self, hidden_size: int):
|
| 72 |
+
super().__init__()
|
| 73 |
+
self.linear = nn.Linear(hidden_size, 1)
|
| 74 |
+
self._init_weights()
|
| 75 |
+
|
| 76 |
+
def _init_weights(self):
|
| 77 |
+
nn.init.normal_(self.linear.weight, std=1.0 / np.sqrt(self.linear.in_features + 1))
|
| 78 |
+
nn.init.zeros_(self.linear.bias)
|
| 79 |
+
|
| 80 |
+
def forward(self, hidden_states):
|
| 81 |
+
return self.linear(hidden_states).squeeze(-1)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class GPT2ActorCritic(GPT2PreTrainedModel):
|
| 85 |
+
"""Actor-Critic模型,同时输出logits和values"""
|
| 86 |
+
|
| 87 |
+
def __init__(self, config: GPT2Config):
|
| 88 |
+
super().__init__(config)
|
| 89 |
+
self.lm = GPT2LMHeadModel(config)
|
| 90 |
+
self.value_head = ValueHead(config.hidden_size)
|
| 91 |
+
self.post_init()
|
| 92 |
+
|
| 93 |
+
def forward(self, input_ids, attention_mask=None):
|
| 94 |
+
outputs = self.lm(
|
| 95 |
+
input_ids,
|
| 96 |
+
attention_mask=attention_mask,
|
| 97 |
+
output_hidden_states=True
|
| 98 |
+
)
|
| 99 |
+
# values来自最后一层hidden states
|
| 100 |
+
values = self.value_head(outputs.hidden_states[-1])
|
| 101 |
+
return outputs.logits, values
|
| 102 |
+
|
| 103 |
+
def generate(self, *args, **kwargs):
|
| 104 |
+
return self.lm.generate(*args, **kwargs)
|
| 105 |
+
|
| 106 |
+
@classmethod
|
| 107 |
+
def from_pretrained(cls, pretrained_model_name):
|
| 108 |
+
"""从预训练GPT2LMHeadModel加载"""
|
| 109 |
+
config = GPT2Config.from_pretrained(pretrained_model_name)
|
| 110 |
+
model = cls(config)
|
| 111 |
+
pretrained = GPT2LMHeadModel.from_pretrained(pretrained_model_name)
|
| 112 |
+
model.lm.load_state_dict(pretrained.state_dict(), strict=False)
|
| 113 |
+
return model
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
class GPT2RewardModel(GPT2PreTrainedModel):
|
| 117 |
+
"""奖励模型,为每个token预测奖励"""
|
| 118 |
+
|
| 119 |
+
def __init__(self, config: GPT2Config):
|
| 120 |
+
super().__init__(config)
|
| 121 |
+
self.transformer = GPT2Model(config)
|
| 122 |
+
self.reward_head = nn.Linear(config.hidden_size, 1)
|
| 123 |
+
self.post_init()
|
| 124 |
+
|
| 125 |
+
def forward(self, input_ids, attention_mask=None):
|
| 126 |
+
outputs = self.transformer(
|
| 127 |
+
input_ids,
|
| 128 |
+
attention_mask=attention_mask,
|
| 129 |
+
output_hidden_states=True
|
| 130 |
+
)
|
| 131 |
+
rewards = self.reward_head(outputs.hidden_states[-1]).squeeze(-1)
|
| 132 |
+
return torch.sigmoid(rewards) # [batch, seq_len]
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
class PPOAgent:
|
| 136 |
+
"""PPO训练Agent,封装所有训练逻辑"""
|
| 137 |
+
|
| 138 |
+
def __init__(self, args):
|
| 139 |
+
self.args = args
|
| 140 |
+
self.device = torch.device(args.device)
|
| 141 |
+
|
| 142 |
+
# 加载tokenizer
|
| 143 |
+
self.tokenizer = AutoTokenizer.from_pretrained(args.sft_model_name)
|
| 144 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
| 145 |
+
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
|
| 146 |
+
|
| 147 |
+
# 加载模型
|
| 148 |
+
print("Loading models...")
|
| 149 |
+
self.actor_critic = GPT2ActorCritic.from_pretrained(args.sft_model_name).to(self.device)
|
| 150 |
+
self.reward_model = GPT2RewardModel.from_pretrained(args.reward_model_name).to(self.device)
|
| 151 |
+
self.reward_model.eval()
|
| 152 |
+
|
| 153 |
+
# 参考模型(冻结)
|
| 154 |
+
self.ref_model = copy.deepcopy(self.actor_critic).to(self.device)
|
| 155 |
+
self.ref_model.eval()
|
| 156 |
+
|
| 157 |
+
# 优化器
|
| 158 |
+
self.optimizer = torch.optim.Adam(self.actor_critic.parameters(), lr=args.lr)
|
| 159 |
+
|
| 160 |
+
# 训练状态
|
| 161 |
+
self.training_step = 0
|
| 162 |
+
|
| 163 |
+
def prepare_dataset(self):
|
| 164 |
+
"""准备训练数据集"""
|
| 165 |
+
print("Loading dataset...")
|
| 166 |
+
dataset = load_dataset(
|
| 167 |
+
path=self.args.dataset_path,
|
| 168 |
+
cache_dir=self.args.dataset_cache_dir,
|
| 169 |
+
split="train"
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
def filter_and_truncate(example):
|
| 173 |
+
# 只保留足够长的句子
|
| 174 |
+
tokens = self.tokenizer(example["sentence"])["input_ids"]
|
| 175 |
+
if len(tokens) <= 8:
|
| 176 |
+
return False
|
| 177 |
+
|
| 178 |
+
# 随机截取前2-6个token作为query
|
| 179 |
+
example["query_ids"] = tokens[:random.randint(2, 6)]
|
| 180 |
+
return True
|
| 181 |
+
|
| 182 |
+
dataset = dataset.filter(filter_and_truncate)
|
| 183 |
+
dataset = dataset.select(range(min(len(dataset), 5000))) # CPU上用小数据集
|
| 184 |
+
|
| 185 |
+
return dataset
|
| 186 |
+
|
| 187 |
+
def collect_rollouts(self, batch):
|
| 188 |
+
"""收集一轮交互数据"""
|
| 189 |
+
query_ids_list = []
|
| 190 |
+
response_ids_list = []
|
| 191 |
+
rewards_list = []
|
| 192 |
+
|
| 193 |
+
for i in range(len(batch["query_ids"])):
|
| 194 |
+
query_ids = torch.tensor(batch["query_ids"][i]).to(self.device)
|
| 195 |
+
query_ids_list.append(query_ids)
|
| 196 |
+
|
| 197 |
+
# 生成response
|
| 198 |
+
with torch.no_grad():
|
| 199 |
+
response_len = random.randint(
|
| 200 |
+
self.args.min_response_len,
|
| 201 |
+
self.args.max_response_len
|
| 202 |
+
)
|
| 203 |
+
full_ids = self.actor_critic.generate(
|
| 204 |
+
input_ids=query_ids.unsqueeze(0),
|
| 205 |
+
max_new_tokens=response_len,
|
| 206 |
+
do_sample=True,
|
| 207 |
+
top_p=self.args.top_p,
|
| 208 |
+
temperature=self.args.temperature,
|
| 209 |
+
pad_token_id=self.tokenizer.pad_token_id,
|
| 210 |
+
eos_token_id=self.tokenizer.eos_token_id,
|
| 211 |
+
)[0]
|
| 212 |
+
|
| 213 |
+
response_ids = full_ids[len(query_ids):]
|
| 214 |
+
response_ids_list.append(response_ids)
|
| 215 |
+
|
| 216 |
+
# 计算奖励(只取最后一个token的奖励)
|
| 217 |
+
reward = self.reward_model(
|
| 218 |
+
full_ids.unsqueeze(0),
|
| 219 |
+
attention_mask=torch.ones_like(full_ids).unsqueeze(0)
|
| 220 |
+
)[0, -1]
|
| 221 |
+
# 缩放到[-1, 1]
|
| 222 |
+
rewards_list.append(2 * (reward - 0.5))
|
| 223 |
+
|
| 224 |
+
return query_ids_list, response_ids_list, rewards_list
|
| 225 |
+
|
| 226 |
+
def compute_advantages_and_returns(self, log_probs, values, rewards, masks):
|
| 227 |
+
"""计算GAE advantages和returns"""
|
| 228 |
+
seq_len = rewards.shape[1]
|
| 229 |
+
advantages = torch.zeros_like(rewards)
|
| 230 |
+
returns = torch.zeros_like(rewards)
|
| 231 |
+
|
| 232 |
+
gae = 0
|
| 233 |
+
for t in reversed(range(seq_len)):
|
| 234 |
+
if t == seq_len - 1:
|
| 235 |
+
next_value = 0
|
| 236 |
+
else:
|
| 237 |
+
next_value = values[:, t + 1]
|
| 238 |
+
|
| 239 |
+
delta = rewards[:, t] + self.args.gamma * next_value - values[:, t]
|
| 240 |
+
gae = delta + self.args.gamma * self.args.lam * gae
|
| 241 |
+
advantages[:, t] = gae
|
| 242 |
+
returns[:, t] = advantages[:, t] + values[:, t]
|
| 243 |
+
|
| 244 |
+
# 只对有效位置进行whiten
|
| 245 |
+
advantages = self.masked_whiten(advantages, masks)
|
| 246 |
+
return advantages, returns
|
| 247 |
+
|
| 248 |
+
def masked_whiten(self, values, mask):
|
| 249 |
+
"""带mask的whitening"""
|
| 250 |
+
mask = mask.float()
|
| 251 |
+
mean = (values * mask).sum() / mask.sum()
|
| 252 |
+
var = (((values - mean) * mask) ** 2).sum() / mask.sum()
|
| 253 |
+
whitened = (values - mean) * torch.rsqrt(var + 1e-8)
|
| 254 |
+
return whitened * mask
|
| 255 |
+
|
| 256 |
+
def ppo_step(self, batch_data):
|
| 257 |
+
"""单步PPO更新"""
|
| 258 |
+
(query_ids_list, response_ids_list, old_log_probs,
|
| 259 |
+
advantages, returns, masks) = batch_data
|
| 260 |
+
|
| 261 |
+
# 拼接完整的query+response
|
| 262 |
+
full_ids_list = []
|
| 263 |
+
for q, r in zip(query_ids_list, response_ids_list):
|
| 264 |
+
full_ids_list.append(torch.cat([q, r]))
|
| 265 |
+
|
| 266 |
+
# padding
|
| 267 |
+
padded = self.tokenizer.pad(
|
| 268 |
+
{"input_ids": full_ids_list},
|
| 269 |
+
padding=True,
|
| 270 |
+
return_tensors="pt"
|
| 271 |
+
)
|
| 272 |
+
input_ids = padded["input_ids"].to(self.device)
|
| 273 |
+
attention_mask = padded["attention_mask"].to(self.device)
|
| 274 |
+
|
| 275 |
+
# 前向传播
|
| 276 |
+
logits, values = self.actor_critic(input_ids, attention_mask)
|
| 277 |
+
|
| 278 |
+
# 计算新的log_probs
|
| 279 |
+
log_probs = F.log_softmax(logits[:, :-1, :], dim=-1)
|
| 280 |
+
log_probs = torch.gather(
|
| 281 |
+
log_probs, 2,
|
| 282 |
+
input_ids[:, 1:].unsqueeze(-1)
|
| 283 |
+
).squeeze(-1)
|
| 284 |
+
|
| 285 |
+
# 只保留response部分的log_probs
|
| 286 |
+
response_start = [len(q) for q in query_ids_list]
|
| 287 |
+
new_log_probs = []
|
| 288 |
+
for i, start in enumerate(response_start):
|
| 289 |
+
new_log_probs.append(log_probs[i, start - 1:start - 1 + len(response_ids_list[i])])
|
| 290 |
+
new_log_probs = torch.cat(new_log_probs)
|
| 291 |
+
|
| 292 |
+
# 计算ratio和PPO损失
|
| 293 |
+
old_log_probs = old_log_probs.detach()
|
| 294 |
+
ratio = torch.exp(new_log_probs - old_log_probs)
|
| 295 |
+
|
| 296 |
+
# 裁剪的policy loss
|
| 297 |
+
surr1 = ratio * advantages
|
| 298 |
+
surr2 = torch.clamp(ratio, 1 - self.args.clip_epsilon,
|
| 299 |
+
1 + self.args.clip_epsilon) * advantages
|
| 300 |
+
policy_loss = -torch.min(surr1, surr2).mean()
|
| 301 |
+
|
| 302 |
+
# value loss
|
| 303 |
+
value_pred = []
|
| 304 |
+
for i, start in enumerate(response_start):
|
| 305 |
+
value_pred.append(values[i, start - 1:start - 1 + len(response_ids_list[i])])
|
| 306 |
+
value_pred = torch.cat(value_pred)
|
| 307 |
+
value_loss = F.mse_loss(value_pred, returns)
|
| 308 |
+
|
| 309 |
+
# 总loss
|
| 310 |
+
loss = policy_loss + 0.5 * value_loss
|
| 311 |
+
|
| 312 |
+
return loss, policy_loss, value_loss
|
| 313 |
+
|
| 314 |
+
def train_epoch(self, dataset):
|
| 315 |
+
"""训练一个epoch"""
|
| 316 |
+
total_policy_loss = 0
|
| 317 |
+
total_value_loss = 0
|
| 318 |
+
num_batches = 0
|
| 319 |
+
|
| 320 |
+
for batch_idx in range(0, len(dataset), self.args.batch_size):
|
| 321 |
+
# 1. 收集数据
|
| 322 |
+
batch = dataset[batch_idx:batch_idx + self.args.batch_size]
|
| 323 |
+
query_ids_list, response_ids_list, rewards_list = self.collect_rollouts(batch)
|
| 324 |
+
|
| 325 |
+
# 2. 计算旧的log_probs和values
|
| 326 |
+
old_log_probs_list = []
|
| 327 |
+
values_list = []
|
| 328 |
+
masks_list = []
|
| 329 |
+
|
| 330 |
+
with torch.no_grad():
|
| 331 |
+
for q_ids, r_ids in zip(query_ids_list, response_ids_list):
|
| 332 |
+
full_ids = torch.cat([q_ids, r_ids]).unsqueeze(0).to(self.device)
|
| 333 |
+
attn_mask = torch.ones_like(full_ids)
|
| 334 |
+
|
| 335 |
+
logits, values = self.actor_critic(full_ids, attn_mask)
|
| 336 |
+
|
| 337 |
+
# 计算response部分的log_probs
|
| 338 |
+
log_probs = F.log_softmax(logits[:, :-1, :], dim=-1)
|
| 339 |
+
log_probs = torch.gather(
|
| 340 |
+
log_probs, 2,
|
| 341 |
+
full_ids[:, 1:].unsqueeze(-1)
|
| 342 |
+
).squeeze(-1)
|
| 343 |
+
|
| 344 |
+
start = len(q_ids) - 1
|
| 345 |
+
end = start + len(r_ids)
|
| 346 |
+
old_log_probs_list.append(log_probs[0, start:end])
|
| 347 |
+
values_list.append(values[0, start:end])
|
| 348 |
+
|
| 349 |
+
# 创建mask
|
| 350 |
+
mask = torch.zeros(len(r_ids))
|
| 351 |
+
mask[-1] = 1 # 最后一个token有真实奖励
|
| 352 |
+
masks_list.append(mask)
|
| 353 |
+
|
| 354 |
+
# 转换为tensor
|
| 355 |
+
old_log_probs = torch.cat(old_log_probs_list).to(self.device)
|
| 356 |
+
values = torch.cat(values_list).to(self.device)
|
| 357 |
+
masks = torch.cat(masks_list).to(self.device)
|
| 358 |
+
rewards = torch.zeros_like(values).to(self.device)
|
| 359 |
+
|
| 360 |
+
# 设置奖励(只在最后一个token加上环境奖励)
|
| 361 |
+
for i, (r, mask) in enumerate(zip(rewards_list, masks_list)):
|
| 362 |
+
if mask[-1] > 0:
|
| 363 |
+
# KL惩罚
|
| 364 |
+
kl = old_log_probs[i] - old_log_probs[i] # 这里简化了,实际要用ref_model
|
| 365 |
+
kl_penalty = -self.args.kl_beta * kl
|
| 366 |
+
rewards[i] = kl_penalty + r
|
| 367 |
+
|
| 368 |
+
# 3. 计算advantages和returns
|
| 369 |
+
advantages, returns = self.compute_advantages_and_returns(
|
| 370 |
+
old_log_probs.unsqueeze(0),
|
| 371 |
+
values.unsqueeze(0),
|
| 372 |
+
rewards.unsqueeze(0),
|
| 373 |
+
masks.unsqueeze(0)
|
| 374 |
+
)
|
| 375 |
+
|
| 376 |
+
# 4. PPO多次更新
|
| 377 |
+
batch_data = (query_ids_list, response_ids_list, old_log_probs,
|
| 378 |
+
advantages.squeeze(0), returns.squeeze(0), masks)
|
| 379 |
+
|
| 380 |
+
for _ in range(self.args.ppo_epochs):
|
| 381 |
+
loss, policy_loss, value_loss = self.ppo_step(batch_data)
|
| 382 |
+
|
| 383 |
+
self.optimizer.zero_grad()
|
| 384 |
+
loss.backward()
|
| 385 |
+
torch.nn.utils.clip_grad_norm_(self.actor_critic.parameters(), 1.0)
|
| 386 |
+
self.optimizer.step()
|
| 387 |
+
|
| 388 |
+
total_policy_loss += policy_loss.item()
|
| 389 |
+
total_value_loss += value_loss.item()
|
| 390 |
+
num_batches += 1
|
| 391 |
+
self.training_step += 1
|
| 392 |
+
|
| 393 |
+
if batch_idx % 100 == 0:
|
| 394 |
+
print(f"Batch {batch_idx}/{len(dataset)}: "
|
| 395 |
+
f"policy_loss={total_policy_loss / num_batches:.4f}, "
|
| 396 |
+
f"value_loss={total_value_loss / num_batches:.4f}")
|
| 397 |
+
|
| 398 |
+
return total_policy_loss / num_batches, total_value_loss / num_batches
|
| 399 |
+
|
| 400 |
+
def train(self):
|
| 401 |
+
"""主训练循环"""
|
| 402 |
+
dataset = self.prepare_dataset()
|
| 403 |
+
print(f"Dataset size: {len(dataset)}")
|
| 404 |
+
|
| 405 |
+
for epoch in range(self.args.max_epochs):
|
| 406 |
+
print(f"\n=== Epoch {epoch + 1}/{self.args.max_epochs} ===")
|
| 407 |
+
policy_loss, value_loss = self.train_epoch(dataset)
|
| 408 |
+
print(f"Epoch {epoch + 1} finished: "
|
| 409 |
+
f"policy_loss={policy_loss:.4f}, value_loss={value_loss:.4f}")
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
def main():
|
| 413 |
+
args = get_args()
|
| 414 |
+
print("PPO Training with CPU")
|
| 415 |
+
print(f"Arguments: {args}")
|
| 416 |
+
|
| 417 |
+
# 创建agent并开始训练
|
| 418 |
+
agent = PPOAgent(args)
|
| 419 |
+
agent.train()
|
| 420 |
+
|
| 421 |
+
# 保存模型
|
| 422 |
+
output_dir = Path(args.sft_model_name) / "ppo_trained"
|
| 423 |
+
output_dir.mkdir(exist_ok=True, parents=True)
|
| 424 |
+
agent.actor_critic.save_pretrained(output_dir)
|
| 425 |
+
agent.tokenizer.save_pretrained(output_dir)
|
| 426 |
+
print(f"Model saved to {output_dir}")
|
| 427 |
+
|
| 428 |
+
|
| 429 |
+
if __name__ == "__main__":
|
| 430 |
+
main()
|
examples/tutorials/rlhf/gpt2_sst2/step_5_ppo_rlhf2.py
ADDED
|
@@ -0,0 +1,430 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
import argparse
|
| 4 |
+
import copy
|
| 5 |
+
import os
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
import platform
|
| 8 |
+
import random
|
| 9 |
+
from typing import Any, Dict, List, Optional, Union, Tuple
|
| 10 |
+
|
| 11 |
+
if platform.system() in ("Windows", "Darwin"):
|
| 12 |
+
from project_settings import project_path, temp_directory
|
| 13 |
+
else:
|
| 14 |
+
project_path = os.path.abspath("../../../")
|
| 15 |
+
project_path = Path(project_path)
|
| 16 |
+
temp_directory = Path("/root/autodl-tmp/OpenMiniMind/temp")
|
| 17 |
+
|
| 18 |
+
from datasets import load_dataset
|
| 19 |
+
import numpy as np
|
| 20 |
+
import torch
|
| 21 |
+
import torch.nn as nn
|
| 22 |
+
from torch.utils.data import DataLoader
|
| 23 |
+
from transformers.cache_utils import Cache, DynamicCache, EncoderDecoderCache
|
| 24 |
+
from transformers import (AutoTokenizer,
|
| 25 |
+
GPT2PreTrainedModel, GPT2Config, GPT2Model, GPT2LMHeadModel,
|
| 26 |
+
)
|
| 27 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
|
| 28 |
+
from transformers import DataCollatorWithPadding
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def get_args():
|
| 32 |
+
parser = argparse.ArgumentParser()
|
| 33 |
+
parser.add_argument(
|
| 34 |
+
"--reward_model_name",
|
| 35 |
+
default=(project_path / "trained_models/gpt2-sst2-reward").as_posix(),
|
| 36 |
+
type=str
|
| 37 |
+
),
|
| 38 |
+
parser.add_argument(
|
| 39 |
+
"--sft_model_name",
|
| 40 |
+
default=(project_path / "trained_models/gpt2-sst2-generation").as_posix(),
|
| 41 |
+
type=str
|
| 42 |
+
),
|
| 43 |
+
parser.add_argument(
|
| 44 |
+
"--dataset_path",
|
| 45 |
+
default="stanfordnlp/sst2",
|
| 46 |
+
type=str
|
| 47 |
+
),
|
| 48 |
+
parser.add_argument("--dataset_name", default=None, type=str),
|
| 49 |
+
parser.add_argument("--dataset_split", default=None, type=str),
|
| 50 |
+
parser.add_argument(
|
| 51 |
+
"--dataset_cache_dir",
|
| 52 |
+
default=(temp_directory / "hub_datasets").as_posix(),
|
| 53 |
+
type=str
|
| 54 |
+
),
|
| 55 |
+
parser.add_argument(
|
| 56 |
+
"--model_cache_dir",
|
| 57 |
+
default=(temp_directory / "hub_models").as_posix(),
|
| 58 |
+
type=str
|
| 59 |
+
),
|
| 60 |
+
parser.add_argument("--dataset_streaming", default=None, type=str),
|
| 61 |
+
parser.add_argument("--valid_dataset_size", default=1000, type=int),
|
| 62 |
+
parser.add_argument("--shuffle_buffer_size", default=5000, type=int),
|
| 63 |
+
|
| 64 |
+
parser.add_argument(
|
| 65 |
+
"--output_model_dir",
|
| 66 |
+
default=(project_path / "trained_models/gpt2-sst2-generation").as_posix(),
|
| 67 |
+
type=str
|
| 68 |
+
),
|
| 69 |
+
# train
|
| 70 |
+
parser.add_argument("--batch_size", default=32, type=int)
|
| 71 |
+
|
| 72 |
+
# generator
|
| 73 |
+
parser.add_argument(
|
| 74 |
+
"--max_new_tokens",
|
| 75 |
+
default=128, # 8192, 128
|
| 76 |
+
type=int,
|
| 77 |
+
)
|
| 78 |
+
parser.add_argument("--top_p", default=0.85, type=float)
|
| 79 |
+
parser.add_argument("--temperature", default=0.85, type=float)
|
| 80 |
+
|
| 81 |
+
# other
|
| 82 |
+
parser.add_argument(
|
| 83 |
+
"--num_workers",
|
| 84 |
+
default=None if platform.system() in ("Windows", "Darwin") else os.cpu_count() // 2,
|
| 85 |
+
type=int
|
| 86 |
+
),
|
| 87 |
+
parser.add_argument(
|
| 88 |
+
"--device",
|
| 89 |
+
default=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
|
| 90 |
+
type=int
|
| 91 |
+
),
|
| 92 |
+
args = parser.parse_args()
|
| 93 |
+
return args
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class RewardHead(nn.Module):
|
| 97 |
+
def __init__(self, hidden_size: int):
|
| 98 |
+
super().__init__()
|
| 99 |
+
self.hidden_size = hidden_size
|
| 100 |
+
self.linear = nn.Linear(self.hidden_size, 1)
|
| 101 |
+
self._post_init()
|
| 102 |
+
|
| 103 |
+
def _post_init(self):
|
| 104 |
+
nn.init.normal_(
|
| 105 |
+
self.linear.weight,
|
| 106 |
+
std=(1.0 / np.sqrt(self.hidden_size + 1))
|
| 107 |
+
)
|
| 108 |
+
nn.init.zeros_(self.linear.bias)
|
| 109 |
+
|
| 110 |
+
def forward(self, hidden_states):
|
| 111 |
+
# hidden_states shape: [batch_size, seq_len, hidden_size]
|
| 112 |
+
reward_logits = self.linear(hidden_states)
|
| 113 |
+
# reward_logits shape: [batch_size, seq_len, 1]
|
| 114 |
+
return reward_logits
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
class GPT2RewardModel(GPT2PreTrainedModel):
|
| 118 |
+
def __init__(self, config: GPT2Config):
|
| 119 |
+
super().__init__(config)
|
| 120 |
+
self.transformer = GPT2Model(config)
|
| 121 |
+
self.reward_head = RewardHead(config.hidden_size)
|
| 122 |
+
self.post_init()
|
| 123 |
+
|
| 124 |
+
def forward(
|
| 125 |
+
self,
|
| 126 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 127 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 128 |
+
) -> Union[Tuple, torch.Tensor]:
|
| 129 |
+
transformer_outputs = self.transformer(
|
| 130 |
+
input_ids=input_ids,
|
| 131 |
+
attention_mask=attention_mask,
|
| 132 |
+
output_hidden_states=True
|
| 133 |
+
)
|
| 134 |
+
last_hidden_state = transformer_outputs.hidden_states[-1]
|
| 135 |
+
# last_hidden_state shape: [batch_size, seq_len, hidden_size]
|
| 136 |
+
rewards_logits = self.reward_head(last_hidden_state)
|
| 137 |
+
# rewards_logits shape: [batch_size, seq_len, 1]
|
| 138 |
+
rewards_logits = torch.squeeze(rewards_logits, -1)
|
| 139 |
+
# rewards_logits shape: [batch_size, seq_len]
|
| 140 |
+
rewards = torch.sigmoid(rewards_logits)
|
| 141 |
+
# rewards shape: [batch_size, seq_len]
|
| 142 |
+
return rewards
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
class ValueHead(nn.Module):
|
| 146 |
+
def __init__(self, hidden_size: int):
|
| 147 |
+
super().__init__()
|
| 148 |
+
self.hidden_size = hidden_size
|
| 149 |
+
self.linear = nn.Linear(self.hidden_size, 1)
|
| 150 |
+
self._post_init()
|
| 151 |
+
|
| 152 |
+
def _post_init(self):
|
| 153 |
+
nn.init.normal_(
|
| 154 |
+
self.linear.weight,
|
| 155 |
+
std=(1.0 / np.sqrt(self.hidden_size + 1))
|
| 156 |
+
)
|
| 157 |
+
nn.init.zeros_(self.linear.bias)
|
| 158 |
+
|
| 159 |
+
def forward(self, hidden_states):
|
| 160 |
+
# hidden_states shape: [batch_size, seq_len, hidden_size]
|
| 161 |
+
reward_logits = self.linear(hidden_states)
|
| 162 |
+
# reward_logits shape: [batch_size, seq_len, 1]
|
| 163 |
+
return reward_logits
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
class GPT2ActorCriticModel(GPT2PreTrainedModel):
|
| 167 |
+
def __init__(self, config: GPT2Config):
|
| 168 |
+
super().__init__(config)
|
| 169 |
+
self.lm = GPT2LMHeadModel(config)
|
| 170 |
+
self.value_head = ValueHead(config.hidden_size)
|
| 171 |
+
self.post_init()
|
| 172 |
+
|
| 173 |
+
def forward(
|
| 174 |
+
self,
|
| 175 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 176 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 177 |
+
) -> Union[Tuple, torch.Tensor]:
|
| 178 |
+
transformer_outputs = self.lm.forward(
|
| 179 |
+
input_ids,
|
| 180 |
+
attention_mask=attention_mask,
|
| 181 |
+
output_hidden_states=True,
|
| 182 |
+
)
|
| 183 |
+
lm_logits = transformer_outputs.logits
|
| 184 |
+
|
| 185 |
+
# values
|
| 186 |
+
last_hidden_state = transformer_outputs.hidden_states[-1]
|
| 187 |
+
# last_hidden_state shape: [batch_size, seq_len, hidden_size]
|
| 188 |
+
values_logits = self.value_head(last_hidden_state)
|
| 189 |
+
# values_logits shape: [batch_size, seq_len, 1]
|
| 190 |
+
values = torch.squeeze(values_logits, -1)
|
| 191 |
+
# values shape: [batch_size, seq_len]
|
| 192 |
+
values = torch.sigmoid(values)
|
| 193 |
+
# values shape: [batch_size, seq_len]
|
| 194 |
+
return lm_logits, values
|
| 195 |
+
|
| 196 |
+
@classmethod
|
| 197 |
+
def from_pretrained(cls, model_name_or_path, *model_args, **kwargs):
|
| 198 |
+
config = GPT2Config.from_pretrained(model_name_or_path)
|
| 199 |
+
model = cls(config)
|
| 200 |
+
pretrained_model = GPT2LMHeadModel.from_pretrained(model_name_or_path)
|
| 201 |
+
model.lm.load_state_dict(pretrained_model.state_dict(), strict=False)
|
| 202 |
+
return model
|
| 203 |
+
|
| 204 |
+
def generate(self, *args, **kwargs):
|
| 205 |
+
return self.lm.generate(*args, **kwargs)
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def masked_mean(values, mask):
|
| 209 |
+
# 计算带掩码的平均值
|
| 210 |
+
return (values * mask).sum() / mask.sum()
|
| 211 |
+
|
| 212 |
+
def masked_var(values, mask):
|
| 213 |
+
# 计算带掩码的方差
|
| 214 |
+
mean = masked_mean(values, mask)
|
| 215 |
+
centred_values = values - mean
|
| 216 |
+
return masked_mean(centred_values ** 2, mask)
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
def masked_whiten(values, mask):
|
| 220 |
+
"""
|
| 221 |
+
对数据进行带掩码的白化处理,
|
| 222 |
+
让有效数据的方差变为1,但均值保持不变
|
| 223 |
+
"""
|
| 224 |
+
mean, var = masked_mean(values, mask), masked_var(values, mask)
|
| 225 |
+
whitened = (values - mean) * torch.rsqrt(var + 1e-8)
|
| 226 |
+
whitened += mean
|
| 227 |
+
return whitened
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
def main():
|
| 231 |
+
args = get_args()
|
| 232 |
+
|
| 233 |
+
device = torch.device(args.device)
|
| 234 |
+
|
| 235 |
+
# reward_model
|
| 236 |
+
reward_model = GPT2RewardModel.from_pretrained(args.reward_model_name)
|
| 237 |
+
reward_model = reward_model.to(args.device)
|
| 238 |
+
reward_model.eval()
|
| 239 |
+
reward_tokenizer = AutoTokenizer.from_pretrained(args.reward_model_name)
|
| 240 |
+
reward_tokenizer.pad_token = reward_tokenizer.eos_token
|
| 241 |
+
|
| 242 |
+
# actor_critic_model
|
| 243 |
+
actor_critic_model = GPT2ActorCriticModel.from_pretrained(args.sft_model_name)
|
| 244 |
+
actor_critic_model = actor_critic_model.to(args.device)
|
| 245 |
+
actor_critic_tokenizer = AutoTokenizer.from_pretrained(args.sft_model_name)
|
| 246 |
+
actor_critic_tokenizer.pad_token = actor_critic_tokenizer.eos_token
|
| 247 |
+
actor_critic_tokenizer.pad_token_id = actor_critic_tokenizer.eos_token_id
|
| 248 |
+
|
| 249 |
+
# ref_model
|
| 250 |
+
ref_model = copy.deepcopy(actor_critic_model)
|
| 251 |
+
ref_model = ref_model.to(args.device)
|
| 252 |
+
ref_model.eval()
|
| 253 |
+
|
| 254 |
+
dataset_dict = load_dataset(
|
| 255 |
+
path=args.dataset_path,
|
| 256 |
+
name=args.dataset_name,
|
| 257 |
+
split=args.dataset_split,
|
| 258 |
+
cache_dir=args.dataset_cache_dir,
|
| 259 |
+
# num_proc=args.num_workers if not args.dataset_streaming else None,
|
| 260 |
+
streaming=args.dataset_streaming,
|
| 261 |
+
)
|
| 262 |
+
train_dataset = dataset_dict["train"]
|
| 263 |
+
# valid_dataset = dataset_dict["validation"]
|
| 264 |
+
# test_dataset = dataset_dict["test"]
|
| 265 |
+
|
| 266 |
+
def format_func(example):
|
| 267 |
+
sentence: str = example["sentence"]
|
| 268 |
+
score: float = float(example["label"])
|
| 269 |
+
tokenized = actor_critic_tokenizer(sentence)
|
| 270 |
+
input_ids = tokenized["input_ids"]
|
| 271 |
+
attention_mask = tokenized["attention_mask"]
|
| 272 |
+
result = {
|
| 273 |
+
"input_ids": input_ids,
|
| 274 |
+
"attention_mask": attention_mask,
|
| 275 |
+
}
|
| 276 |
+
return result
|
| 277 |
+
|
| 278 |
+
train_dataset = train_dataset.map(
|
| 279 |
+
format_func,
|
| 280 |
+
batched=False,
|
| 281 |
+
remove_columns=train_dataset.column_names,
|
| 282 |
+
)
|
| 283 |
+
train_dataset = train_dataset.filter(
|
| 284 |
+
function=lambda x: len(x["input_ids"]) > 8
|
| 285 |
+
)
|
| 286 |
+
def token_truncate(example):
|
| 287 |
+
target_length = random.randint(2, 6)
|
| 288 |
+
input_ids = example["input_ids"]
|
| 289 |
+
attention_mask = example["attention_mask"]
|
| 290 |
+
input_ids = input_ids[:target_length]
|
| 291 |
+
attention_mask = attention_mask[:target_length]
|
| 292 |
+
text = actor_critic_tokenizer.decode(input_ids)
|
| 293 |
+
result = {
|
| 294 |
+
"input_ids": input_ids,
|
| 295 |
+
"attention_mask": attention_mask,
|
| 296 |
+
# "text": text,
|
| 297 |
+
}
|
| 298 |
+
return result
|
| 299 |
+
|
| 300 |
+
train_dataset = train_dataset.map(
|
| 301 |
+
token_truncate,
|
| 302 |
+
batched=False,
|
| 303 |
+
remove_columns=train_dataset.column_names,
|
| 304 |
+
)
|
| 305 |
+
data_collator = DataCollatorWithPadding(
|
| 306 |
+
tokenizer=actor_critic_tokenizer,
|
| 307 |
+
padding=True,
|
| 308 |
+
)
|
| 309 |
+
train_data_loader = DataLoader(
|
| 310 |
+
dataset=train_dataset,
|
| 311 |
+
batch_size=args.batch_size,
|
| 312 |
+
shuffle=True,
|
| 313 |
+
num_workers=args.num_workers or 0,
|
| 314 |
+
collate_fn=data_collator,
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
# for train_batch in train_data_loader:
|
| 318 |
+
# print(train_batch)
|
| 319 |
+
|
| 320 |
+
for epoch_id in range(10):
|
| 321 |
+
for batch in train_data_loader:
|
| 322 |
+
input_ids = batch["input_ids"]
|
| 323 |
+
attention_mask = batch["attention_mask"]
|
| 324 |
+
|
| 325 |
+
query_ids_list = list()
|
| 326 |
+
query_and_response_ids_list = list()
|
| 327 |
+
response_ids_list = list()
|
| 328 |
+
reward_list = list()
|
| 329 |
+
for idx in range(args.batch_size):
|
| 330 |
+
input_ids_ = input_ids[idx]
|
| 331 |
+
attention_mask_ = attention_mask[idx]
|
| 332 |
+
input_ids_ = input_ids_.to(device)
|
| 333 |
+
attention_mask_ = attention_mask_.to(device)
|
| 334 |
+
|
| 335 |
+
with torch.no_grad():
|
| 336 |
+
query_and_response_ids = actor_critic_model.generate(
|
| 337 |
+
input_ids=input_ids_.unsqueeze(0),
|
| 338 |
+
attention_mask=attention_mask_.unsqueeze(0),
|
| 339 |
+
max_new_tokens=random.randint(5, 16),
|
| 340 |
+
do_sample=True,
|
| 341 |
+
top_p=0.85,
|
| 342 |
+
temperature=0.85,
|
| 343 |
+
pad_token_id=actor_critic_tokenizer.pad_token_id,
|
| 344 |
+
eos_token_id=actor_critic_tokenizer.eos_token_id,
|
| 345 |
+
repetition_penalty=1.0,
|
| 346 |
+
early_stopping=True,
|
| 347 |
+
).squeeze(0)
|
| 348 |
+
query_ids_list.append(input_ids_)
|
| 349 |
+
query_and_response_ids_list.append(query_and_response_ids)
|
| 350 |
+
response_ids = query_and_response_ids[len(input_ids_):]
|
| 351 |
+
response_ids_list.append(response_ids)
|
| 352 |
+
|
| 353 |
+
reward = reward_model(
|
| 354 |
+
input_ids=query_and_response_ids.unsqueeze(0),
|
| 355 |
+
attention_mask=torch.ones_like(query_and_response_ids, dtype=torch.long).unsqueeze(0),
|
| 356 |
+
).squeeze(0)[-1]
|
| 357 |
+
# 将奖励模型的评分从(0,1)缩放到(-1,1)
|
| 358 |
+
reward = 2 * (reward - 0.5)
|
| 359 |
+
reward_list.append(reward)
|
| 360 |
+
|
| 361 |
+
for query_ids, query_and_response_ids in zip(query_ids_list, query_and_response_ids_list):
|
| 362 |
+
print(actor_critic_tokenizer.decode(query_ids, skip_special_tokens=False))
|
| 363 |
+
print(actor_critic_tokenizer.decode(query_ids, skip_special_tokens=True))
|
| 364 |
+
print(actor_critic_tokenizer.decode(query_and_response_ids, skip_special_tokens=True))
|
| 365 |
+
exit(0)
|
| 366 |
+
#计算奖励
|
| 367 |
+
batch_ = list()
|
| 368 |
+
for query_and_response_ids in query_and_response_ids_list:
|
| 369 |
+
print(actor_critic_tokenizer.decode(query_and_response_ids))
|
| 370 |
+
batch_.append({
|
| 371 |
+
"input_ids": query_and_response_ids,
|
| 372 |
+
"attention_mask": torch.ones_like(query_and_response_ids),
|
| 373 |
+
})
|
| 374 |
+
|
| 375 |
+
batch_ = data_collator(batch_)
|
| 376 |
+
input_ids = batch_["input_ids"]
|
| 377 |
+
attention_mask = batch_["attention_mask"]
|
| 378 |
+
input_ids = input_ids.to(device)
|
| 379 |
+
attention_mask = attention_mask.to(device)
|
| 380 |
+
|
| 381 |
+
logits, values = actor_critic_model(input_ids=input_ids, attention_mask=attention_mask)
|
| 382 |
+
ref_logits, _ = ref_model(input_ids=input_ids, attention_mask=attention_mask)
|
| 383 |
+
log_prob = torch.nn.functional.log_softmax(logits[:, :-1, :], dim=-1)
|
| 384 |
+
ref_log_prob = torch.nn.functional.log_softmax(ref_logits[:, :-1, :], dim=-1)
|
| 385 |
+
index = input_ids[:, 1:].unsqueeze(-1)
|
| 386 |
+
log_prob = torch.gather(log_prob, dim=2, index=index).squeeze(-1)
|
| 387 |
+
ref_log_prob = torch.gather(ref_log_prob, dim=2, index=index).squeeze(-1)
|
| 388 |
+
|
| 389 |
+
kl = log_prob - ref_log_prob
|
| 390 |
+
beta = 0.2
|
| 391 |
+
kl_penalty = - beta * kl
|
| 392 |
+
|
| 393 |
+
rewards = kl_penalty
|
| 394 |
+
masks = torch.zeros_like(input_ids[:, 1:])
|
| 395 |
+
for idx in range(args.batch_size):
|
| 396 |
+
start = len(query_ids_list[idx]) - 1
|
| 397 |
+
end = start + len(response_ids_list[idx])
|
| 398 |
+
masks[idx, start:end] = 1
|
| 399 |
+
rewards[idx, end - 1] += reward_list[idx]
|
| 400 |
+
values[idx, :-1] *= masks[idx, :]
|
| 401 |
+
values[idx, -1] = 0
|
| 402 |
+
rewards = rewards * masks
|
| 403 |
+
|
| 404 |
+
# log_prob, rewards, kl_penalty, masks shape: [b, seq_len - 1]
|
| 405 |
+
# values shape: [b, seq_len]
|
| 406 |
+
|
| 407 |
+
# 计算优势
|
| 408 |
+
seq_len = rewards.shape[-1]
|
| 409 |
+
last_gae = 0.0
|
| 410 |
+
gamma, lam = 1.0, 0.95
|
| 411 |
+
advantage_reversed = list()
|
| 412 |
+
for t in reversed(range(seq_len)):
|
| 413 |
+
next_value = values[:, t + 1] if t < seq_len - 1 else 0.0
|
| 414 |
+
delta = rewards[:, t] + gamma * next_value - values[:, t]
|
| 415 |
+
last_gae = delta + gamma * lam * last_gae
|
| 416 |
+
advantage_reversed.append(last_gae)
|
| 417 |
+
advantages = torch.stack(advantage_reversed[::-1], dim=1)
|
| 418 |
+
advantages = masked_whiten(advantages, masks)
|
| 419 |
+
returns = advantages + values[:, :-1]
|
| 420 |
+
|
| 421 |
+
# advantages shape: [b, seq_len-1]
|
| 422 |
+
# returns shape: [b, seq_len-1]
|
| 423 |
+
|
| 424 |
+
exit(0)
|
| 425 |
+
|
| 426 |
+
return
|
| 427 |
+
|
| 428 |
+
|
| 429 |
+
if __name__ == "__main__":
|
| 430 |
+
main()
|
examples/tutorials/rlhf/gpt2_sst2/step_5_pre_ppo_rlhf.py
ADDED
|
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
import argparse
|
| 4 |
+
import os
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
import platform
|
| 7 |
+
from typing import Any, Dict, List, Optional, Union, Tuple
|
| 8 |
+
|
| 9 |
+
if platform.system() in ("Windows", "Darwin"):
|
| 10 |
+
from project_settings import project_path, temp_directory
|
| 11 |
+
else:
|
| 12 |
+
project_path = os.path.abspath("../../../")
|
| 13 |
+
project_path = Path(project_path)
|
| 14 |
+
temp_directory = Path("/root/autodl-tmp/OpenMiniMind/temp")
|
| 15 |
+
|
| 16 |
+
from datasets import load_dataset
|
| 17 |
+
import numpy as np
|
| 18 |
+
import torch
|
| 19 |
+
import torch.nn as nn
|
| 20 |
+
from transformers.cache_utils import Cache, DynamicCache, EncoderDecoderCache
|
| 21 |
+
from transformers import (AutoTokenizer,
|
| 22 |
+
GPT2PreTrainedModel, GPT2Config, GPT2Model, GPT2LMHeadModel,
|
| 23 |
+
)
|
| 24 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def get_args():
|
| 28 |
+
parser = argparse.ArgumentParser()
|
| 29 |
+
parser.add_argument(
|
| 30 |
+
"--reward_model_name",
|
| 31 |
+
default=(project_path / "trained_models/gpt2-sst2-reward").as_posix(),
|
| 32 |
+
type=str
|
| 33 |
+
),
|
| 34 |
+
parser.add_argument(
|
| 35 |
+
"--policy_model_name",
|
| 36 |
+
default=(project_path / "trained_models/gpt2-sst2-generation").as_posix(),
|
| 37 |
+
type=str
|
| 38 |
+
),
|
| 39 |
+
parser.add_argument(
|
| 40 |
+
"--dataset_path",
|
| 41 |
+
default="stanfordnlp/sst2",
|
| 42 |
+
type=str
|
| 43 |
+
),
|
| 44 |
+
parser.add_argument("--dataset_name", default=None, type=str),
|
| 45 |
+
parser.add_argument("--dataset_split", default=None, type=str),
|
| 46 |
+
parser.add_argument(
|
| 47 |
+
"--dataset_cache_dir",
|
| 48 |
+
default=(temp_directory / "hub_datasets").as_posix(),
|
| 49 |
+
type=str
|
| 50 |
+
),
|
| 51 |
+
parser.add_argument(
|
| 52 |
+
"--model_cache_dir",
|
| 53 |
+
default=(temp_directory / "hub_models").as_posix(),
|
| 54 |
+
type=str
|
| 55 |
+
),
|
| 56 |
+
parser.add_argument("--dataset_streaming", default=None, type=str),
|
| 57 |
+
parser.add_argument("--valid_dataset_size", default=1000, type=int),
|
| 58 |
+
parser.add_argument("--shuffle_buffer_size", default=5000, type=int),
|
| 59 |
+
|
| 60 |
+
parser.add_argument(
|
| 61 |
+
"--output_model_dir",
|
| 62 |
+
default=(project_path / "trained_models/gpt2-sst2-generation").as_posix(),
|
| 63 |
+
type=str
|
| 64 |
+
),
|
| 65 |
+
|
| 66 |
+
parser.add_argument(
|
| 67 |
+
"--max_new_tokens",
|
| 68 |
+
default=128, # 8192, 128
|
| 69 |
+
type=int, help="最大生成长度(注意:并非模型实际长文本能力)"
|
| 70 |
+
)
|
| 71 |
+
parser.add_argument("--top_p", default=0.85, type=float, help="nucleus采样阈值(0-1)")
|
| 72 |
+
parser.add_argument("--temperature", default=0.85, type=float, help="生成温度,控制随机性(0-1,越大越随机)")
|
| 73 |
+
|
| 74 |
+
# other
|
| 75 |
+
parser.add_argument(
|
| 76 |
+
"--num_workers",
|
| 77 |
+
default=None if platform.system() in ("Windows", "Darwin") else os.cpu_count() // 2,
|
| 78 |
+
type=int
|
| 79 |
+
),
|
| 80 |
+
parser.add_argument(
|
| 81 |
+
"--device",
|
| 82 |
+
default=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
|
| 83 |
+
type=int
|
| 84 |
+
),
|
| 85 |
+
args = parser.parse_args()
|
| 86 |
+
return args
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class RewardHead(nn.Module):
|
| 90 |
+
def __init__(self, hidden_size: int):
|
| 91 |
+
super().__init__()
|
| 92 |
+
self.hidden_size = hidden_size
|
| 93 |
+
self.linear = nn.Linear(self.hidden_size, 1)
|
| 94 |
+
self._post_init()
|
| 95 |
+
|
| 96 |
+
def _post_init(self):
|
| 97 |
+
nn.init.normal_(
|
| 98 |
+
self.linear.weight,
|
| 99 |
+
std=(1.0 / np.sqrt(self.hidden_size + 1))
|
| 100 |
+
)
|
| 101 |
+
nn.init.zeros_(self.linear.bias)
|
| 102 |
+
|
| 103 |
+
def forward(self, hidden_states):
|
| 104 |
+
# hidden_states shape: [batch_size, seq_len, hidden_size]
|
| 105 |
+
reward_logits = self.linear(hidden_states)
|
| 106 |
+
# reward_logits shape: [batch_size, seq_len, 1]
|
| 107 |
+
return reward_logits
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
class GPT2RewardModel(GPT2PreTrainedModel):
|
| 111 |
+
def __init__(self, config: GPT2Config):
|
| 112 |
+
super().__init__(config)
|
| 113 |
+
self.transformer = GPT2Model(config)
|
| 114 |
+
self.reward_head = RewardHead(config.hidden_size)
|
| 115 |
+
self.post_init()
|
| 116 |
+
|
| 117 |
+
def forward(
|
| 118 |
+
self,
|
| 119 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 120 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 121 |
+
) -> Union[Tuple, torch.Tensor]:
|
| 122 |
+
transformer_outputs = self.transformer(
|
| 123 |
+
input_ids=input_ids,
|
| 124 |
+
attention_mask=attention_mask,
|
| 125 |
+
output_hidden_states=True
|
| 126 |
+
)
|
| 127 |
+
last_hidden_state = transformer_outputs.hidden_states[-1]
|
| 128 |
+
# last_hidden_state shape: [batch_size, seq_len, hidden_size]
|
| 129 |
+
rewards_logits = self.reward_head(last_hidden_state)
|
| 130 |
+
# rewards_logits shape: [batch_size, seq_len, 1]
|
| 131 |
+
rewards_logits = torch.squeeze(rewards_logits, -1)
|
| 132 |
+
# rewards_logits shape: [batch_size, seq_len]
|
| 133 |
+
rewards = torch.sigmoid(rewards_logits)
|
| 134 |
+
# rewards shape: [batch_size, seq_len]
|
| 135 |
+
return rewards
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
class ValueHead(nn.Module):
|
| 139 |
+
def __init__(self, hidden_size: int):
|
| 140 |
+
super().__init__()
|
| 141 |
+
self.hidden_size = hidden_size
|
| 142 |
+
self.linear = nn.Linear(self.hidden_size, 1)
|
| 143 |
+
self._post_init()
|
| 144 |
+
|
| 145 |
+
def _post_init(self):
|
| 146 |
+
nn.init.normal_(
|
| 147 |
+
self.linear.weight,
|
| 148 |
+
std=(1.0 / np.sqrt(self.hidden_size + 1))
|
| 149 |
+
)
|
| 150 |
+
nn.init.zeros_(self.linear.bias)
|
| 151 |
+
|
| 152 |
+
def forward(self, hidden_states):
|
| 153 |
+
# hidden_states shape: [batch_size, seq_len, hidden_size]
|
| 154 |
+
reward_logits = self.linear(hidden_states)
|
| 155 |
+
# reward_logits shape: [batch_size, seq_len, 1]
|
| 156 |
+
return reward_logits
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
class GPT2ActorCriticModel(GPT2PreTrainedModel):
|
| 160 |
+
def __init__(self, config: GPT2Config):
|
| 161 |
+
super().__init__(config)
|
| 162 |
+
self.lm = GPT2LMHeadModel(config)
|
| 163 |
+
self.value_head = ValueHead(config.hidden_size)
|
| 164 |
+
self.post_init()
|
| 165 |
+
|
| 166 |
+
def forward(
|
| 167 |
+
self,
|
| 168 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 169 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 170 |
+
) -> Union[Tuple, torch.Tensor]:
|
| 171 |
+
transformer_outputs = self.lm.forward(
|
| 172 |
+
input_ids,
|
| 173 |
+
attention_mask=attention_mask,
|
| 174 |
+
output_hidden_states=True,
|
| 175 |
+
)
|
| 176 |
+
lm_logits = transformer_outputs.logits
|
| 177 |
+
|
| 178 |
+
# values
|
| 179 |
+
last_hidden_state = transformer_outputs.hidden_states[-1]
|
| 180 |
+
# last_hidden_state shape: [batch_size, seq_len, hidden_size]
|
| 181 |
+
values_logits = self.value_head(last_hidden_state)
|
| 182 |
+
# values_logits shape: [batch_size, seq_len, 1]
|
| 183 |
+
values = torch.squeeze(values_logits, -1)
|
| 184 |
+
# values shape: [batch_size, seq_len]
|
| 185 |
+
values = torch.sigmoid(values)
|
| 186 |
+
# values shape: [batch_size, seq_len]
|
| 187 |
+
return lm_logits, values
|
| 188 |
+
|
| 189 |
+
@classmethod
|
| 190 |
+
def from_pretrained(cls, model_name_or_path, *model_args, **kwargs):
|
| 191 |
+
config = GPT2Config.from_pretrained(model_name_or_path)
|
| 192 |
+
model = cls(config)
|
| 193 |
+
pretrained_model = GPT2LMHeadModel.from_pretrained(model_name_or_path)
|
| 194 |
+
model.lm.load_state_dict(pretrained_model.state_dict(), strict=False)
|
| 195 |
+
return model
|
| 196 |
+
|
| 197 |
+
def generate(self, *args, **kwargs):
|
| 198 |
+
return self.lm.generate(*args, **kwargs)
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def main():
|
| 202 |
+
args = get_args()
|
| 203 |
+
|
| 204 |
+
reward_model = GPT2RewardModel.from_pretrained(args.reward_model_name)
|
| 205 |
+
reward_model = reward_model.to(args.device)
|
| 206 |
+
reward_tokenizer = AutoTokenizer.from_pretrained(args.reward_model_name)
|
| 207 |
+
reward_tokenizer.pad_token = reward_tokenizer.eos_token
|
| 208 |
+
print(reward_model)
|
| 209 |
+
print(reward_tokenizer)
|
| 210 |
+
|
| 211 |
+
# tokenized = reward_tokenizer(
|
| 212 |
+
# "this is very good movie, I recommend it.",
|
| 213 |
+
# return_tensors="pt"
|
| 214 |
+
# )
|
| 215 |
+
# rewards = reward_model(**tokenized)
|
| 216 |
+
# rewards = rewards[0]
|
| 217 |
+
# rewards = rewards.detach().cpu().numpy()
|
| 218 |
+
# last_token_reward = rewards[-1]
|
| 219 |
+
# # rewards: {rewards}\n
|
| 220 |
+
# msg = f"last_token_reward: {last_token_reward}\n"
|
| 221 |
+
# print(msg)
|
| 222 |
+
# exit(0)
|
| 223 |
+
|
| 224 |
+
# actor_critic_model
|
| 225 |
+
actor_critic_model = GPT2ActorCriticModel.from_pretrained(args.policy_model_name)
|
| 226 |
+
actor_critic_model = actor_critic_model.to(args.device)
|
| 227 |
+
actor_critic_tokenizer = AutoTokenizer.from_pretrained(args.policy_model_name)
|
| 228 |
+
actor_critic_tokenizer.pad_token = actor_critic_tokenizer.eos_token
|
| 229 |
+
print(actor_critic_model)
|
| 230 |
+
print(actor_critic_tokenizer)
|
| 231 |
+
|
| 232 |
+
tokenized = actor_critic_tokenizer(
|
| 233 |
+
"this is ",
|
| 234 |
+
return_tensors="pt"
|
| 235 |
+
)
|
| 236 |
+
streamer = TextStreamer(actor_critic_tokenizer, skip_prompt=True, skip_special_tokens=True)
|
| 237 |
+
generated_ids = actor_critic_model.generate(
|
| 238 |
+
inputs=tokenized["input_ids"], attention_mask=tokenized["attention_mask"],
|
| 239 |
+
max_new_tokens=args.max_new_tokens, do_sample=True, streamer=streamer,
|
| 240 |
+
pad_token_id=actor_critic_tokenizer.pad_token_id, eos_token_id=actor_critic_tokenizer.eos_token_id,
|
| 241 |
+
top_p=args.top_p, temperature=args.temperature, repetition_penalty=1.0,
|
| 242 |
+
)
|
| 243 |
+
response = actor_critic_tokenizer.decode(generated_ids[0][len(tokenized["input_ids"][0]):], skip_special_tokens=True)
|
| 244 |
+
print(response)
|
| 245 |
+
|
| 246 |
+
tokenized = actor_critic_tokenizer(
|
| 247 |
+
"this is very good movie, I recommend it.",
|
| 248 |
+
return_tensors="pt"
|
| 249 |
+
)
|
| 250 |
+
lm_logits, values = actor_critic_model(**tokenized)
|
| 251 |
+
print(values)
|
| 252 |
+
|
| 253 |
+
return
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
if __name__ == "__main__":
|
| 257 |
+
main()
|
tabs/chat_template_tab.py
CHANGED
|
@@ -13,8 +13,10 @@ def run_chat_template(conversation: str, model_name: str, add_generation_prompt:
|
|
| 13 |
|
| 14 |
result = tokenizer.apply_chat_template(
|
| 15 |
conversation,
|
|
|
|
| 16 |
tokenize=False,
|
| 17 |
add_generation_prompt=add_generation_prompt,
|
|
|
|
| 18 |
)
|
| 19 |
return result
|
| 20 |
|
|
|
|
| 13 |
|
| 14 |
result = tokenizer.apply_chat_template(
|
| 15 |
conversation,
|
| 16 |
+
# tools=None,
|
| 17 |
tokenize=False,
|
| 18 |
add_generation_prompt=add_generation_prompt,
|
| 19 |
+
# enable_thinking=True,
|
| 20 |
)
|
| 21 |
return result
|
| 22 |
|