miyuki2026 commited on
Commit
3bd251d
·
1 Parent(s): a82ed9a
examples/playground/chat.py CHANGED
@@ -17,7 +17,7 @@ def get_args():
17
  parser.add_argument(
18
  "--pretrained_model_name_or_path",
19
  # default="jingyaogong/MiniMind2",
20
- default=(project_path / "pretrained_models/MiniMind2"),
21
  type=str
22
  )
23
 
 
17
  parser.add_argument(
18
  "--pretrained_model_name_or_path",
19
  # default="jingyaogong/MiniMind2",
20
+ default=(project_path / "pretrained_models/jingyaogong/MiniMind2"),
21
  type=str
22
  )
23
 
examples/playground/generation.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ https://github.com/jingyaogong/minimind/blob/master/eval_llm.py
5
+ """
6
+ import argparse
7
+ import time
8
+
9
+ import torch
10
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
11
+
12
+ from project_settings import project_path
13
+
14
+
15
+ def get_args():
16
+ parser = argparse.ArgumentParser()
17
+ parser.add_argument(
18
+ "--pretrained_model_name_or_path",
19
+ # default=(project_path / "trained_models/gpt2-sst2-generation"),
20
+ default=(project_path / "trained_models/gpt2-sst2-generation-20260213-2048"),
21
+ type=str
22
+ )
23
+ parser.add_argument(
24
+ "--max_new_tokens",
25
+ default=1024, # 8192, 128
26
+ type=int, help="最大生成长度(注意:并非模型实际长文本能力)"
27
+ )
28
+ parser.add_argument("--top_p", default=0.85, type=float, help="nucleus采样阈值(0-1)")
29
+ parser.add_argument("--temperature", default=0.85, type=float, help="生成温度,控制随机性(0-1,越大越随机)")
30
+
31
+ args = parser.parse_args()
32
+ return args
33
+
34
+
35
+ def main():
36
+ args = get_args()
37
+
38
+ if torch.cuda.is_available():
39
+ device = "cuda"
40
+ elif torch.backends.mps.is_available():
41
+ # device = "mps"
42
+ device = "cpu"
43
+ else:
44
+ device = "cpu"
45
+ print(f"device: {device}")
46
+
47
+ tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_name_or_path)
48
+ model = AutoModelForCausalLM.from_pretrained(args.pretrained_model_name_or_path)
49
+ model = model.eval().to(device)
50
+
51
+ tokenized = tokenizer(
52
+ # "this",
53
+ # "this is ",
54
+ # "who needs mind-bending",
55
+ "eldom has a movie",
56
+ # "thanks to scott 's charismatic",
57
+ return_tensors="pt"
58
+ )
59
+
60
+ streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
61
+
62
+ generated_ids = model.generate(
63
+ inputs=tokenized["input_ids"], attention_mask=tokenized["attention_mask"],
64
+ max_new_tokens=args.max_new_tokens, do_sample=True, streamer=streamer,
65
+ pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id,
66
+ top_p=args.top_p, temperature=args.temperature, repetition_penalty=3.0,
67
+ early_stopping=True,
68
+ )
69
+ # response = tokenizer.decode(generated_ids[0][len(tokenized["input_ids"][0]):], skip_special_tokens=True)
70
+ response = tokenizer.decode(generated_ids[0], skip_special_tokens=False)
71
+ print(response)
72
+ print(generated_ids)
73
+
74
+ return
75
+
76
+
77
+ if __name__ == "__main__":
78
+ main()
examples/tutorials/dpo/ultrachat/step_1_prepare_data.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ 或使用命令行
5
+ pip install modelscope
6
+ modelscope download \
7
+ --model 'qgyd2021/Qwen3-8B-sft-deepspeed' \
8
+ --local_dir '/root/autodl-tmp/trained_models/Qwen3-8B-sft-deepspeed'
9
+
10
+ """
11
+ import argparse
12
+ import os
13
+ from pathlib import Path
14
+ import platform
15
+
16
+ os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
17
+
18
+ if platform.system() in ("Windows", "Darwin"):
19
+ from project_settings import project_path, temp_directory
20
+ else:
21
+ project_path = os.path.abspath("../../../")
22
+ project_path = Path(project_path)
23
+ temp_directory = Path("/root/autodl-tmp/OpenMiniMind/temp")
24
+
25
+ from modelscope import snapshot_download
26
+ # from huggingface_hub import snapshot_download
27
+
28
+
29
+ def get_args():
30
+ parser = argparse.ArgumentParser()
31
+ parser.add_argument("--repo_id", default="openai-community/gpt2", type=str)
32
+ parser.add_argument(
33
+ "--local_dir",
34
+ default=(temp_directory / "../trained_models/openai-community/gpt2").as_posix(),
35
+ type=str
36
+ )
37
+ args = parser.parse_args()
38
+ return args
39
+
40
+
41
+ def main():
42
+ args = get_args()
43
+
44
+ #modelscope
45
+ snapshot_download(
46
+ model_id=args.repo_id,
47
+ local_dir=args.local_dir,
48
+ )
49
+ #huggingface_hub
50
+ snapshot_download(
51
+ repo_type="model",
52
+ repo_id=args.repo_id,
53
+ local_dir=args.local_dir,
54
+ )
55
+ return
56
+
57
+
58
+ if __name__ == "__main__":
59
+ main()
examples/tutorials/dpo/ultrachat/step_2_train_sft_model.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ import os
5
+ from pathlib import Path
6
+ import platform
7
+
8
+ os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
9
+
10
+ if platform.system() in ("Windows", "Darwin"):
11
+ from project_settings import project_path, temp_directory
12
+ else:
13
+ project_path = os.path.abspath("../../../")
14
+ project_path = Path(project_path)
15
+ temp_directory = Path("/root/autodl-tmp/OpenMiniMind/temp")
16
+
17
+ from datasets import load_dataset
18
+ import torch
19
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
20
+ from trl import SFTTrainer, SFTConfig
21
+
22
+
23
+ def get_args():
24
+ parser = argparse.ArgumentParser()
25
+ parser.add_argument(
26
+ "--model_name",
27
+ default=(project_path / "pretrained_models/Qwen/Qwen2.5-0.5B").as_posix(),
28
+ type=str
29
+ ),
30
+ parser.add_argument(
31
+ "--dataset_path",
32
+ default="HuggingFaceH4/ultrachat_200k",
33
+ type=str
34
+ ),
35
+ parser.add_argument("--dataset_name", default=None, type=str),
36
+ parser.add_argument("--dataset_split", default=None, type=str),
37
+ parser.add_argument(
38
+ "--dataset_cache_dir",
39
+ default=(temp_directory / "hub_datasets").as_posix(),
40
+ type=str
41
+ ),
42
+ parser.add_argument(
43
+ "--model_cache_dir",
44
+ default=(temp_directory / "hub_models").as_posix(),
45
+ type=str
46
+ ),
47
+ parser.add_argument("--dataset_streaming", default="false", type=str),
48
+ parser.add_argument("--valid_dataset_size", default=1000, type=int),
49
+ parser.add_argument("--shuffle_buffer_size", default=5000, type=int),
50
+
51
+ parser.add_argument("--max_seq_length", default=2048, type=int)
52
+
53
+ parser.add_argument(
54
+ "--output_model_dir",
55
+ default=(project_path / "trained_models/qwen2_5-0_5B-ultrachat-sft").as_posix(),
56
+ type=str
57
+ ),
58
+
59
+ parser.add_argument(
60
+ "--num_workers",
61
+ default=None if platform.system() in ("Windows", "Darwin") else os.cpu_count() // 2,
62
+ type=int
63
+ ),
64
+ args = parser.parse_args()
65
+ return args
66
+
67
+
68
+ def main():
69
+ args = get_args()
70
+
71
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
72
+
73
+ model = AutoModelForCausalLM.from_pretrained(args.model_name)
74
+ model = model.to(args.device)
75
+ tokenizer = AutoTokenizer.from_pretrained(args.model_name)
76
+ tokenizer.pad_token = tokenizer.eos_token
77
+
78
+ train_dataset = load_dataset(
79
+ path=args.dataset_path,
80
+ name=args.dataset_name,
81
+ split="train_sft",
82
+ cache_dir=args.dataset_cache_dir,
83
+ # num_proc=args.num_workers if not args.dataset_streaming else None,
84
+ streaming=True if args.dataset_streaming in ("true",) else False,
85
+ )
86
+
87
+ sft_config = SFTConfig(
88
+ output_dir=args.output_model_dir,
89
+ num_train_epochs=1,
90
+ per_device_train_batch_size=4,
91
+ gradient_accumulation_steps=4,
92
+ save_strategy="steps",
93
+ save_steps=500,
94
+ save_total_limit=2,
95
+ logging_steps=100,
96
+ learning_rate=2e-5,
97
+ warmup_ratio=0.03,
98
+ lr_scheduler_type="cosine",
99
+ bf16=torch.cuda.is_available(),
100
+ tf32=torch.cuda.is_available(),
101
+ gradient_checkpointing=True,
102
+ optim="adamw_torch",
103
+ remove_unused_columns=False,
104
+ report_to="none",
105
+ dataloader_num_workers=args.num_workers or 0,
106
+ ddp_find_unused_parameters=False if torch.cuda.device_count() > 1 else None,
107
+
108
+ # SFT specific parameters
109
+ max_length=args.max_seq_length,
110
+ dataset_text_field=None,
111
+ dataset_kwargs={
112
+ "add_special_tokens": True,
113
+ # "split": "train",
114
+ },
115
+ )
116
+
117
+ # 创建 trainer
118
+ trainer = SFTTrainer(
119
+ model=model,
120
+ args=sft_config,
121
+ train_dataset=train_dataset,
122
+ processing_class=tokenizer,
123
+ )
124
+
125
+ trainer.train()
126
+ trainer.save_model()
127
+ return
128
+
129
+
130
+ if __name__ == "__main__":
131
+ main()
132
+
examples/tutorials/mix_lora_unsloth/step_2_train_model.py CHANGED
@@ -146,6 +146,8 @@ def main():
146
  # max_steps = 30,
147
  learning_rate=2e-5, # Reduce to 2e-5 for long training runs
148
  logging_steps=1,
 
 
149
  optim="adamw_8bit",
150
  weight_decay=0.01,
151
  lr_scheduler_type="linear",
 
146
  # max_steps = 30,
147
  learning_rate=2e-5, # Reduce to 2e-5 for long training runs
148
  logging_steps=1,
149
+ save_steps=100, # 每500步保存一次检查点
150
+ save_total_limit=2, # 最多只保留2个检查点,旧的自动清理
151
  optim="adamw_8bit",
152
  weight_decay=0.01,
153
  lr_scheduler_type="linear",
examples/tutorials/rl/cart_pole/requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ gymnasium
2
+ matplotlib
3
+ pygame
examples/tutorials/rl/cart_pole/step_2_actor_critic.py ADDED
@@ -0,0 +1,716 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ 相比于 REINFORCE 方法,演员-评论家(Actor-Critic)方法不需要等到一局游戏结束就可以触发优化迭代。
5
+
6
+
7
+
8
+ A2C 计算价值优势,主要是对单次的优势进行了移动平均。
9
+
10
+ 由于函数的优化步长受到优势的直接影响。当优步长过大时很容易直接跨过,导致优化失败。
11
+ 虽然A2C 已经对历史优势进行了移动平均,但问题仍然存在。
12
+ 尤其是当训练的早期价值函数还没有获得较好的训练,这种问题尤其容易出现。
13
+ 因此需要对优化步长进行截断限制。
14
+
15
+ """
16
+ import gymnasium as gym
17
+ import matplotlib.pyplot as plt
18
+ import numpy as np
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.optim as optim
22
+ from torch.distributions import Categorical
23
+ import torch.nn.functional as F
24
+
25
+
26
+ # ============== 1. 基础Actor-Critic ==============
27
+ class ActorCritic(nn.Module):
28
+ """共享网络的Actor-Critic"""
29
+
30
+ def __init__(self, state_dim, action_dim, hidden_dim=256):
31
+ super(ActorCritic, self).__init__()
32
+
33
+ # 共享特征提取层
34
+ self.feature_layer = nn.Sequential(
35
+ nn.Linear(state_dim, hidden_dim),
36
+ nn.ReLU(),
37
+ nn.Linear(hidden_dim, hidden_dim),
38
+ nn.ReLU()
39
+ )
40
+
41
+ # Actor: 策略网络
42
+ self.actor = nn.Linear(hidden_dim, action_dim)
43
+
44
+ # Critic: 价值网络
45
+ self.critic = nn.Linear(hidden_dim, 1)
46
+
47
+ def forward(self, state):
48
+ features = self.feature_layer(state)
49
+
50
+ # Actor输出动作概率
51
+ action_probs = F.softmax(self.actor(features), dim=-1)
52
+
53
+ # Critic输出状态价值
54
+ state_value = self.critic(features)
55
+
56
+ return action_probs, state_value
57
+
58
+
59
+ class ActorCriticAgent:
60
+ """基础的Actor-Critic算法"""
61
+
62
+ def __init__(self,
63
+ env,
64
+ actor_lr=1e-3,
65
+ critic_lr=1e-3,
66
+ gamma=0.99,
67
+ hidden_dim=256,
68
+ render=False):
69
+
70
+ self.env = env
71
+ self.gamma = gamma
72
+ self.render = render
73
+
74
+ self.state_dim = env.observation_space.shape[0]
75
+ self.action_dim = env.action_space.n
76
+
77
+ # 使用共享网络或分离网络
78
+ self.use_shared_network = True
79
+
80
+ if self.use_shared_network:
81
+ # 共享网络版本
82
+ self.actor_critic = ActorCritic(self.state_dim, self.action_dim, hidden_dim)
83
+ self.optimizer = optim.Adam(self.actor_critic.parameters(), lr=actor_lr)
84
+ else:
85
+ # 分离网络版本
86
+ self.actor = nn.Sequential(
87
+ nn.Linear(self.state_dim, hidden_dim),
88
+ nn.ReLU(),
89
+ nn.Linear(hidden_dim, hidden_dim),
90
+ nn.ReLU(),
91
+ nn.Linear(hidden_dim, self.action_dim),
92
+ nn.Softmax(dim=-1)
93
+ )
94
+ self.critic = nn.Sequential(
95
+ nn.Linear(self.state_dim, hidden_dim),
96
+ nn.ReLU(),
97
+ nn.Linear(hidden_dim, hidden_dim),
98
+ nn.ReLU(),
99
+ nn.Linear(hidden_dim, 1)
100
+ )
101
+ self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=actor_lr)
102
+ self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=critic_lr)
103
+
104
+ # 训练统计
105
+ self.training_stats = {
106
+ 'episode_rewards': [],
107
+ 'critic_loss': [],
108
+ 'actor_loss': [],
109
+ 'advantages': []
110
+ }
111
+
112
+ def select_action(self, state):
113
+ """选择动作并返回动作、log概率和状态价值"""
114
+ state = torch.FloatTensor(state).unsqueeze(0)
115
+
116
+ if self.use_shared_network:
117
+ action_probs, state_value = self.actor_critic(state)
118
+ else:
119
+ action_probs = self.actor(state)
120
+ state_value = self.critic(state)
121
+
122
+ m = Categorical(action_probs)
123
+ action = m.sample()
124
+ log_prob = m.log_prob(action)
125
+
126
+ return action.item(), log_prob, state_value
127
+
128
+ def update(self, log_prob, state_value, reward, next_state_value, done):
129
+ """单步更新Actor和Critic"""
130
+ if self.use_shared_network:
131
+ return self._update_shared(log_prob, state_value, reward, next_state_value, done)
132
+ else:
133
+ return self._update_separate(log_prob, state_value, reward, next_state_value, done)
134
+
135
+ def _update_shared(self, log_prob, state_value, reward, next_state_value, done):
136
+ """共享网络更新"""
137
+ # 计算TD目标
138
+ td_target = reward + (1 - done) * self.gamma * next_state_value
139
+ td_target = td_target.detach()
140
+
141
+ # 计算TD误差(优势)
142
+ td_error = td_target - state_value
143
+
144
+ # Critic损失
145
+ critic_loss = td_error.pow(2).mean()
146
+
147
+ # Actor损失
148
+ actor_loss = -(log_prob * td_error.detach()).mean()
149
+ # td_error,实际的未来奖励 - 预期的未来奖励
150
+ # 如果是正数,则加大当前动作的概率,如果是负数则减小动作的概率。
151
+
152
+ # 总损失
153
+ total_loss = actor_loss + 0.5 * critic_loss
154
+
155
+ # 反向传播
156
+ self.optimizer.zero_grad()
157
+ total_loss.backward()
158
+ torch.nn.utils.clip_grad_norm_(self.actor_critic.parameters(), 0.5)
159
+ self.optimizer.step()
160
+
161
+ return actor_loss.item(), critic_loss.item(), td_error.mean().item()
162
+
163
+ def _update_separate(self, log_prob, state_value, reward, next_state_value, done):
164
+ """分离网络更新"""
165
+ # 计算TD目标
166
+ td_target = reward + (1 - done) * self.gamma * next_state_value
167
+ td_target = td_target.detach()
168
+
169
+ # 计算TD误差
170
+ td_error = td_target - state_value
171
+
172
+ # 更新Critic
173
+ critic_loss = td_error.pow(2).mean()
174
+ self.critic_optimizer.zero_grad()
175
+ critic_loss.backward()
176
+ torch.nn.utils.clip_grad_norm_(self.critic.parameters(), 0.5)
177
+ self.critic_optimizer.step()
178
+
179
+ # 更新Actor
180
+ actor_loss = -(log_prob * td_error.detach()).mean()
181
+ self.actor_optimizer.zero_grad()
182
+ actor_loss.backward()
183
+ torch.nn.utils.clip_grad_norm_(self.actor.parameters(), 0.5)
184
+ self.actor_optimizer.step()
185
+
186
+ return actor_loss.item(), critic_loss.item(), td_error.mean().item()
187
+
188
+ def train(self, num_episodes=1000, max_steps=5000):
189
+ """训练智能体"""
190
+ episode_rewards = []
191
+ episode_lengths = []
192
+
193
+ for episode in range(num_episodes):
194
+ state, _ = self.env.reset()
195
+ episode_reward = 0
196
+ episode_losses = {'actor': [], 'critic': [], 'advantage': []}
197
+
198
+ for step in range(max_steps):
199
+ if self.render:
200
+ self.env.render()
201
+
202
+ # 选择动作
203
+ action, log_prob, state_value = self.select_action(state)
204
+
205
+ # 执行动作
206
+ next_state, reward, terminated, truncated, _ = self.env.step(action)
207
+ done = terminated or truncated
208
+
209
+ # 获取下一状态的价值
210
+ with torch.no_grad():
211
+ next_state_tensor = torch.FloatTensor(next_state).unsqueeze(0)
212
+ if self.use_shared_network:
213
+ _, next_state_value = self.actor_critic(next_state_tensor)
214
+ else:
215
+ next_state_value = self.critic(next_state_tensor)
216
+
217
+ # 更新网络
218
+ actor_loss, critic_loss, advantage = self.update(
219
+ log_prob, state_value, reward, next_state_value, done
220
+ )
221
+
222
+ episode_reward += reward
223
+ episode_losses['actor'].append(actor_loss)
224
+ episode_losses['critic'].append(critic_loss)
225
+ episode_losses['advantage'].append(advantage)
226
+
227
+ state = next_state
228
+
229
+ if done:
230
+ break
231
+
232
+ # 记录统计信息
233
+ episode_rewards.append(episode_reward)
234
+ episode_lengths.append(step + 1)
235
+ self.training_stats['episode_rewards'].append(episode_reward)
236
+ self.training_stats['actor_loss'].append(np.mean(episode_losses['actor']))
237
+ self.training_stats['critic_loss'].append(np.mean(episode_losses['critic']))
238
+ self.training_stats['advantages'].append(np.mean(episode_losses['advantage']))
239
+
240
+ # 打印进度
241
+ if (episode + 1) % 20 == 0:
242
+ avg_reward = np.mean(episode_rewards[-20:])
243
+ avg_actor_loss = np.mean(self.training_stats['actor_loss'][-20:])
244
+ avg_critic_loss = np.mean(self.training_stats['critic_loss'][-20:])
245
+ print(f"回合 {episode + 1:4d} | "
246
+ f"奖励: {episode_reward:5.1f} | "
247
+ f"平均奖励: {avg_reward:5.1f} | "
248
+ f"A-Loss: {avg_actor_loss:.4f} | "
249
+ f"C-Loss: {avg_critic_loss:.4f}")
250
+
251
+ return episode_rewards
252
+
253
+
254
+ # ============== 2. A2C (Advantage Actor-Critic) ==============
255
+ class A2C(nn.Module):
256
+ """A2C网络 - 包含熵正则化"""
257
+
258
+ def __init__(self, state_dim, action_dim, hidden_dim=256):
259
+ super(A2C, self).__init__()
260
+
261
+ self.actor = nn.Sequential(
262
+ nn.Linear(state_dim, hidden_dim),
263
+ nn.ReLU(),
264
+ nn.Linear(hidden_dim, hidden_dim),
265
+ nn.ReLU(),
266
+ nn.Linear(hidden_dim, action_dim)
267
+ )
268
+
269
+ self.critic = nn.Sequential(
270
+ nn.Linear(state_dim, hidden_dim),
271
+ nn.ReLU(),
272
+ nn.Linear(hidden_dim, hidden_dim),
273
+ nn.ReLU(),
274
+ nn.Linear(hidden_dim, 1)
275
+ )
276
+
277
+ def forward(self, state):
278
+ action_logits = self.actor(state)
279
+ state_value = self.critic(state)
280
+ return action_logits, state_value
281
+
282
+
283
+ class A2CAgent:
284
+ """A2C算法 - 使用优势函数和熵正则化"""
285
+
286
+ def __init__(self,
287
+ env,
288
+ learning_rate=3e-4,
289
+ gamma=0.99,
290
+ gae_lambda=0.95,
291
+ entropy_coef=0.01,
292
+ value_coef=0.5,
293
+ max_grad_norm=0.5,
294
+ hidden_dim=256):
295
+
296
+ self.env = env
297
+ self.gamma = gamma
298
+ self.gae_lambda = gae_lambda
299
+ self.entropy_coef = entropy_coef
300
+ self.value_coef = value_coef
301
+ self.max_grad_norm = max_grad_norm
302
+
303
+ self.state_dim = env.observation_space.shape[0]
304
+ self.action_dim = env.action_space.n
305
+
306
+ self.network = A2C(self.state_dim, self.action_dim, hidden_dim)
307
+ self.optimizer = optim.Adam(self.network.parameters(), lr=learning_rate)
308
+
309
+ # 存储经验
310
+ self.states = []
311
+ self.actions = []
312
+ self.log_probs = []
313
+ self.rewards = []
314
+ self.values = []
315
+ self.dones = []
316
+
317
+ self.training_stats = {
318
+ 'episode_rewards': [],
319
+ 'policy_loss': [],
320
+ 'value_loss': [],
321
+ 'entropy': []
322
+ }
323
+
324
+ def select_action(self, state):
325
+ """选择动作并存储经验"""
326
+ state = torch.FloatTensor(state).unsqueeze(0)
327
+ action_logits, state_value = self.network(state)
328
+
329
+ action_probs = F.softmax(action_logits, dim=-1)
330
+ m = Categorical(action_probs)
331
+ action = m.sample()
332
+ log_prob = m.log_prob(action)
333
+
334
+ # 存储经验
335
+ self.states.append(state)
336
+ self.actions.append(action)
337
+ self.log_probs.append(log_prob)
338
+ self.values.append(state_value)
339
+
340
+ return action.item()
341
+
342
+ def compute_gae(self, rewards, values, dones):
343
+ """计算广义优势估计(GAE)"""
344
+ advantages = []
345
+ gae = 0
346
+
347
+ values = values + [0] # 添加最后一个虚拟价值
348
+
349
+ for t in reversed(range(len(rewards))):
350
+ delta = rewards[t] + self.gamma * values[t + 1] * (1 - dones[t]) - values[t]
351
+ gae = delta + self.gamma * self.gae_lambda * (1 - dones[t]) * gae
352
+ advantages.insert(0, gae)
353
+
354
+ returns = [adv + val for adv, val in zip(advantages, values[:-1])]
355
+
356
+ return advantages, returns
357
+
358
+ def update(self):
359
+ """更新网络参数"""
360
+ if len(self.rewards) == 0:
361
+ return
362
+
363
+ # 转换为tensor
364
+ states = torch.cat(self.states)
365
+ actions = torch.cat(self.actions)
366
+ old_log_probs = torch.cat(self.log_probs).detach()
367
+ rewards = self.rewards
368
+ values = [v.squeeze() for v in self.values]
369
+ dones = self.dones
370
+
371
+ # 计算GAE和returns
372
+ advantages, returns = self.compute_gae(rewards, values, dones)
373
+ advantages = torch.FloatTensor(advantages)
374
+ returns = torch.FloatTensor(returns)
375
+
376
+ # 标准化优势
377
+ advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
378
+
379
+ # 重新计算log probs和values
380
+ action_logits, state_values = self.network(states)
381
+ action_probs = F.softmax(action_logits, dim=-1)
382
+ m = Categorical(action_probs)
383
+ log_probs = m.log_prob(actions)
384
+ entropy = m.entropy().mean()
385
+
386
+ # 计算损失
387
+ state_values = state_values.squeeze()
388
+ value_loss = F.mse_loss(state_values, returns)
389
+
390
+ policy_loss = -(log_probs * advantages.detach()).mean()
391
+ entropy_loss = -self.entropy_coef * entropy
392
+
393
+ total_loss = policy_loss + self.value_coef * value_loss + entropy_loss
394
+
395
+ # 反向传播
396
+ self.optimizer.zero_grad()
397
+ total_loss.backward()
398
+ torch.nn.utils.clip_grad_norm_(self.network.parameters(), self.max_grad_norm)
399
+ self.optimizer.step()
400
+
401
+ # 记录统计
402
+ policy_loss_val = policy_loss.item()
403
+ value_loss_val = value_loss.item()
404
+ entropy_val = entropy.item()
405
+
406
+ # 清空经验
407
+ self.states = []
408
+ self.actions = []
409
+ self.log_probs = []
410
+ self.rewards = []
411
+ self.values = []
412
+ self.dones = []
413
+
414
+ return policy_loss_val, value_loss_val, entropy_val
415
+
416
+ def train(self, num_episodes=1000, max_steps=500, update_frequency=10):
417
+ """训练A2C智能体"""
418
+ episode_rewards = []
419
+ episode_lengths = []
420
+
421
+ for episode in range(num_episodes):
422
+ state, _ = self.env.reset()
423
+ episode_reward = 0
424
+
425
+ for step in range(max_steps):
426
+ # 选择动作
427
+ action = self.select_action(state)
428
+
429
+ # 执行动作
430
+ next_state, reward, terminated, truncated, _ = self.env.step(action)
431
+ done = terminated or truncated
432
+
433
+ # 存储经验
434
+ self.rewards.append(reward)
435
+ self.dones.append(done)
436
+
437
+ episode_reward += reward
438
+ state = next_state
439
+
440
+ if done:
441
+ break
442
+
443
+ episode_rewards.append(episode_reward)
444
+ episode_lengths.append(step + 1)
445
+ self.training_stats['episode_rewards'].append(episode_reward)
446
+
447
+ # 更新网络
448
+ if (episode + 1) % update_frequency == 0:
449
+ policy_loss, value_loss, entropy = self.update()
450
+ self.training_stats['policy_loss'].append(policy_loss)
451
+ self.training_stats['value_loss'].append(value_loss)
452
+ self.training_stats['entropy'].append(entropy)
453
+
454
+ # 打印进度
455
+ if (episode + 1) % 20 == 0:
456
+ avg_reward = np.mean(episode_rewards[-20:])
457
+ print(f"A2C - 回合 {episode + 1:4d} | "
458
+ f"奖励: {episode_reward:5.1f} | "
459
+ f"平均奖励: {avg_reward:5.1f}")
460
+
461
+ return episode_rewards
462
+
463
+
464
+ # ============== 3. 可视化对比 ==============
465
+ def compare_algorithms():
466
+ """对比基础Actor-Critic和A2C"""
467
+ print("\n" + "=" * 70)
468
+ print("Actor-Critic算法对比实验")
469
+ print("=" * 70)
470
+
471
+ # 创建环境
472
+ env = gym.make('CartPole-v1')
473
+
474
+ # 1. 基础Actor-Critic
475
+ print("\n1. 训练基础Actor-Critic...")
476
+ ac_agent = ActorCriticAgent(env, actor_lr=1e-3, critic_lr=1e-3)
477
+ ac_rewards = ac_agent.train(num_episodes=300)
478
+
479
+ # 2. A2C
480
+ print("\n2. 训练A2C (带GAE和熵正则化)...")
481
+ a2c_agent = A2CAgent(env, learning_rate=3e-4)
482
+ a2c_rewards = a2c_agent.train(num_episodes=300)
483
+
484
+ # 可视化对比
485
+ fig, axes = plt.subplots(2, 3, figsize=(16, 10))
486
+
487
+ # 1. 奖励曲线对比
488
+ ax1 = axes[0, 0]
489
+ ax1.plot(ac_rewards, alpha=0.6, label='Actor-Critic', color='blue')
490
+ ax1.plot(a2c_rewards, alpha=0.6, label='A2C', color='red')
491
+
492
+ # 平滑曲线
493
+ window = 20
494
+ ac_smooth = np.convolve(ac_rewards, np.ones(window) / window, mode='valid')
495
+ a2c_smooth = np.convolve(a2c_rewards, np.ones(window) / window, mode='valid')
496
+
497
+ ax1.plot(range(window - 1, len(ac_smooth) + window - 1), ac_smooth,
498
+ 'b-', linewidth=2, label='AC (平滑)')
499
+ ax1.plot(range(window - 1, len(a2c_smooth) + window - 1), a2c_smooth,
500
+ 'r-', linewidth=2, label='A2C (平滑)')
501
+
502
+ ax1.set_xlabel('回合')
503
+ ax1.set_ylabel('总奖励')
504
+ ax1.set_title('训练奖励对比')
505
+ ax1.legend()
506
+ ax1.grid(True, alpha=0.3)
507
+
508
+ # 2. Actor损失对比
509
+ ax2 = axes[0, 1]
510
+ if hasattr(ac_agent.training_stats, 'actor_loss'):
511
+ ax2.plot(ac_agent.training_stats['actor_loss'],
512
+ label='Actor-Critic', color='blue', alpha=0.7)
513
+ ax2.plot(a2c_agent.training_stats['policy_loss'],
514
+ label='A2C', color='red', alpha=0.7)
515
+ ax2.set_xlabel('更新步')
516
+ ax2.set_ylabel('策略损失')
517
+ ax2.set_title('Actor损失对比')
518
+ ax2.legend()
519
+ ax2.grid(True, alpha=0.3)
520
+
521
+ # 3. Critic损失对比
522
+ ax3 = axes[0, 2]
523
+ if hasattr(ac_agent.training_stats, 'critic_loss'):
524
+ ax3.plot(ac_agent.training_stats['critic_loss'],
525
+ label='Actor-Critic', color='blue', alpha=0.7)
526
+ ax3.plot(a2c_agent.training_stats['value_loss'],
527
+ label='A2C', color='red', alpha=0.7)
528
+ ax3.set_xlabel('更新步')
529
+ ax3.set_ylabel('价值损失')
530
+ ax3.set_title('Critic损失对比')
531
+ ax3.legend()
532
+ ax3.grid(True, alpha=0.3)
533
+
534
+ # 4. 熵变化 (A2C特有)
535
+ ax4 = axes[1, 0]
536
+ if a2c_agent.training_stats['entropy']:
537
+ ax4.plot(a2c_agent.training_stats['entropy'],
538
+ color='green', linewidth=2)
539
+ ax4.set_xlabel('更新步')
540
+ ax4.set_ylabel('策略熵')
541
+ ax4.set_title('A2C - 策略熵变化')
542
+ ax4.grid(True, alpha=0.3)
543
+
544
+ # 5. 收敛速度箱线图
545
+ ax5 = axes[1, 1]
546
+
547
+ # 计算收敛回合数
548
+ def get_convergence_episode(rewards, threshold=450):
549
+ for i, r in enumerate(rewards):
550
+ if r >= threshold and np.mean(rewards[max(0, i - 10):i + 1]) >= threshold:
551
+ return i + 1
552
+ return len(rewards)
553
+
554
+ # 多次实验
555
+ n_trials = 10
556
+ ac_convergence = []
557
+ a2c_convergence = []
558
+
559
+ for trial in range(n_trials):
560
+ env_trial = gym.make('CartPole-v1')
561
+
562
+ ac_trial = ActorCriticAgent(env_trial, actor_lr=1e-3, critic_lr=1e-3)
563
+ ac_rewards_trial = ac_trial.train(num_episodes=200)
564
+ ac_convergence.append(get_convergence_episode(ac_rewards_trial))
565
+
566
+ a2c_trial = A2CAgent(env_trial, learning_rate=3e-4)
567
+ a2c_rewards_trial = a2c_trial.train(num_episodes=200)
568
+ a2c_convergence.append(get_convergence_episode(a2c_rewards_trial))
569
+
570
+ bp = ax5.boxplot([ac_convergence, a2c_convergence],
571
+ labels=['Actor-Critic', 'A2C'],
572
+ patch_artist=True)
573
+ bp['boxes'][0].set_facecolor('lightblue')
574
+ bp['boxes'][1].set_facecolor('lightcoral')
575
+ ax5.set_ylabel('收敛所需回合数')
576
+ ax5.set_title('收敛速度对比 (越低越好)')
577
+ ax5.grid(True, alpha=0.3)
578
+
579
+ # 6. 算法对比表格
580
+ ax6 = axes[1, 2]
581
+ ax6.axis('off')
582
+
583
+ # 创建对比表格
584
+ col_labels = ['算法', '更新方式', '优势估计', '熵正则', '收敛速度', '稳定性']
585
+ data = [
586
+ ['Actor-Critic', '单步TD', 'TD误差', '无', '较慢', '中等'],
587
+ ['A2C', '多步回报', 'GAE', '有', '快', '稳定']
588
+ ]
589
+
590
+ table = ax6.table(cellText=data,
591
+ colLabels=col_labels,
592
+ cellLoc='center',
593
+ loc='center',
594
+ bbox=[0, 0, 1, 1])
595
+ table.auto_set_font_size(False)
596
+ table.set_fontsize(10)
597
+ table.scale(1, 1.5)
598
+
599
+ plt.suptitle('Actor-Critic vs A2C: 性能对比分析', fontsize=14, y=1.02)
600
+ plt.tight_layout()
601
+ plt.savefig('ac_vs_a2c_comparison.png', dpi=150, bbox_inches='tight')
602
+ plt.show()
603
+
604
+ # 打印总结
605
+ print("\n" + "=" * 70)
606
+ print("实验总结:")
607
+ print("=" * 70)
608
+ print(f"\n基础Actor-Critic:")
609
+ print(f" - 平均收敛回合: {np.mean(ac_convergence):.1f} ± {np.std(ac_convergence):.1f}")
610
+ print(f" - 最终平均奖励: {np.mean(ac_rewards[-50:]):.1f}")
611
+
612
+ print(f"\nA2C (带GAE和熵正则):")
613
+ print(f" - 平均收敛回合: {np.mean(a2c_convergence):.1f} ± {np.std(a2c_convergence):.1f}")
614
+ print(f" - 最终平均奖励: {np.mean(a2c_rewards[-50:]):.1f}")
615
+
616
+ improvement = (np.mean(ac_convergence) - np.mean(a2c_convergence)) / np.mean(ac_convergence) * 100
617
+ print(f"\nA2C收敛速度提升: {improvement:.1f}%")
618
+
619
+ return ac_agent, a2c_agent
620
+
621
+
622
+ # ============== 4. 主函数 ==============
623
+ def main():
624
+ """主函数"""
625
+ print("=" * 70)
626
+ print("Actor-Critic 方法完整实现")
627
+ print("=" * 70)
628
+
629
+ # 参数设置
630
+ import argparse
631
+ parser = argparse.ArgumentParser(description='Actor-Critic算法实现')
632
+ parser.add_argument('--algo', type=str, default='a2c',
633
+ choices=['ac', 'a2c', 'compare'],
634
+ help='选择算法: ac (基础Actor-Critic), a2c, compare')
635
+ parser.add_argument('--episodes', type=int, default=5000,
636
+ help='训练回合数')
637
+ parser.add_argument('--render', action='store_true',
638
+ help='渲染环境')
639
+ args = parser.parse_args()
640
+
641
+ env = gym.make('CartPole-v1')
642
+
643
+ if args.algo == 'ac':
644
+ print("\n训练基础Actor-Critic...")
645
+ agent = ActorCriticAgent(env, render=args.render)
646
+ rewards = agent.train(num_episodes=args.episodes)
647
+
648
+ plt.figure(figsize=(10, 6))
649
+ plt.plot(rewards, alpha=0.6, label='Episode Reward')
650
+
651
+ # 平滑曲线
652
+ window = 20
653
+ if len(rewards) >= window:
654
+ smoothed = np.convolve(rewards, np.ones(window) / window, mode='valid')
655
+ plt.plot(range(window - 1, len(smoothed) + window - 1),
656
+ smoothed, 'r-', linewidth=2, label=f'Moving Avg (window={window})')
657
+
658
+ plt.xlabel('Episode')
659
+ plt.ylabel('Total Reward')
660
+ plt.title('Actor-Critic Training on CartPole')
661
+ plt.legend()
662
+ plt.grid(True, alpha=0.3)
663
+ plt.savefig('actor_critic_training.png', dpi=150)
664
+ plt.show()
665
+
666
+ elif args.algo == 'a2c':
667
+ print("\n训练A2C...")
668
+ agent = A2CAgent(env, learning_rate=3e-4)
669
+ rewards = agent.train(num_episodes=args.episodes)
670
+
671
+ fig, axes = plt.subplots(2, 2, figsize=(12, 8))
672
+
673
+ # 奖励曲线
674
+ axes[0, 0].plot(rewards, alpha=0.6, color='blue')
675
+ if len(rewards) >= 20:
676
+ smoothed = np.convolve(rewards, np.ones(20) / 20, mode='valid')
677
+ axes[0, 0].plot(range(19, len(smoothed) + 19), smoothed, 'r-', linewidth=2)
678
+ axes[0, 0].set_xlabel('Episode')
679
+ axes[0, 0].set_ylabel('Total Reward')
680
+ axes[0, 0].set_title('A2C Training Rewards')
681
+ axes[0, 0].grid(True, alpha=0.3)
682
+
683
+ # 策略损失
684
+ axes[0, 1].plot(agent.training_stats['policy_loss'], color='purple')
685
+ axes[0, 1].set_xlabel('Update Step')
686
+ axes[0, 1].set_ylabel('Policy Loss')
687
+ axes[0, 1].set_title('Policy Loss')
688
+ axes[0, 1].grid(True, alpha=0.3)
689
+
690
+ # 价值损失
691
+ axes[1, 0].plot(agent.training_stats['value_loss'], color='orange')
692
+ axes[1, 0].set_xlabel('Update Step')
693
+ axes[1, 0].set_ylabel('Value Loss')
694
+ axes[1, 0].set_title('Value Loss')
695
+ axes[1, 0].grid(True, alpha=0.3)
696
+
697
+ # 策略熵
698
+ axes[1, 1].plot(agent.training_stats['entropy'], color='green')
699
+ axes[1, 1].set_xlabel('Update Step')
700
+ axes[1, 1].set_ylabel('Policy Entropy')
701
+ axes[1, 1].set_title('Policy Entropy')
702
+ axes[1, 1].grid(True, alpha=0.3)
703
+
704
+ plt.tight_layout()
705
+ plt.savefig('a2c_training.png', dpi=150)
706
+ plt.show()
707
+
708
+ else: # compare
709
+ compare_algorithms()
710
+
711
+ env.close()
712
+ print("\n训练完成!")
713
+
714
+
715
+ if __name__ == "__main__":
716
+ main()
examples/tutorials/rl/cart_pole/step_2_ppo_clip.py ADDED
@@ -0,0 +1,739 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ 🌱 **第一代:REINFORCE (1992)**
5
+ ├── 核心创新:蒙特卡洛策略梯度
6
+ ├── 公式:∇θ ∝ Σ ∇θ log π * Gₜ
7
+ ├── 痛点:▸ 必须等回合结束 ▸ 方差极大 ▸ 样本效率低
8
+ └── 贡献:开创了策略梯度范式
9
+
10
+ 🌿 **第二代:REINFORCE with Baseline (2000s)**
11
+ ├── 核心创新:引入基线降低方差
12
+ ├── 公式:∇θ ∝ Σ ∇θ log π * (Gₜ - b(s))
13
+ ├── 痛点:▸ 仍需完整回合 ▸ 基线需要单独学习
14
+ └── 贡献:方差降低,训练更稳定
15
+
16
+ 🍃 **第三代:Actor-Critic (2000s)**
17
+ ├── 核心创新:单步更新,不再等待回合
18
+ ├── 公式:∇θ ∝ ∇θ log π * (r + γV(s') - V(s))
19
+ ├── 痛点:▸ 单步TD偏差大 ▸ 价值估计不准
20
+ └── 贡献:实现了真正的在线学习
21
+
22
+ 🌳 **第四代:A2C/A3C (2016)**
23
+ ├── 核心创新:优势函数 + 多步回报
24
+ ├── 公式:∇θ ∝ ∇θ log π * Â(s,a)
25
+ ├── GAE:Â = Σ (γλ)ᵏ δ_{t+k} ★滑动平均★
26
+ ├── 痛点:▸ 更新步长敏感 ▸ 容易破坏策略
27
+ └── 贡献:GAE成为标准配置
28
+
29
+ 🌲 **第五代:PPO (2017)**
30
+ ├── 核心创新:**所有前人智慧的集大成**
31
+ ├── 1️⃣ 继承AC/A2C:单步更新 + GAE
32
+ ├── 2️⃣ 继承重要性采样:可以复用数据
33
+ ├── 3️⃣ ✨ **独创:Clipped Surrogate Objective**
34
+ │ L = min(r(θ)Â, clip(r(θ), 1-ε, 1+ε)Â)
35
+ ├── 4️⃣ ✨ **独创:自适应KL惩罚**
36
+ └── 贡献:**稳定、高效、易用,成为事实标准**
37
+
38
+ 由于函数的优化步长受到优势的直接影响。当优步长过大时很容易直接跨过,导致优化失败。
39
+ 虽然A2C 已经对历史优势进行了移动平均,但问题仍然存在。
40
+ 尤其是当训练的早期价值函数还没有获得较好的训练,这种问题尤其容易出现。
41
+
42
+ clip 当新模型相比于旧模型动作的概率变化幅度 radio 已经较大时,则不再进行优化。
43
+ clip 操作会对 radio 值进行截断,会切断梯度反向传播的通道。
44
+
45
+ 同时对动作的概率做熵最大化,目的是避免策略函数过早确定化,保持探索能力。
46
+
47
+ PPO-clip 使用radio * advantages
48
+ 其中:ratio = probs / old_probs
49
+ 而不是 A2C 中的 log_probs * advantages
50
+
51
+
52
+ """
53
+ import gymnasium as gym
54
+ import torch
55
+ import torch.nn as nn
56
+ import torch.optim as optim
57
+ from torch.distributions import Categorical
58
+ import numpy as np
59
+ import matplotlib.pyplot as plt
60
+ from collections import deque
61
+ import time
62
+
63
+
64
+ # ============== 1. PPO-Clip Network ==============
65
+ class PPONetwork(nn.Module):
66
+ """PPO Network: Shared feature extractor + Actor head + Critic head"""
67
+
68
+ def __init__(self, state_dim, action_dim, hidden_dim=64):
69
+ super(PPONetwork, self).__init__()
70
+
71
+ # Shared feature extractor
72
+ self.feature_layer = nn.Sequential(
73
+ nn.Linear(state_dim, hidden_dim),
74
+ nn.Tanh(),
75
+ nn.Linear(hidden_dim, hidden_dim),
76
+ nn.Tanh()
77
+ )
78
+
79
+ # Actor head: action probability distribution
80
+ self.actor = nn.Sequential(
81
+ nn.Linear(hidden_dim, action_dim),
82
+ nn.Softmax(dim=-1)
83
+ )
84
+
85
+ # Critic head: state value
86
+ self.critic = nn.Linear(hidden_dim, 1)
87
+
88
+ def forward(self, state):
89
+ features = self.feature_layer(state)
90
+ action_probs = self.actor(features)
91
+ state_value = self.critic(features)
92
+ return action_probs, state_value
93
+
94
+
95
+ # ============== 2. PPO-Clip Agent ==============
96
+ class PPOClipAgent:
97
+ """PPO-Clip Agent with Gymnasium API"""
98
+
99
+ def __init__(self,
100
+ env,
101
+ learning_rate=3e-4,
102
+ gamma=0.99,
103
+ gae_lambda=0.95,
104
+ clip_epsilon=0.2,
105
+ entropy_coef=0.01,
106
+ value_coef=0.5,
107
+ max_grad_norm=0.5,
108
+ update_epochs=4,
109
+ mini_batch_size=64,
110
+ horizon=2048,
111
+ hidden_dim=64):
112
+
113
+ self.env = env
114
+ self.gamma = gamma
115
+ self.gae_lambda = gae_lambda
116
+ self.clip_epsilon = clip_epsilon
117
+ self.entropy_coef = entropy_coef
118
+ self.value_coef = value_coef
119
+ self.max_grad_norm = max_grad_norm
120
+ self.update_epochs = update_epochs
121
+ self.mini_batch_size = mini_batch_size
122
+ self.horizon = horizon
123
+
124
+ self.state_dim = env.observation_space.shape[0]
125
+ self.action_dim = env.action_space.n
126
+
127
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
128
+
129
+ self.network = PPONetwork(self.state_dim, self.action_dim, hidden_dim).to(self.device)
130
+ self.optimizer = optim.Adam(self.network.parameters(), lr=learning_rate)
131
+
132
+ self.reset_buffer()
133
+
134
+ self.training_stats = {
135
+ 'episode_rewards': [],
136
+ 'episode_lengths': [],
137
+ 'policy_loss': [],
138
+ 'value_loss': [],
139
+ 'entropy': [],
140
+ 'clip_fraction': [],
141
+ 'explained_variance': []
142
+ }
143
+
144
+ # Statistics for table logging
145
+ self.recent_rewards = deque(maxlen=20)
146
+ self.recent_lengths = deque(maxlen=20)
147
+
148
+ def reset_buffer(self):
149
+ """Reset experience buffer"""
150
+ self.buffer = {
151
+ 'states': [],
152
+ 'actions': [],
153
+ 'rewards': [],
154
+ 'next_states': [],
155
+ 'dones': [],
156
+ 'log_probs': [],
157
+ 'values': []
158
+ }
159
+
160
+ def select_action(self, state, eval_mode=False):
161
+ """Select action using current policy"""
162
+ state = torch.FloatTensor(state).unsqueeze(0).to(self.device)
163
+
164
+ with torch.no_grad():
165
+ action_probs, state_value = self.network(state)
166
+
167
+ m = Categorical(action_probs)
168
+ action = m.sample()
169
+ log_prob = m.log_prob(action)
170
+
171
+ if eval_mode:
172
+ action = torch.argmax(action_probs)
173
+ log_prob = m.log_prob(action)
174
+
175
+ return action.item(), log_prob.cpu().item(), state_value.cpu().item()
176
+
177
+ def store_transition(self, state, action, reward, next_state, done, log_prob, value):
178
+ """Store one step of experience"""
179
+ self.buffer['states'].append(state)
180
+ self.buffer['actions'].append(action)
181
+ self.buffer['rewards'].append(reward)
182
+ self.buffer['next_states'].append(next_state)
183
+ self.buffer['dones'].append(done)
184
+ self.buffer['log_probs'].append(log_prob)
185
+ self.buffer['values'].append(value)
186
+
187
+ def compute_gae(self, rewards, values, next_values, dones):
188
+ """Compute Generalized Advantage Estimation"""
189
+ advantages = []
190
+ gae = 0
191
+
192
+ for t in reversed(range(len(rewards))):
193
+ delta = rewards[t] + self.gamma * next_values[t] * (1 - dones[t]) - values[t]
194
+ gae = delta + self.gamma * self.gae_lambda * (1 - dones[t]) * gae
195
+ advantages.insert(0, gae)
196
+
197
+ return advantages
198
+
199
+ def update(self):
200
+ """PPO-Clip update"""
201
+ if len(self.buffer['rewards']) < self.mini_batch_size:
202
+ return None
203
+
204
+ # Convert to tensors
205
+ states = torch.FloatTensor(np.array(self.buffer['states'])).to(self.device)
206
+ actions = torch.LongTensor(self.buffer['actions']).to(self.device)
207
+ old_log_probs = torch.FloatTensor(self.buffer['log_probs']).to(self.device)
208
+
209
+ values = torch.FloatTensor(self.buffer['values']).to(self.device)
210
+ next_states = torch.FloatTensor(np.array(self.buffer['next_states'])).to(self.device)
211
+ dones = torch.FloatTensor(self.buffer['dones']).to(self.device)
212
+ rewards = self.buffer['rewards']
213
+
214
+ # Compute next state values
215
+ with torch.no_grad():
216
+ _, next_values = self.network(next_states)
217
+ next_values = next_values.squeeze().cpu().numpy()
218
+
219
+ # Compute advantages and returns
220
+ advantages = self.compute_gae(
221
+ rewards,
222
+ values.cpu().numpy(),
223
+ next_values,
224
+ dones.cpu().numpy()
225
+ )
226
+ advantages = torch.FloatTensor(advantages).to(self.device)
227
+ returns = advantages + values
228
+
229
+ # Normalize advantages
230
+ advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
231
+
232
+ total_policy_loss = 0
233
+ total_value_loss = 0
234
+ total_entropy = 0
235
+ total_clip_fraction = 0
236
+
237
+ dataset_size = len(states)
238
+
239
+ # Multiple epochs of PPO update
240
+ for epoch in range(self.update_epochs):
241
+ indices = np.random.permutation(dataset_size)
242
+
243
+ for start in range(0, dataset_size, self.mini_batch_size):
244
+ end = start + self.mini_batch_size
245
+ batch_indices = indices[start:end]
246
+
247
+ batch_states = states[batch_indices]
248
+ batch_actions = actions[batch_indices]
249
+ batch_old_log_probs = old_log_probs[batch_indices]
250
+ batch_advantages = advantages[batch_indices]
251
+ batch_returns = returns[batch_indices]
252
+
253
+ # Forward pass
254
+ action_probs, state_values = self.network(batch_states)
255
+ state_values = state_values.squeeze()
256
+
257
+ # Compute log probs and entropy
258
+ m = Categorical(action_probs)
259
+ log_probs = m.log_prob(batch_actions)
260
+ entropy = m.entropy().mean()
261
+
262
+ # Importance sampling ratio
263
+ ratio = torch.exp(log_probs - batch_old_log_probs)
264
+
265
+ # PPO-Clip objective
266
+ surr1 = ratio * batch_advantages
267
+ surr2 = torch.clamp(ratio, 1 - self.clip_epsilon, 1 + self.clip_epsilon) * batch_advantages
268
+ policy_loss = -torch.min(surr1, surr2).mean()
269
+
270
+ # Value loss
271
+ value_loss = nn.MSELoss()(state_values, batch_returns)
272
+
273
+ # Entropy loss
274
+ entropy_loss = -self.entropy_coef * entropy
275
+
276
+ # Total loss
277
+ total_loss = policy_loss + self.value_coef * value_loss + entropy_loss
278
+
279
+ # Backward pass
280
+ self.optimizer.zero_grad()
281
+ total_loss.backward()
282
+ torch.nn.utils.clip_grad_norm_(self.network.parameters(), self.max_grad_norm)
283
+ self.optimizer.step()
284
+
285
+ total_policy_loss += policy_loss.item()
286
+ total_value_loss += value_loss.item()
287
+ total_entropy += entropy.item()
288
+
289
+ # Compute clip fraction
290
+ with torch.no_grad():
291
+ clip_mask = (ratio < 1 - self.clip_epsilon) | (ratio > 1 + self.clip_epsilon)
292
+ clip_fraction = clip_mask.float().mean().item()
293
+ total_clip_fraction += clip_fraction
294
+
295
+ n_updates = self.update_epochs * (dataset_size // self.mini_batch_size + 1)
296
+
297
+ # Compute explained variance
298
+ with torch.no_grad():
299
+ _, pred_values = self.network(states)
300
+ pred_values = pred_values.squeeze()
301
+ explained_variance = 1 - torch.var(returns - pred_values) / (torch.var(returns) + 1e-8)
302
+ explained_variance = explained_variance.cpu().item()
303
+
304
+ stats = {
305
+ 'policy_loss': total_policy_loss / n_updates,
306
+ 'value_loss': total_value_loss / n_updates,
307
+ 'entropy': total_entropy / n_updates,
308
+ 'clip_fraction': total_clip_fraction / n_updates,
309
+ 'explained_variance': explained_variance
310
+ }
311
+
312
+ self.reset_buffer()
313
+
314
+ return stats
315
+
316
+ def print_header(self):
317
+ """Print table header"""
318
+ print("\n" + "=" * 130)
319
+ print(
320
+ f"{'Episode':>8} | {'Avg Reward':>12} | {'Avg Length':>12} | {'Policy Loss':>12} | {'Value Loss':>12} | {'Entropy':>10} | {'Clip%':>10} | {'Expl Var':>12} | {'Time':>10}")
321
+ print("-" * 130)
322
+
323
+ def print_row(self, episode, total_episodes, stats=None, elapsed_time=None):
324
+ """Print table row"""
325
+
326
+ # Update statistics cache
327
+ if len(self.training_stats['episode_rewards']) > 0:
328
+ self.recent_rewards.append(self.training_stats['episode_rewards'][-1])
329
+ if len(self.training_stats['episode_lengths']) > 0:
330
+ self.recent_lengths.append(self.training_stats['episode_lengths'][-1])
331
+
332
+ # Calculate averages
333
+ avg_reward = np.mean(self.recent_rewards) if self.recent_rewards else 0
334
+ avg_length = np.mean(self.recent_lengths) if self.recent_lengths else 0
335
+
336
+ # Format output
337
+ if stats:
338
+ print(f"{episode:>8}/{total_episodes} | "
339
+ f"{avg_reward:>12.2f} | "
340
+ f"{avg_length:>12.1f} | "
341
+ f"{stats['policy_loss']:>12.4f} | "
342
+ f"{stats['value_loss']:>12.4f} | "
343
+ f"{stats['entropy']:>10.4f} | "
344
+ f"{stats['clip_fraction']:>10.3f} | "
345
+ f"{stats['explained_variance']:>12.3f} | "
346
+ f"{elapsed_time:>10.1f}s")
347
+ else:
348
+ print(f"{episode:>8}/{total_episodes} | "
349
+ f"{avg_reward:>12.2f} | "
350
+ f"{avg_length:>12.1f} | "
351
+ f"{'-':>12} | "
352
+ f"{'-':>12} | "
353
+ f"{'-':>10} | "
354
+ f"{'-':>10} | "
355
+ f"{'-':>12} | "
356
+ f"{elapsed_time:>10.1f}s")
357
+
358
+ def print_summary(self, total_time, num_episodes):
359
+ """Print training summary"""
360
+ print("=" * 130)
361
+ print(f"\n🎯 Training completed! Episodes: {num_episodes}, Total time: {total_time:.1f}s")
362
+
363
+ if len(self.training_stats['episode_rewards']) >= 20:
364
+ final_avg_reward = np.mean(self.training_stats['episode_rewards'][-20:])
365
+ final_avg_length = np.mean(self.training_stats['episode_lengths'][-20:])
366
+ print(f"📊 Last 20 episodes - Avg Reward: {final_avg_reward:.2f}, Avg Length: {final_avg_length:.1f}")
367
+
368
+ if self.training_stats['policy_loss']:
369
+ print(f"📉 Final Policy Loss: {self.training_stats['policy_loss'][-1]:.4f}")
370
+ print(f"📉 Final Value Loss: {self.training_stats['value_loss'][-1]:.4f}")
371
+ print(f"🎲 Final Entropy: {self.training_stats['entropy'][-1]:.4f}")
372
+ print(f"✂️ Final Clip Fraction: {self.training_stats['clip_fraction'][-1]:.3f}")
373
+ print(f"📈 Final Explained Variance: {self.training_stats['explained_variance'][-1]:.3f}")
374
+
375
+ print("=" * 130)
376
+
377
+ def train(self, num_episodes=1000, max_steps_per_episode=500, log_interval=20):
378
+ """Train PPO-Clip agent"""
379
+
380
+ print("\n" + "🚀" * 65)
381
+ print("PPO-Clip Training Started (Gymnasium API)")
382
+ print("🚀" * 65)
383
+
384
+ print(f"\n📋 Hyperparameters:")
385
+ print(f" Learning Rate: {self.optimizer.param_groups[0]['lr']:.6f}")
386
+ print(f" Gamma: {self.gamma:.2f}, GAE Lambda: {self.gae_lambda:.2f}")
387
+ print(f" Clip Epsilon: {self.clip_epsilon:.2f}, Update Epochs: {self.update_epochs}")
388
+ print(f" Mini-batch: {self.mini_batch_size}, Horizon: {self.horizon}")
389
+ print(f" Entropy Coef: {self.entropy_coef:.3f}, Value Coef: {self.value_coef:.1f}")
390
+ print(f" Device: {self.device}")
391
+
392
+ self.print_header()
393
+
394
+ total_steps = 0
395
+ episode = 0
396
+ start_time = time.time()
397
+
398
+ while episode < num_episodes:
399
+ # Gymnasium API: reset() returns state, info
400
+ state, _ = self.env.reset()
401
+ episode_reward = 0
402
+ episode_step = 0
403
+
404
+ while episode_step < max_steps_per_episode:
405
+ action, log_prob, value = self.select_action(state)
406
+
407
+ # Gymnasium API: step() returns next_state, reward, terminated, truncated, info
408
+ next_state, reward, terminated, truncated, _ = self.env.step(action)
409
+ done = terminated or truncated
410
+
411
+ self.store_transition(state, action, reward, next_state, done, log_prob, value)
412
+
413
+ episode_reward += reward
414
+ episode_step += 1
415
+ total_steps += 1
416
+ state = next_state
417
+
418
+ # Update when buffer is full or episode ends
419
+ if len(self.buffer['rewards']) >= self.horizon or done:
420
+ stats = self.update()
421
+ if stats:
422
+ self.training_stats['policy_loss'].append(stats['policy_loss'])
423
+ self.training_stats['value_loss'].append(stats['value_loss'])
424
+ self.training_stats['entropy'].append(stats['entropy'])
425
+ self.training_stats['clip_fraction'].append(stats['clip_fraction'])
426
+ self.training_stats['explained_variance'].append(stats['explained_variance'])
427
+
428
+ if done:
429
+ break
430
+
431
+ # Record episode statistics
432
+ self.training_stats['episode_rewards'].append(episode_reward)
433
+ self.training_stats['episode_lengths'].append(episode_step)
434
+
435
+ episode += 1
436
+
437
+ # Log periodically
438
+ if episode % log_interval == 0:
439
+ current_time = time.time()
440
+ elapsed = current_time - start_time
441
+
442
+ recent_stats = None
443
+ if self.training_stats['policy_loss']:
444
+ recent_stats = {
445
+ 'policy_loss': self.training_stats['policy_loss'][-1],
446
+ 'value_loss': self.training_stats['value_loss'][-1],
447
+ 'entropy': self.training_stats['entropy'][-1],
448
+ 'clip_fraction': self.training_stats['clip_fraction'][-1],
449
+ 'explained_variance': self.training_stats['explained_variance'][-1]
450
+ }
451
+
452
+ self.print_row(episode, num_episodes, recent_stats, elapsed)
453
+
454
+ total_time = time.time() - start_time
455
+ self.print_summary(total_time, num_episodes)
456
+
457
+ return self.training_stats['episode_rewards'], self.training_stats['episode_lengths']
458
+
459
+ def save(self, path):
460
+ """Save model"""
461
+ torch.save({
462
+ 'network_state_dict': self.network.state_dict(),
463
+ 'optimizer_state_dict': self.optimizer.state_dict(),
464
+ 'training_stats': self.training_stats
465
+ }, path)
466
+ print(f"\n💾 Model saved to {path}")
467
+
468
+ def load(self, path):
469
+ """Load model"""
470
+ checkpoint = torch.load(path)
471
+ self.network.load_state_dict(checkpoint['network_state_dict'])
472
+ self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
473
+ self.training_stats = checkpoint['training_stats']
474
+ print(f"\n📂 Model loaded from {path}")
475
+
476
+
477
+ # ============== 3. Evaluation Function ==============
478
+ def evaluate_agent(agent, env, num_episodes=10, render=False):
479
+ """Evaluate trained agent"""
480
+ print("\n" + "🎯" * 35)
481
+ print("Evaluation Started")
482
+ print("🎯" * 35)
483
+
484
+ print(f"\n{'Episode':^12} | {'Reward':^12} | {'Length':^12} | {'Avg Reward':^14}")
485
+ print("-" * 60)
486
+
487
+ episode_rewards = []
488
+ episode_lengths = []
489
+
490
+ for episode in range(num_episodes):
491
+ state, _ = env.reset()
492
+ episode_reward = 0
493
+ episode_step = 0
494
+
495
+ while True:
496
+ if render:
497
+ env.render()
498
+ time.sleep(0.02)
499
+
500
+ action, _, _ = agent.select_action(state, eval_mode=True)
501
+ next_state, reward, terminated, truncated, _ = env.step(action)
502
+ done = terminated or truncated
503
+
504
+ episode_reward += reward
505
+ episode_step += 1
506
+ state = next_state
507
+
508
+ if done:
509
+ break
510
+
511
+ episode_rewards.append(episode_reward)
512
+ episode_lengths.append(episode_step)
513
+
514
+ avg_so_far = np.mean(episode_rewards)
515
+ print(f"{episode + 1:^12} | {episode_reward:^12.1f} | {episode_step:^12} | {avg_so_far:^14.2f}")
516
+
517
+ print("-" * 60)
518
+ print(f"\n📊 Evaluation Results ({num_episodes} episodes):")
519
+ print(f" Avg Reward: {np.mean(episode_rewards):.2f} ± {np.std(episode_rewards):.2f}")
520
+ print(f" Avg Length: {np.mean(episode_lengths):.2f} ± {np.std(episode_lengths):.2f}")
521
+ print(f" Max Reward: {np.max(episode_rewards):.2f}")
522
+ print(f" Min Reward: {np.min(episode_rewards):.2f}")
523
+ print(f" Success Rate (>=475): {np.mean(np.array(episode_rewards) >= 475) * 100:.1f}%")
524
+ print("=" * 60)
525
+
526
+ return episode_rewards, episode_lengths
527
+
528
+
529
+ # ============== 4. Visualization Function (English Only) ==============
530
+ def plot_training_results(agent, save_path='ppo_training_results.png'):
531
+ """Plot training results - English only, no Chinese font issues"""
532
+
533
+ stats = agent.training_stats
534
+
535
+ fig, axes = plt.subplots(2, 3, figsize=(18, 10))
536
+ fig.suptitle('PPO-Clip Training Results (CartPole-v1)', fontsize=16, y=1.02)
537
+
538
+ window = 20
539
+
540
+ # 1. Episode Rewards
541
+ ax1 = axes[0, 0]
542
+ rewards = stats['episode_rewards']
543
+ ax1.plot(rewards, alpha=0.3, color='blue', label='Raw Reward')
544
+
545
+ if len(rewards) >= window:
546
+ smoothed = np.convolve(rewards, np.ones(window) / window, mode='valid')
547
+ ax1.plot(range(window - 1, len(smoothed) + window - 1), smoothed,
548
+ 'r-', linewidth=2, label=f'{window}-Episode MA')
549
+
550
+ ax1.axhline(y=500, color='green', linestyle='--', alpha=0.7, label='Target (500)')
551
+ ax1.set_xlabel('Episode')
552
+ ax1.set_ylabel('Total Reward')
553
+ ax1.set_title('Training Rewards')
554
+ ax1.legend()
555
+ ax1.grid(True, alpha=0.3)
556
+
557
+ # 2. Episode Lengths
558
+ ax2 = axes[0, 1]
559
+ lengths = stats['episode_lengths']
560
+ ax2.plot(lengths, alpha=0.3, color='orange', label='Episode Length')
561
+
562
+ if len(lengths) >= window:
563
+ smoothed_length = np.convolve(lengths, np.ones(window) / window, mode='valid')
564
+ ax2.plot(range(window - 1, len(smoothed_length) + window - 1), smoothed_length,
565
+ 'r-', linewidth=2)
566
+
567
+ ax2.set_xlabel('Episode')
568
+ ax2.set_ylabel('Episode Length')
569
+ ax2.set_title('Episode Lengths')
570
+ ax2.grid(True, alpha=0.3)
571
+
572
+ # 3. Policy Loss
573
+ ax3 = axes[0, 2]
574
+ if stats['policy_loss']:
575
+ ax3.plot(stats['policy_loss'], color='purple', linewidth=1.5)
576
+ ax3.set_xlabel('Update Step')
577
+ ax3.set_ylabel('Policy Loss')
578
+ ax3.set_title('Policy Loss')
579
+ ax3.grid(True, alpha=0.3)
580
+
581
+ # 4. Value Loss
582
+ ax4 = axes[1, 0]
583
+ if stats['value_loss']:
584
+ ax4.plot(stats['value_loss'], color='brown', linewidth=1.5)
585
+ ax4.set_xlabel('Update Step')
586
+ ax4.set_ylabel('Value Loss')
587
+ ax4.set_title('Value Loss')
588
+ ax4.grid(True, alpha=0.3)
589
+
590
+ # 5. Policy Entropy
591
+ ax5 = axes[1, 1]
592
+ if stats['entropy']:
593
+ ax5.plot(stats['entropy'], color='green', linewidth=1.5)
594
+ ax5.set_xlabel('Update Step')
595
+ ax5.set_ylabel('Policy Entropy')
596
+ ax5.set_title('Policy Entropy (Exploration)')
597
+ ax5.grid(True, alpha=0.3)
598
+
599
+ # 6. Clip Fraction & Explained Variance
600
+ ax6 = axes[1, 2]
601
+ if stats['clip_fraction']:
602
+ ax6.plot(stats['clip_fraction'], color='red', linewidth=1.5, label='Clip Fraction')
603
+ ax6.set_xlabel('Update Step')
604
+ ax6.set_ylabel('Clip Fraction', color='red')
605
+ ax6.tick_params(axis='y', labelcolor='red')
606
+ ax6.grid(True, alpha=0.3)
607
+
608
+ ax6_twin = ax6.twinx()
609
+ if stats['explained_variance']:
610
+ ax6_twin.plot(stats['explained_variance'], color='blue', linewidth=1.5,
611
+ label='Explained Variance')
612
+ ax6_twin.set_ylabel('Explained Variance', color='blue')
613
+ ax6_twin.tick_params(axis='y', labelcolor='blue')
614
+
615
+ plt.tight_layout()
616
+ plt.savefig(save_path, dpi=150, bbox_inches='tight')
617
+ plt.show()
618
+ print(f"\n📸 Training results saved to {save_path}")
619
+
620
+
621
+ # ============== 5. Main Function ==============
622
+ def main():
623
+ """Main function"""
624
+
625
+ # Create environment (Gymnasium)
626
+ env = gym.make('CartPole-v1')
627
+
628
+ # PPO-Clip hyperparameters
629
+ config = {
630
+ 'learning_rate': 3e-4,
631
+ 'gamma': 0.99,
632
+ 'gae_lambda': 0.95,
633
+ 'clip_epsilon': 0.2,
634
+ 'entropy_coef': 0.01,
635
+ 'value_coef': 0.5,
636
+ 'max_grad_norm': 0.5,
637
+ 'update_epochs': 4,
638
+ 'mini_batch_size': 64,
639
+ 'horizon': 2048,
640
+ 'hidden_dim': 64
641
+ }
642
+
643
+ print("\n" + "=" * 90)
644
+ print("PPO-Clip for CartPole-v1 (Gymnasium)")
645
+ print("=" * 90)
646
+ print("\n📋 Hyperparameters:")
647
+ for key, value in config.items():
648
+ print(f" {key:20}: {value}")
649
+
650
+ # Create PPO agent
651
+ agent = PPOClipAgent(env, **config)
652
+
653
+ try:
654
+ # Train
655
+ rewards, lengths = agent.train(num_episodes=5000, log_interval=20)
656
+
657
+ # Save model
658
+ agent.save('ppo_cartpole_gymnasium.pth')
659
+
660
+ # Plot results
661
+ plot_training_results(agent)
662
+
663
+ # Evaluate
664
+ print("\n")
665
+ eval_env = gym.make('CartPole-v1')
666
+ evaluate_agent(agent, eval_env, num_episodes=20, render=False)
667
+
668
+ # Demo (with rendering)
669
+ print("\n")
670
+ demo_env = gym.make('CartPole-v1', render_mode='human')
671
+ evaluate_agent(agent, demo_env, num_episodes=3, render=True)
672
+
673
+ except KeyboardInterrupt:
674
+ print("\n\n⚠️ Training interrupted, saving model...")
675
+ agent.save('ppo_cartpole_gymnasium_interrupted.pth')
676
+
677
+ finally:
678
+ env.close()
679
+ if 'eval_env' in locals():
680
+ eval_env.close()
681
+ if 'demo_env' in locals():
682
+ demo_env.close()
683
+
684
+
685
+ # ============== 6. Hyperparameter Sweep ==============
686
+ def hyperparameter_sweep():
687
+ """Hyperparameter tuning experiment"""
688
+
689
+ print("\n" + "🔬" * 45)
690
+ print("PPO-Clip Hyperparameter Sweep")
691
+ print("🔬" * 45)
692
+
693
+ clip_values = [0.1, 0.2, 0.3]
694
+ results = {}
695
+
696
+ for clip_eps in clip_values:
697
+ print(f"\n📊 Testing clip_epsilon = {clip_eps}")
698
+ print("-" * 60)
699
+
700
+ env = gym.make('CartPole-v1')
701
+ agent = PPOClipAgent(
702
+ env,
703
+ learning_rate=3e-4,
704
+ clip_epsilon=clip_eps,
705
+ update_epochs=4,
706
+ horizon=2048,
707
+ mini_batch_size=64
708
+ )
709
+
710
+ rewards, _ = agent.train(num_episodes=2000, log_interval=20)
711
+ results[f'ε={clip_eps}'] = rewards
712
+ env.close()
713
+
714
+ # Plot comparison
715
+ plt.figure(figsize=(12, 6))
716
+
717
+ for name, rewards in results.items():
718
+ window = 20
719
+ smoothed = np.convolve(rewards, np.ones(window) / window, mode='valid')
720
+ plt.plot(range(window - 1, len(smoothed) + window - 1), smoothed,
721
+ linewidth=2, label=name)
722
+
723
+ plt.xlabel('Episode')
724
+ plt.ylabel(f'Avg Reward ({window}-Episode MA)')
725
+ plt.title('PPO-Clip: Different Clip Epsilon Comparison')
726
+ plt.legend()
727
+ plt.grid(True, alpha=0.3)
728
+ plt.savefig('ppo_clip_comparison.png', dpi=150)
729
+ plt.show()
730
+
731
+ return results
732
+
733
+
734
+ if __name__ == "__main__":
735
+ main()
736
+ # hyperparameter_sweep() # Uncomment to run hyperparameter sweep
737
+
738
+
739
+
examples/tutorials/rl/cart_pole/step_2_ppo_penalty.py ADDED
@@ -0,0 +1,767 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ KL散度损失就是为了直接约束新旧策略之间的变化程度。
5
+
6
+ 使用 KL散度的好处:
7
+ penalty 更像是,不管优势多大,它总能将其进行可控的相对缩放。
8
+
9
+
10
+ clip
11
+ 当新模型相比旧模型的变化幅度已较大时,clip以阻断优化,会切断梯度传导。
12
+ penalty
13
+ 直接将新模型与旧模型的动作概率约束在一个 target_kl 附近,限制每一次迭代的优化幅度。
14
+ KL散度不会切断梯度传导,总是可以进行有效的优化。
15
+
16
+ """
17
+ import gymnasium as gym
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.optim as optim
21
+ from torch.distributions import Categorical
22
+ import numpy as np
23
+ import matplotlib.pyplot as plt
24
+ from collections import deque
25
+ import time
26
+
27
+
28
+ # ============== 1. PPO-Penalty Network ==============
29
+ class PPONetwork(nn.Module):
30
+ """PPO Network: Shared feature extractor + Actor head + Critic head"""
31
+
32
+ def __init__(self, state_dim, action_dim, hidden_dim=64):
33
+ super(PPONetwork, self).__init__()
34
+
35
+ # Shared feature extractor
36
+ self.feature_layer = nn.Sequential(
37
+ nn.Linear(state_dim, hidden_dim),
38
+ nn.Tanh(),
39
+ nn.Linear(hidden_dim, hidden_dim),
40
+ nn.Tanh()
41
+ )
42
+
43
+ # Actor head: action probability distribution
44
+ self.actor = nn.Sequential(
45
+ nn.Linear(hidden_dim, action_dim),
46
+ nn.Softmax(dim=-1)
47
+ )
48
+
49
+ # Critic head: state value
50
+ self.critic = nn.Linear(hidden_dim, 1)
51
+
52
+ def forward(self, state):
53
+ features = self.feature_layer(state)
54
+ action_probs = self.actor(features)
55
+ state_value = self.critic(features)
56
+ return action_probs, state_value
57
+
58
+
59
+ # ============== 2. PPO-Penalty Agent ==============
60
+ class PPOPenaltyAgent:
61
+ """PPO-Penalty (Adaptive KL Penalty) with Gymnasium API"""
62
+
63
+ def __init__(self,
64
+ env,
65
+ learning_rate=3e-4,
66
+ gamma=0.99,
67
+ gae_lambda=0.95,
68
+ kl_target=0.01, # Target KL divergence
69
+ kl_coef_init=1.0, # Initial KL penalty coefficient
70
+ kl_coef_adapt=1.5, # KL coefficient adaptation rate
71
+ entropy_coef=0.01,
72
+ value_coef=0.5,
73
+ max_grad_norm=0.5,
74
+ update_epochs=10, # PPO-Penalty typically uses more epochs
75
+ mini_batch_size=64,
76
+ horizon=2048,
77
+ hidden_dim=64):
78
+
79
+ self.env = env
80
+ self.gamma = gamma
81
+ self.gae_lambda = gae_lambda
82
+ self.kl_target = kl_target
83
+ self.kl_coef = kl_coef_init
84
+ self.kl_coef_adapt = kl_coef_adapt
85
+ self.entropy_coef = entropy_coef
86
+ self.value_coef = value_coef
87
+ self.max_grad_norm = max_grad_norm
88
+ self.update_epochs = update_epochs
89
+ self.mini_batch_size = mini_batch_size
90
+ self.horizon = horizon
91
+
92
+ self.state_dim = env.observation_space.shape[0]
93
+ self.action_dim = env.action_space.n
94
+
95
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
96
+
97
+ # Policy network
98
+ self.policy = PPONetwork(self.state_dim, self.action_dim, hidden_dim).to(self.device)
99
+ self.optimizer = optim.Adam(self.policy.parameters(), lr=learning_rate)
100
+
101
+ # Old policy for KL calculation
102
+ self.old_policy = PPONetwork(self.state_dim, self.action_dim, hidden_dim).to(self.device)
103
+ self.update_old_policy()
104
+
105
+ self.reset_buffer()
106
+
107
+ self.training_stats = {
108
+ 'episode_rewards': [],
109
+ 'episode_lengths': [],
110
+ 'policy_loss': [],
111
+ 'value_loss': [],
112
+ 'entropy': [],
113
+ 'kl_divergence': [],
114
+ 'kl_coef': [],
115
+ 'explained_variance': []
116
+ }
117
+
118
+ # Statistics for table logging
119
+ self.recent_rewards = deque(maxlen=20)
120
+ self.recent_lengths = deque(maxlen=20)
121
+
122
+ def update_old_policy(self):
123
+ """Copy current policy to old policy"""
124
+ self.old_policy.load_state_dict(self.policy.state_dict())
125
+
126
+ def reset_buffer(self):
127
+ """Reset experience buffer"""
128
+ self.buffer = {
129
+ 'states': [],
130
+ 'actions': [],
131
+ 'rewards': [],
132
+ 'next_states': [],
133
+ 'dones': [],
134
+ 'log_probs': [],
135
+ 'values': []
136
+ }
137
+
138
+ def select_action(self, state, eval_mode=False):
139
+ """Select action using current policy"""
140
+ state = torch.FloatTensor(state).unsqueeze(0).to(self.device)
141
+
142
+ with torch.no_grad():
143
+ action_probs, state_value = self.policy(state)
144
+
145
+ m = Categorical(action_probs)
146
+ action = m.sample()
147
+ log_prob = m.log_prob(action)
148
+
149
+ if eval_mode:
150
+ action = torch.argmax(action_probs)
151
+ log_prob = m.log_prob(action)
152
+
153
+ return action.item(), log_prob.cpu().item(), state_value.cpu().item()
154
+
155
+ def store_transition(self, state, action, reward, next_state, done, log_prob, value):
156
+ """Store one step of experience"""
157
+ self.buffer['states'].append(state)
158
+ self.buffer['actions'].append(action)
159
+ self.buffer['rewards'].append(reward)
160
+ self.buffer['next_states'].append(next_state)
161
+ self.buffer['dones'].append(done)
162
+ self.buffer['log_probs'].append(log_prob)
163
+ self.buffer['values'].append(value)
164
+
165
+ def compute_gae(self, rewards, values, next_values, dones):
166
+ """Compute Generalized Advantage Estimation"""
167
+ advantages = []
168
+ gae = 0
169
+
170
+ for t in reversed(range(len(rewards))):
171
+ delta = rewards[t] + self.gamma * next_values[t] * (1 - dones[t]) - values[t]
172
+ gae = delta + self.gamma * self.gae_lambda * (1 - dones[t]) * gae
173
+ advantages.insert(0, gae)
174
+
175
+ return advantages
176
+
177
+ def compute_kl_divergence(self, states, actions):
178
+ """Compute KL divergence between old and new policy"""
179
+ with torch.no_grad():
180
+ # Get old policy distributions
181
+ old_probs, _ = self.old_policy(states)
182
+ old_m = Categorical(old_probs)
183
+
184
+ # Get new policy distributions
185
+ new_probs, _ = self.policy(states)
186
+ new_m = Categorical(new_probs)
187
+
188
+ # Compute KL divergence
189
+ kl = torch.distributions.kl.kl_divergence(old_m, new_m).mean()
190
+
191
+ return kl.item()
192
+
193
+ def update(self):
194
+ """PPO-Penalty update with adaptive KL penalty"""
195
+ if len(self.buffer['rewards']) < self.mini_batch_size:
196
+ return None
197
+
198
+ # Convert to tensors
199
+ states = torch.FloatTensor(np.array(self.buffer['states'])).to(self.device)
200
+ actions = torch.LongTensor(self.buffer['actions']).to(self.device)
201
+ old_log_probs = torch.FloatTensor(self.buffer['log_probs']).to(self.device)
202
+
203
+ values = torch.FloatTensor(self.buffer['values']).to(self.device)
204
+ next_states = torch.FloatTensor(np.array(self.buffer['next_states'])).to(self.device)
205
+ dones = torch.FloatTensor(self.buffer['dones']).to(self.device)
206
+ rewards = self.buffer['rewards']
207
+
208
+ # Compute next state values
209
+ with torch.no_grad():
210
+ _, next_values = self.policy(next_states)
211
+ next_values = next_values.squeeze().cpu().numpy()
212
+
213
+ # Compute advantages and returns
214
+ advantages = self.compute_gae(
215
+ rewards,
216
+ values.cpu().numpy(),
217
+ next_values,
218
+ dones.cpu().numpy()
219
+ )
220
+ advantages = torch.FloatTensor(advantages).to(self.device)
221
+ returns = advantages + values
222
+
223
+ # Normalize advantages
224
+ advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
225
+
226
+ total_policy_loss = 0
227
+ total_value_loss = 0
228
+ total_entropy = 0
229
+ total_kl = 0
230
+
231
+ dataset_size = len(states)
232
+
233
+ # Multiple epochs of PPO update
234
+ for epoch in range(self.update_epochs):
235
+ indices = np.random.permutation(dataset_size)
236
+
237
+ for start in range(0, dataset_size, self.mini_batch_size):
238
+ end = start + self.mini_batch_size
239
+ batch_indices = indices[start:end]
240
+
241
+ batch_states = states[batch_indices]
242
+ batch_actions = actions[batch_indices]
243
+ batch_old_log_probs = old_log_probs[batch_indices]
244
+ batch_advantages = advantages[batch_indices]
245
+ batch_returns = returns[batch_indices]
246
+
247
+ # Forward pass
248
+ action_probs, state_values = self.policy(batch_states)
249
+ state_values = state_values.squeeze()
250
+
251
+ # Compute log probs and entropy
252
+ m = Categorical(action_probs)
253
+ log_probs = m.log_prob(batch_actions)
254
+ entropy = m.entropy().mean()
255
+
256
+ # Importance sampling ratio
257
+ ratio = torch.exp(log_probs - batch_old_log_probs)
258
+
259
+ # PPO-Penalty objective (with KL penalty, no clipping!)
260
+ policy_loss = -(ratio * batch_advantages).mean()
261
+
262
+ # Compute KL divergence for this batch
263
+ with torch.no_grad():
264
+ old_probs, _ = self.old_policy(batch_states)
265
+ old_m = Categorical(old_probs)
266
+ kl_batch = torch.distributions.kl.kl_divergence(old_m, m).mean()
267
+ total_kl += kl_batch.item()
268
+
269
+ # Add KL penalty to policy loss
270
+ policy_loss_penalized = policy_loss + self.kl_coef * kl_batch
271
+
272
+ # Value loss
273
+ value_loss = nn.MSELoss()(state_values, batch_returns)
274
+
275
+ # Entropy loss (encourage exploration)
276
+ entropy_loss = -self.entropy_coef * entropy
277
+
278
+ # Total loss
279
+ total_loss = policy_loss_penalized + self.value_coef * value_loss + entropy_loss
280
+
281
+ # Backward pass
282
+ self.optimizer.zero_grad()
283
+ total_loss.backward()
284
+ torch.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
285
+ self.optimizer.step()
286
+
287
+ total_policy_loss += policy_loss.item()
288
+ total_value_loss += value_loss.item()
289
+ total_entropy += entropy.item()
290
+
291
+ # Compute average KL divergence
292
+ n_updates = self.update_epochs * (dataset_size // self.mini_batch_size + 1)
293
+ avg_kl = total_kl / n_updates
294
+
295
+ # Adapt KL coefficient (核心:自适应调整KL惩罚系数)
296
+ if avg_kl < self.kl_target / 1.5:
297
+ # KL too small -> reduce penalty
298
+ self.kl_coef /= self.kl_coef_adapt
299
+ elif avg_kl > self.kl_target * 1.5:
300
+ # KL too large -> increase penalty
301
+ self.kl_coef *= self.kl_coef_adapt
302
+
303
+ # Keep KL coefficient in reasonable range
304
+ self.kl_coef = np.clip(self.kl_coef, 1e-10, 10.0)
305
+
306
+ # Compute explained variance
307
+ with torch.no_grad():
308
+ _, pred_values = self.policy(states)
309
+ pred_values = pred_values.squeeze()
310
+ explained_variance = 1 - torch.var(returns - pred_values) / (torch.var(returns) + 1e-8)
311
+ explained_variance = explained_variance.cpu().item()
312
+
313
+ stats = {
314
+ 'policy_loss': total_policy_loss / n_updates,
315
+ 'value_loss': total_value_loss / n_updates,
316
+ 'entropy': total_entropy / n_updates,
317
+ 'kl_divergence': avg_kl,
318
+ 'kl_coef': self.kl_coef,
319
+ 'explained_variance': explained_variance
320
+ }
321
+
322
+ # Update old policy
323
+ self.update_old_policy()
324
+ self.reset_buffer()
325
+
326
+ return stats
327
+
328
+ def print_header(self):
329
+ """Print table header"""
330
+ print("\n" + "=" * 150)
331
+ print(
332
+ f"{'Episode':>8} | {'Avg Reward':>12} | {'Avg Length':>12} | {'Policy Loss':>12} | {'Value Loss':>12} | {'Entropy':>10} | {'KL Div':>10} | {'KL Coef':>10} | {'Expl Var':>12} | {'Time':>10}")
333
+ print("-" * 150)
334
+
335
+ def print_row(self, episode, total_episodes, stats=None, elapsed_time=None):
336
+ """Print table row"""
337
+
338
+ # Update statistics cache
339
+ if len(self.training_stats['episode_rewards']) > 0:
340
+ self.recent_rewards.append(self.training_stats['episode_rewards'][-1])
341
+ if len(self.training_stats['episode_lengths']) > 0:
342
+ self.recent_lengths.append(self.training_stats['episode_lengths'][-1])
343
+
344
+ # Calculate averages
345
+ avg_reward = np.mean(self.recent_rewards) if self.recent_rewards else 0
346
+ avg_length = np.mean(self.recent_lengths) if self.recent_lengths else 0
347
+
348
+ # Format output
349
+ if stats:
350
+ print(f"{episode:>8}/{total_episodes} | "
351
+ f"{avg_reward:>12.2f} | "
352
+ f"{avg_length:>12.1f} | "
353
+ f"{stats['policy_loss']:>12.4f} | "
354
+ f"{stats['value_loss']:>12.4f} | "
355
+ f"{stats['entropy']:>10.4f} | "
356
+ f"{stats['kl_divergence']:>10.6f} | "
357
+ f"{stats['kl_coef']:>10.6f} | "
358
+ f"{stats['explained_variance']:>12.3f} | "
359
+ f"{elapsed_time:>10.1f}s")
360
+ else:
361
+ print(f"{episode:>8}/{total_episodes} | "
362
+ f"{avg_reward:>12.2f} | "
363
+ f"{avg_length:>12.1f} | "
364
+ f"{'-':>12} | "
365
+ f"{'-':>12} | "
366
+ f"{'-':>10} | "
367
+ f"{'-':>10} | "
368
+ f"{'-':>10} | "
369
+ f"{'-':>12} | "
370
+ f"{elapsed_time:>10.1f}s")
371
+
372
+ def print_summary(self, total_time, num_episodes):
373
+ """Print training summary"""
374
+ print("=" * 150)
375
+ print(f"\n🎯 Training completed! Episodes: {num_episodes}, Total time: {total_time:.1f}s")
376
+
377
+ if len(self.training_stats['episode_rewards']) >= 20:
378
+ final_avg_reward = np.mean(self.training_stats['episode_rewards'][-20:])
379
+ final_avg_length = np.mean(self.training_stats['episode_lengths'][-20:])
380
+ print(f"📊 Last 20 episodes - Avg Reward: {final_avg_reward:.2f}, Avg Length: {final_avg_length:.1f}")
381
+
382
+ if self.training_stats['policy_loss']:
383
+ print(f"📉 Final Policy Loss: {self.training_stats['policy_loss'][-1]:.4f}")
384
+ print(f"📉 Final Value Loss: {self.training_stats['value_loss'][-1]:.4f}")
385
+ print(f"🎲 Final Entropy: {self.training_stats['entropy'][-1]:.4f}")
386
+ print(f"📏 Final KL Divergence: {self.training_stats['kl_divergence'][-1]:.6f}")
387
+ print(f"⚖️ Final KL Coefficient: {self.training_stats['kl_coef'][-1]:.6f}")
388
+ print(f"📈 Final Explained Variance: {self.training_stats['explained_variance'][-1]:.3f}")
389
+
390
+ print("=" * 150)
391
+
392
+ def train(self, num_episodes=1000, max_steps_per_episode=500, log_interval=20):
393
+ """Train PPO-Penalty agent"""
394
+
395
+ print("\n" + "🚀" * 75)
396
+ print("PPO-Penalty (Adaptive KL) Training Started - Gymnasium API")
397
+ print("🚀" * 75)
398
+
399
+ print(f"\n📋 Hyperparameters:")
400
+ print(f" Learning Rate: {self.optimizer.param_groups[0]['lr']:.6f}")
401
+ print(f" Gamma: {self.gamma:.2f}, GAE Lambda: {self.gae_lambda:.2f}")
402
+ print(f" KL Target: {self.kl_target:.4f}, KL Coef Init: {self.kl_coef:.3f}")
403
+ print(f" KL Adapt Rate: {self.kl_coef_adapt:.2f}")
404
+ print(f" Update Epochs: {self.update_epochs}")
405
+ print(f" Mini-batch: {self.mini_batch_size}, Horizon: {self.horizon}")
406
+ print(f" Entropy Coef: {self.entropy_coef:.3f}, Value Coef: {self.value_coef:.1f}")
407
+ print(f" Device: {self.device}")
408
+
409
+ self.print_header()
410
+
411
+ total_steps = 0
412
+ episode = 0
413
+ start_time = time.time()
414
+
415
+ while episode < num_episodes:
416
+ # Gymnasium API: reset() returns state, info
417
+ state, _ = self.env.reset()
418
+ episode_reward = 0
419
+ episode_step = 0
420
+
421
+ while episode_step < max_steps_per_episode:
422
+ action, log_prob, value = self.select_action(state)
423
+
424
+ # Gymnasium API: step() returns next_state, reward, terminated, truncated, info
425
+ next_state, reward, terminated, truncated, _ = self.env.step(action)
426
+ done = terminated or truncated
427
+
428
+ self.store_transition(state, action, reward, next_state, done, log_prob, value)
429
+
430
+ episode_reward += reward
431
+ episode_step += 1
432
+ total_steps += 1
433
+ state = next_state
434
+
435
+ # Update when buffer is full or episode ends
436
+ if len(self.buffer['rewards']) >= self.horizon or done:
437
+ stats = self.update()
438
+ if stats:
439
+ self.training_stats['policy_loss'].append(stats['policy_loss'])
440
+ self.training_stats['value_loss'].append(stats['value_loss'])
441
+ self.training_stats['entropy'].append(stats['entropy'])
442
+ self.training_stats['kl_divergence'].append(stats['kl_divergence'])
443
+ self.training_stats['kl_coef'].append(stats['kl_coef'])
444
+ self.training_stats['explained_variance'].append(stats['explained_variance'])
445
+
446
+ if done:
447
+ break
448
+
449
+ # Record episode statistics
450
+ self.training_stats['episode_rewards'].append(episode_reward)
451
+ self.training_stats['episode_lengths'].append(episode_step)
452
+
453
+ episode += 1
454
+
455
+ # Log periodically
456
+ if episode % log_interval == 0:
457
+ current_time = time.time()
458
+ elapsed = current_time - start_time
459
+
460
+ recent_stats = None
461
+ if self.training_stats['policy_loss']:
462
+ recent_stats = {
463
+ 'policy_loss': self.training_stats['policy_loss'][-1],
464
+ 'value_loss': self.training_stats['value_loss'][-1],
465
+ 'entropy': self.training_stats['entropy'][-1],
466
+ 'kl_divergence': self.training_stats['kl_divergence'][-1],
467
+ 'kl_coef': self.training_stats['kl_coef'][-1],
468
+ 'explained_variance': self.training_stats['explained_variance'][-1]
469
+ }
470
+
471
+ self.print_row(episode, num_episodes, recent_stats, elapsed)
472
+
473
+ total_time = time.time() - start_time
474
+ self.print_summary(total_time, num_episodes)
475
+
476
+ return self.training_stats['episode_rewards'], self.training_stats['episode_lengths']
477
+
478
+ def save(self, path):
479
+ """Save model"""
480
+ torch.save({
481
+ 'policy_state_dict': self.policy.state_dict(),
482
+ 'old_policy_state_dict': self.old_policy.state_dict(),
483
+ 'optimizer_state_dict': self.optimizer.state_dict(),
484
+ 'kl_coef': self.kl_coef,
485
+ 'training_stats': self.training_stats
486
+ }, path)
487
+ print(f"\n💾 Model saved to {path}")
488
+
489
+ def load(self, path):
490
+ """Load model"""
491
+ checkpoint = torch.load(path)
492
+ self.policy.load_state_dict(checkpoint['policy_state_dict'])
493
+ self.old_policy.load_state_dict(checkpoint['old_policy_state_dict'])
494
+ self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
495
+ self.kl_coef = checkpoint['kl_coef']
496
+ self.training_stats = checkpoint['training_stats']
497
+ print(f"\n📂 Model loaded from {path}")
498
+
499
+
500
+ # ============== 3. Evaluation Function ==============
501
+ def evaluate_agent(agent, env, num_episodes=10, render=False):
502
+ """Evaluate trained agent"""
503
+ print("\n" + "🎯" * 35)
504
+ print("Evaluation Started")
505
+ print("🎯" * 35)
506
+
507
+ print(f"\n{'Episode':^12} | {'Reward':^12} | {'Length':^12} | {'Avg Reward':^14}")
508
+ print("-" * 60)
509
+
510
+ episode_rewards = []
511
+ episode_lengths = []
512
+
513
+ for episode in range(num_episodes):
514
+ state, _ = env.reset()
515
+ episode_reward = 0
516
+ episode_step = 0
517
+
518
+ while True:
519
+ if render:
520
+ env.render()
521
+ time.sleep(0.02)
522
+
523
+ action, _, _ = agent.select_action(state, eval_mode=True)
524
+ next_state, reward, terminated, truncated, _ = env.step(action)
525
+ done = terminated or truncated
526
+
527
+ episode_reward += reward
528
+ episode_step += 1
529
+ state = next_state
530
+
531
+ if done:
532
+ break
533
+
534
+ episode_rewards.append(episode_reward)
535
+ episode_lengths.append(episode_step)
536
+
537
+ avg_so_far = np.mean(episode_rewards)
538
+ print(f"{episode + 1:^12} | {episode_reward:^12.1f} | {episode_step:^12} | {avg_so_far:^14.2f}")
539
+
540
+ print("-" * 60)
541
+ print(f"\n📊 Evaluation Results ({num_episodes} episodes):")
542
+ print(f" Avg Reward: {np.mean(episode_rewards):.2f} ± {np.std(episode_rewards):.2f}")
543
+ print(f" Avg Length: {np.mean(episode_lengths):.2f} ± {np.std(episode_lengths):.2f}")
544
+ print(f" Max Reward: {np.max(episode_rewards):.2f}")
545
+ print(f" Min Reward: {np.min(episode_rewards):.2f}")
546
+ print(f" Success Rate (>=475): {np.mean(np.array(episode_rewards) >= 475) * 100:.1f}%")
547
+ print("=" * 60)
548
+
549
+ return episode_rewards, episode_lengths
550
+
551
+
552
+ # ============== 4. Visualization Function ==============
553
+ def plot_training_results(agent, save_path='ppo_penalty_training_results.png'):
554
+ """Plot training results"""
555
+
556
+ stats = agent.training_stats
557
+
558
+ fig, axes = plt.subplots(2, 3, figsize=(18, 10))
559
+ fig.suptitle('PPO-Penalty (Adaptive KL) Training Results - CartPole-v1', fontsize=16, y=1.02)
560
+
561
+ window = 20
562
+
563
+ # 1. Episode Rewards
564
+ ax1 = axes[0, 0]
565
+ rewards = stats['episode_rewards']
566
+ ax1.plot(rewards, alpha=0.3, color='blue', label='Raw Reward')
567
+
568
+ if len(rewards) >= window:
569
+ smoothed = np.convolve(rewards, np.ones(window) / window, mode='valid')
570
+ ax1.plot(range(window - 1, len(smoothed) + window - 1), smoothed,
571
+ 'r-', linewidth=2, label=f'{window}-Episode MA')
572
+
573
+ ax1.axhline(y=500, color='green', linestyle='--', alpha=0.7, label='Target (500)')
574
+ ax1.set_xlabel('Episode')
575
+ ax1.set_ylabel('Total Reward')
576
+ ax1.set_title('Training Rewards')
577
+ ax1.legend()
578
+ ax1.grid(True, alpha=0.3)
579
+
580
+ # 2. Episode Lengths
581
+ ax2 = axes[0, 1]
582
+ lengths = stats['episode_lengths']
583
+ ax2.plot(lengths, alpha=0.3, color='orange', label='Episode Length')
584
+
585
+ if len(lengths) >= window:
586
+ smoothed_length = np.convolve(lengths, np.ones(window) / window, mode='valid')
587
+ ax2.plot(range(window - 1, len(smoothed_length) + window - 1), smoothed_length,
588
+ 'r-', linewidth=2)
589
+
590
+ ax2.set_xlabel('Episode')
591
+ ax2.set_ylabel('Episode Length')
592
+ ax2.set_title('Episode Lengths')
593
+ ax2.grid(True, alpha=0.3)
594
+
595
+ # 3. Policy Loss & KL Divergence
596
+ ax3 = axes[0, 2]
597
+ if stats['policy_loss']:
598
+ ax3_twin = ax3.twinx()
599
+ ax3.plot(stats['policy_loss'], color='purple', linewidth=1.5, label='Policy Loss')
600
+ ax3.set_xlabel('Update Step')
601
+ ax3.set_ylabel('Policy Loss', color='purple')
602
+ ax3.tick_params(axis='y', labelcolor='purple')
603
+
604
+ if stats['kl_divergence']:
605
+ ax3_twin.plot(stats['kl_divergence'], color='orange', linewidth=1.5, label='KL Div')
606
+ ax3_twin.set_ylabel('KL Divergence', color='orange')
607
+ ax3_twin.tick_params(axis='y', labelcolor='orange')
608
+ ax3.grid(True, alpha=0.3)
609
+
610
+ # 4. Value Loss
611
+ ax4 = axes[1, 0]
612
+ if stats['value_loss']:
613
+ ax4.plot(stats['value_loss'], color='brown', linewidth=1.5)
614
+ ax4.set_xlabel('Update Step')
615
+ ax4.set_ylabel('Value Loss')
616
+ ax4.set_title('Value Loss')
617
+ ax4.grid(True, alpha=0.3)
618
+
619
+ # 5. Policy Entropy
620
+ ax5 = axes[1, 1]
621
+ if stats['entropy']:
622
+ ax5.plot(stats['entropy'], color='green', linewidth=1.5)
623
+ ax5.set_xlabel('Update Step')
624
+ ax5.set_ylabel('Policy Entropy')
625
+ ax5.set_title('Policy Entropy (Exploration)')
626
+ ax5.grid(True, alpha=0.3)
627
+
628
+ # 6. KL Coefficient & Explained Variance
629
+ ax6 = axes[1, 2]
630
+ if stats['kl_coef']:
631
+ ax6_twin = ax6.twinx()
632
+ ax6.plot(stats['kl_coef'], color='red', linewidth=1.5, label='KL Coef')
633
+ ax6.set_xlabel('Update Step')
634
+ ax6.set_ylabel('KL Coefficient', color='red')
635
+ ax6.tick_params(axis='y', labelcolor='red')
636
+
637
+ if stats['explained_variance']:
638
+ ax6_twin.plot(stats['explained_variance'], color='blue', linewidth=1.5, label='Expl Var')
639
+ ax6_twin.set_ylabel('Explained Variance', color='blue')
640
+ ax6_twin.tick_params(axis='y', labelcolor='blue')
641
+ ax6.grid(True, alpha=0.3)
642
+
643
+ plt.tight_layout()
644
+ plt.savefig(save_path, dpi=150, bbox_inches='tight')
645
+ plt.show()
646
+ print(f"\n📸 Training results saved to {save_path}")
647
+
648
+
649
+ # ============== 5. Main Function ==============
650
+ def main():
651
+ """Main function"""
652
+
653
+ # Create environment (Gymnasium)
654
+ env = gym.make('CartPole-v1')
655
+
656
+ # PPO-Penalty hyperparameters
657
+ config = {
658
+ 'learning_rate': 3e-4,
659
+ 'gamma': 0.99,
660
+ 'gae_lambda': 0.95,
661
+ 'kl_target': 0.01, # Target KL divergence per update
662
+ 'kl_coef_init': 1.0, # Initial KL penalty coefficient
663
+ 'kl_coef_adapt': 1.5, # Adaptation rate
664
+ 'entropy_coef': 0.01,
665
+ 'value_coef': 0.5,
666
+ 'max_grad_norm': 0.5,
667
+ 'update_epochs': 10, # More epochs for PPO-Penalty
668
+ 'mini_batch_size': 64,
669
+ 'horizon': 2048,
670
+ 'hidden_dim': 64
671
+ }
672
+
673
+ print("\n" + "=" * 100)
674
+ print("PPO-Penalty (Adaptive KL) for CartPole-v1 (Gymnasium)")
675
+ print("=" * 100)
676
+ print("\n📋 Hyperparameters:")
677
+ for key, value in config.items():
678
+ print(f" {key:20}: {value}")
679
+
680
+ # Create PPO-Penalty agent
681
+ agent = PPOPenaltyAgent(env, **config)
682
+
683
+ try:
684
+ # Train
685
+ rewards, lengths = agent.train(num_episodes=500, log_interval=20)
686
+
687
+ # Save model
688
+ agent.save('ppo_penalty_cartpole.pth')
689
+
690
+ # Plot results
691
+ plot_training_results(agent)
692
+
693
+ # Evaluate
694
+ print("\n")
695
+ eval_env = gym.make('CartPole-v1')
696
+ evaluate_agent(agent, eval_env, num_episodes=20, render=False)
697
+
698
+ # Demo (with rendering)
699
+ print("\n")
700
+ demo_env = gym.make('CartPole-v1', render_mode='human')
701
+ evaluate_agent(agent, demo_env, num_episodes=3, render=True)
702
+
703
+ except KeyboardInterrupt:
704
+ print("\n\n⚠️ Training interrupted, saving model...")
705
+ agent.save('ppo_penalty_cartpole_interrupted.pth')
706
+
707
+ finally:
708
+ env.close()
709
+ if 'eval_env' in locals():
710
+ eval_env.close()
711
+ if 'demo_env' in locals():
712
+ demo_env.close()
713
+
714
+
715
+ # ============== 6. Compare PPO-Clip vs PPO-Penalty ==============
716
+ def compare_ppo_variants():
717
+ """Compare PPO-Clip and PPO-Penalty"""
718
+
719
+ print("\n" + "🔬" * 50)
720
+ print("PPO-Clip vs PPO-Penalty Comparison")
721
+ print("🔬" * 50)
722
+
723
+ # This would require implementing both agents and running experiments
724
+ # For brevity, here's the conceptual comparison:
725
+
726
+ comparison = """
727
+ 📊 **PPO-Clip vs PPO-Penalty Comparison**
728
+
729
+ ============================================================
730
+ Feature | PPO-Clip | PPO-Penalty
731
+ ============================================================
732
+ Constraint Type | Hard clipping | Soft KL penalty
733
+ ------------------------------------------------------------
734
+ Update Limit | r ∈ [1-ε, 1+ε] | KL(π||π_old) < target
735
+ ------------------------------------------------------------
736
+ Adaptation | Fixed ε | Adaptive KL coef
737
+ ------------------------------------------------------------
738
+ Implementation | Simple | More complex
739
+ ------------------------------------------------------------
740
+ Compute Cost | Low | Higher (KL calc)
741
+ ------------------------------------------------------------
742
+ Stability | Very stable | Very stable
743
+ ------------------------------------------------------------
744
+ Sample Efficiency | Good | Good
745
+ ------------------------------------------------------------
746
+ Hyperparameter | ε=0.2 (robust) | kl_target=0.01
747
+ ------------------------------------------------------------
748
+ TRPO Relation | Approximation | Direct descendant
749
+ ------------------------------------------------------------
750
+
751
+ **When to use PPO-Penalty:**
752
+ • When you need precise KL control
753
+ • When you're comfortable tuning kl_target
754
+ • When you want to stay closer to TRPO theory
755
+
756
+ **When to use PPO-Clip:**
757
+ • Default choice for most problems
758
+ • Simpler, fewer hyperparameters
759
+ • More widely adopted in practice
760
+ """
761
+
762
+ print(comparison)
763
+
764
+
765
+ if __name__ == "__main__":
766
+ main()
767
+ # compare_ppo_variants() # Uncomment to see comparison
examples/tutorials/rl/cart_pole/step_2_reinforce.py ADDED
@@ -0,0 +1,382 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ 策略梯度法
5
+
6
+
7
+ 如果在相同的动作序列下,环境会输出相同的状态和奖励,此方法会彻底失效。
8
+
9
+ 推车立杆任体力比较简单,每次只有2个动作,它很容易就将每一步的最优选择学会了,但对于复杂任务,REINFORCE 可能要困难得多得多。
10
+
11
+ 其本质是搜集多局游戏数据,进行奖励最大化,再依赖环境的随机性进一步迭代。其类似于动态规划算法。 并不保证能找到全局最优解。
12
+
13
+
14
+ """
15
+ import gymnasium as gym
16
+ import torch
17
+ import torch.nn as nn
18
+ import torch.optim as optim
19
+ from torch.distributions import Categorical
20
+ import numpy as np
21
+ import matplotlib.pyplot as plt
22
+ from collections import deque
23
+ import warnings
24
+
25
+ warnings.filterwarnings('ignore')
26
+
27
+
28
+ # ==================== 策略网络 ====================
29
+ class PolicyNetwork(nn.Module):
30
+ """
31
+ 策略网络:状态 -> 动作概率分布
32
+ 输出是每个动作的概率(经过softmax)
33
+ """
34
+
35
+ def __init__(self, state_dim, hidden_dim, action_dim):
36
+ super(PolicyNetwork, self).__init__()
37
+
38
+ self.network = nn.Sequential(
39
+ nn.Linear(state_dim, hidden_dim),
40
+ nn.ReLU(),
41
+ nn.Linear(hidden_dim, hidden_dim),
42
+ nn.ReLU(),
43
+ nn.Linear(hidden_dim, action_dim),
44
+ nn.Softmax(dim=-1) # 输出概率分布
45
+ )
46
+ self.apply(self._init_weights)
47
+
48
+ def _init_weights(self, module):
49
+ if isinstance(module, nn.Linear):
50
+ nn.init.orthogonal_(module.weight, gain=np.sqrt(2))
51
+ nn.init.constant_(module.bias, 0)
52
+
53
+ def forward(self, state):
54
+ """
55
+ 输入: state [batch_size, state_dim]
56
+ 输出: action_probs [batch_size, action_dim]
57
+ """
58
+ return self.network(state)
59
+
60
+ def get_action(self, state):
61
+ """
62
+ 根据概率分布采样动作
63
+ 返回: action, log_prob
64
+ """
65
+ state = torch.FloatTensor(state).unsqueeze(0)
66
+ # state shape: [1, 4]
67
+ probs = self.forward(state)
68
+ # probs shape: [1, 2]
69
+ dist = Categorical(probs)
70
+ action = dist.sample()
71
+ # action shape: [1], 一个数值。
72
+ log_prob = dist.log_prob(action)
73
+ # log_prob shape: [1], 一个数值。
74
+ return action.item(), log_prob
75
+
76
+
77
+ # ==================== REINFORCE智能体 ====================
78
+ class ReinforceAgent:
79
+ """
80
+ REINFORCE: 蒙特卡洛策略梯度
81
+ 核心思想:用完整轨迹的累积奖励来更新策略
82
+ 好的轨迹 -> 增加这些动作的概率
83
+ 坏的轨迹 -> 减少这些动作的概率
84
+ """
85
+
86
+ def __init__(self,
87
+ state_dim,
88
+ hidden_dim=128,
89
+ action_dim=2,
90
+ lr=1e-3,
91
+ gamma=0.99):
92
+
93
+ self.gamma = gamma
94
+ self.policy = PolicyNetwork(state_dim, hidden_dim, action_dim)
95
+ self.optimizer = optim.Adam(self.policy.parameters(), lr=lr)
96
+
97
+ # 轨迹存储
98
+ self.log_probs = [] # 每个动作的对数概率
99
+ self.rewards = [] # 每个时间步的奖励
100
+
101
+ def select_action(self, state):
102
+ """选择动作并记录对数概率"""
103
+ action, log_prob = self.policy.get_action(state)
104
+ self.log_probs.append(log_prob)
105
+ return action
106
+
107
+ def store_reward(self, reward):
108
+ """存储奖励"""
109
+ self.rewards.append(reward)
110
+
111
+ def update(self):
112
+ """
113
+ REINFORCE更新公式:
114
+ ∇J = E[ Σ ∇logπ(a_t|s_t) * G_t ]
115
+ 其中 G_t = Σ γ^(k-t) * r_k 是从t开始的累积折扣奖励
116
+ """
117
+ # 计算累积折扣奖励 G_t
118
+ returns = []
119
+ G = 0
120
+ for r in reversed(self.rewards):
121
+ G = r + self.gamma * G
122
+ returns.insert(0, G)
123
+
124
+ returns = torch.tensor(returns)
125
+
126
+ # 标准化 returns(降低方差,不是必须但很有帮助)
127
+ returns = (returns - returns.mean()) / (returns.std() + 1e-9)
128
+
129
+ # 计算策略梯度损失
130
+ policy_loss = []
131
+ for log_prob, G in zip(self.log_probs, returns):
132
+ # 核心公式:-log_prob * G
133
+ # 负号是因为PyTorch做梯度下降,我们要最大化J
134
+ policy_loss.append(-log_prob * G)
135
+
136
+ policy_loss = torch.stack(policy_loss).sum()
137
+
138
+ # 更新策略
139
+ self.optimizer.zero_grad()
140
+ policy_loss.backward()
141
+ torch.nn.utils.clip_grad_norm_(self.policy.parameters(), max_norm=0.5)
142
+ self.optimizer.step()
143
+
144
+ # 清空轨迹
145
+ self.log_probs = []
146
+ self.rewards = []
147
+
148
+ return policy_loss.item()
149
+
150
+ def save(self, path):
151
+ torch.save(self.policy.state_dict(), path)
152
+
153
+ def load(self, path):
154
+ self.policy.load_state_dict(torch.load(path))
155
+
156
+
157
+ # ==================== 训练函数 ====================
158
+ def train_reinforce(env_name='CartPole-v1',
159
+ hidden_dim=128,
160
+ lr=1e-3,
161
+ gamma=0.99,
162
+ max_episodes=1000,
163
+ log_interval=20):
164
+ """
165
+ 训练REINFORCE智能体
166
+ """
167
+ # 创建环境
168
+ # env = gym.make(env_name, render_mode='human')
169
+ env = gym.make(env_name)
170
+ state_dim = env.observation_space.shape[0]
171
+ action_dim = env.action_space.n
172
+
173
+ # 初始化智能体
174
+ agent = ReinforceAgent(
175
+ state_dim=state_dim,
176
+ hidden_dim=hidden_dim,
177
+ action_dim=action_dim,
178
+ lr=lr,
179
+ gamma=gamma
180
+ )
181
+
182
+ # 记录训练过程
183
+ episode_rewards = []
184
+ episode_losses = []
185
+ moving_avg_rewards = deque(maxlen=100)
186
+
187
+ print(f"开始训练 REINFORCE on {env_name}")
188
+ print(f"状态维度: {state_dim}, 动作维度: {action_dim}")
189
+ print(f"学习率: {lr}, 折扣因子: {gamma}")
190
+ print("-" * 50)
191
+
192
+ for episode in range(1, max_episodes + 1):
193
+ state, _ = env.reset()
194
+ episode_reward = 0
195
+ episode_loss = 0
196
+ done = False
197
+
198
+ # 收集一条完整轨迹
199
+ while not done:
200
+ # 选择动作
201
+ action = agent.select_action(state)
202
+
203
+ # 执行动作
204
+ next_state, reward, terminated, truncated, _ = env.step(action)
205
+ done = terminated or truncated
206
+
207
+ # 存储奖励
208
+ agent.store_reward(reward)
209
+
210
+ state = next_state
211
+ episode_reward += reward
212
+
213
+ # 一条轨迹结束后更新策略
214
+ loss = agent.update()
215
+ episode_loss = loss
216
+
217
+ # 记录
218
+ episode_rewards.append(episode_reward)
219
+ episode_losses.append(episode_loss)
220
+ moving_avg_rewards.append(episode_reward)
221
+
222
+ # 打印进度
223
+ if episode % log_interval == 0:
224
+ avg_reward = np.mean(moving_avg_rewards)
225
+ print(f"Episode {episode:5d} | "
226
+ f"Reward: {episode_reward:6.2f} | "
227
+ f"Avg Reward: {avg_reward:6.2f} | "
228
+ f"Loss: {episode_loss:8.4f}")
229
+
230
+ # 早停:如果连续100局平均分>=475
231
+ if len(moving_avg_rewards) == 100 and np.mean(moving_avg_rewards) >= 475:
232
+ print(f"\n🎉 在第 {episode} 回合解决问题!平均奖励: {np.mean(moving_avg_rewards):.2f}")
233
+ break
234
+
235
+ env.close()
236
+ return agent, episode_rewards, episode_losses
237
+
238
+
239
+ # ==================== 可视化函数 ====================
240
+ def plot_training(rewards, losses, save_path=None):
241
+ """
242
+ 绘制训练曲线
243
+ """
244
+ fig, axes = plt.subplots(2, 2, figsize=(15, 10))
245
+
246
+ # 1. 原始奖励曲线
247
+ axes[0, 0].plot(rewards, alpha=0.6, color='blue', linewidth=0.8)
248
+ axes[0, 0].set_xlabel('Episode')
249
+ axes[0, 0].set_ylabel('Total Reward')
250
+ axes[0, 0].set_title('Training Rewards')
251
+ axes[0, 0].grid(True, alpha=0.3)
252
+
253
+ # 2. 移动平均奖励
254
+ window = 20
255
+ moving_avg = np.convolve(rewards, np.ones(window) / window, mode='valid')
256
+ axes[0, 1].plot(moving_avg, color='red', linewidth=2)
257
+ axes[0, 1].fill_between(range(len(moving_avg)),
258
+ moving_avg - np.std(rewards[:len(moving_avg)]),
259
+ moving_avg + np.std(rewards[:len(moving_avg)]),
260
+ alpha=0.2, color='red')
261
+ axes[0, 1].set_xlabel('Episode')
262
+ axes[0, 1].set_ylabel(f'Moving Avg Reward (window={window})')
263
+ axes[0, 1].set_title('Smoothed Training Curve')
264
+ axes[0, 1].grid(True, alpha=0.3)
265
+
266
+ # 3. 损失曲线
267
+ axes[1, 0].plot(losses, color='green', alpha=0.6, linewidth=0.8)
268
+ axes[1, 0].set_xlabel('Episode')
269
+ axes[1, 0].set_ylabel('Policy Loss')
270
+ axes[1, 0].set_title('Training Loss')
271
+ axes[1, 0].grid(True, alpha=0.3)
272
+
273
+ # 4. 奖励分布直方图
274
+ axes[1, 1].hist(rewards[-100:], bins=20, color='purple', alpha=0.7, edgecolor='black')
275
+ axes[1, 1].set_xlabel('Total Reward')
276
+ axes[1, 1].set_ylabel('Frequency')
277
+ axes[1, 1].set_title('Reward Distribution (Last 100 Episodes)')
278
+ axes[1, 1].axvline(x=np.mean(rewards[-100:]), color='red', linestyle='--',
279
+ label=f'Mean: {np.mean(rewards[-100:]):.1f}')
280
+ axes[1, 1].legend()
281
+ axes[1, 1].grid(True, alpha=0.3)
282
+
283
+ plt.tight_layout()
284
+
285
+ if save_path:
286
+ plt.savefig(save_path, dpi=150, bbox_inches='tight')
287
+ plt.show()
288
+
289
+
290
+ # ==================== 演示智能体 ====================
291
+ def demo_agent(agent, env_name='CartPole-v1', episodes=5):
292
+ """
293
+ 演示训练好的智能体
294
+ """
295
+ env = gym.make(env_name, render_mode='human')
296
+ # env = gym.make(env_name)
297
+
298
+ for episode in range(episodes):
299
+ state, _ = env.reset()
300
+ total_reward = 0
301
+ done = False
302
+
303
+ while not done:
304
+ action, _ = agent.policy.get_action(state)
305
+ state, reward, terminated, truncated, _ = env.step(action)
306
+ done = terminated or truncated
307
+ total_reward += reward
308
+
309
+ print(f"Demo Episode {episode + 1}: Reward = {total_reward}")
310
+
311
+ env.close()
312
+
313
+
314
+ # ==================== 超参数调优 ====================
315
+ def hyperparameter_sweep():
316
+ """
317
+ 简单的超参数搜索
318
+ """
319
+ learning_rates = [1e-4, 3e-4, 1e-3, 3e-3]
320
+ hidden_sizes = [64, 128, 256]
321
+
322
+ results = {}
323
+
324
+ for lr in learning_rates:
325
+ for hidden in hidden_sizes:
326
+ print(f"\n测试 lr={lr}, hidden={hidden}")
327
+ _, rewards, _ = train_reinforce(
328
+ lr=lr,
329
+ hidden_dim=hidden,
330
+ max_episodes=300,
331
+ log_interval=50
332
+ )
333
+
334
+ avg_reward = np.mean(rewards[-50:])
335
+ results[(lr, hidden)] = avg_reward
336
+ print(f"平均奖励: {avg_reward:.2f}")
337
+
338
+ # 找出最佳参数
339
+ best_params = max(results, key=results.get)
340
+ print(f"\n最佳参数: lr={best_params[0]}, hidden={best_params[1]}")
341
+ print(f"最佳平均奖励: {results[best_params]:.2f}")
342
+
343
+ return results
344
+
345
+
346
+ # ==================== 主程序 ====================
347
+ if __name__ == "__main__":
348
+ # 设置随机种子(可复现)
349
+ torch.manual_seed(42)
350
+ np.random.seed(42)
351
+
352
+ # 训练参数
353
+ CONFIG = {
354
+ 'env_name': 'CartPole-v1',
355
+ 'hidden_dim': 128,
356
+ 'lr': 1e-3,
357
+ 'gamma': 0.99,
358
+ 'max_episodes': 800,
359
+ 'log_interval': 20
360
+ }
361
+
362
+ # 训练智能体
363
+ agent, rewards, losses = train_reinforce(**CONFIG)
364
+
365
+ # 绘制训练曲线
366
+ plot_training(rewards, losses, save_path='reinforce_training.png')
367
+
368
+ # 打印最终结果
369
+ print("\n" + "=" * 50)
370
+ print("训练完成!")
371
+ print(f"最高奖励: {max(rewards):.2f}")
372
+ print(f"平均奖励(最后100局): {np.mean(rewards[-100:]):.2f}")
373
+ print(f"标准差(最后100局): {np.std(rewards[-100:]):.2f}")
374
+ print("=" * 50)
375
+
376
+ # 保存模型
377
+ agent.save('reinforce_cartpole.pth')
378
+ print("模型已保存到 reinforce_cartpole.pth")
379
+
380
+ # 演示
381
+ print("\n开始演示...")
382
+ demo_agent(agent, episodes=3)
examples/tutorials/rl/cart_pole/step_2_reinforce_with_baseline.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ 带基线的方法(REINFORCE with Baseline),相当于是先预测一个未来预期奖励的估计,
5
+ 然后如果实际的奖励大于这个值,则模型得到正向反馈,动作的概率会被加大,
6
+ 如果小于这个值,则动作的概率会减小。
7
+ 这相比于原始的方法(REINFORCE)永远只是增大动作的概率使训练变得更稳定。
8
+
9
+ 如果按照训练的轮次/游戏局数来比较,带基线的方法确实收敛更快。
10
+ 但额外的价值网络会增加额外的价值网络。
11
+
12
+ 原始REINFORCE 方法,当确定的动作序列导致确定的状态序列时,会失效。
13
+ 但带基线REINFORCE 仍然有效。
14
+
15
+ """
16
+ import argparse
17
+ from collections import deque
18
+
19
+ import gymnasium as gym
20
+ import matplotlib.pyplot as plt
21
+ import numpy as np
22
+ import torch
23
+ import torch.nn as nn
24
+ import torch.optim as optim
25
+ from torch.distributions import Categorical
26
+
27
+
28
+ class PolicyNetwork(nn.Module):
29
+ """策略网络 - 输出动作概率分布"""
30
+
31
+ def __init__(self, state_dim, hidden_dim, action_dim):
32
+ super(PolicyNetwork, self).__init__()
33
+ self.fc1 = nn.Linear(state_dim, hidden_dim)
34
+ self.fc2 = nn.Linear(hidden_dim, hidden_dim)
35
+ self.fc3 = nn.Linear(hidden_dim, action_dim)
36
+ self.softmax = nn.Softmax(dim=-1)
37
+
38
+ def forward(self, state):
39
+ x = torch.relu(self.fc1(state))
40
+ x = torch.relu(self.fc2(x))
41
+ x = self.fc3(x)
42
+ return self.softmax(x)
43
+
44
+
45
+ class ValueNetwork(nn.Module):
46
+ """价值网络 - 作为基线,估计状态价值"""
47
+
48
+ def __init__(self, state_dim, hidden_dim):
49
+ super(ValueNetwork, self).__init__()
50
+ self.fc1 = nn.Linear(state_dim, hidden_dim)
51
+ self.fc2 = nn.Linear(hidden_dim, hidden_dim)
52
+ self.fc3 = nn.Linear(hidden_dim, 1)
53
+
54
+ def forward(self, state):
55
+ x = torch.relu(self.fc1(state))
56
+ x = torch.relu(self.fc2(x))
57
+ return self.fc3(x)
58
+
59
+
60
+ class REINFORCEwithBaseline:
61
+ """带基线的策略梯度算法"""
62
+
63
+ def __init__(self,
64
+ env,
65
+ policy_lr=1e-3,
66
+ value_lr=1e-3,
67
+ gamma=0.99,
68
+ hidden_dim=128,
69
+ render=False):
70
+
71
+ self.env = env
72
+ self.gamma = gamma
73
+ self.render = render
74
+
75
+ # 获取状态和动作维度
76
+ self.state_dim = env.observation_space.shape[0]
77
+ self.action_dim = env.action_space.n
78
+
79
+ # 初始化策略网络和价值网络
80
+ self.policy_net = PolicyNetwork(self.state_dim, hidden_dim, self.action_dim)
81
+ self.value_net = ValueNetwork(self.state_dim, hidden_dim)
82
+
83
+ # 优化器
84
+ self.policy_optimizer = optim.Adam(self.policy_net.parameters(), lr=policy_lr)
85
+ self.value_optimizer = optim.Adam(self.value_net.parameters(), lr=value_lr)
86
+
87
+ # 存储轨迹
88
+ self.reset_memory()
89
+
90
+ # 记录训练信息
91
+ self.training_stats = {'episode_rewards': [], 'baseline_loss': []}
92
+
93
+ def reset_memory(self):
94
+ """重置存储的记忆"""
95
+ self.states = []
96
+ self.actions = []
97
+ self.rewards = []
98
+ self.log_probs = []
99
+
100
+ def select_action(self, state):
101
+ """根据当前策略选择动作"""
102
+ state = torch.FloatTensor(state).unsqueeze(0)
103
+ probs = self.policy_net(state)
104
+ m = Categorical(probs)
105
+ action = m.sample()
106
+ log_prob = m.log_prob(action)
107
+
108
+ # 存储经验
109
+ self.states.append(state)
110
+ self.actions.append(action)
111
+ self.log_probs.append(log_prob)
112
+
113
+ return action.item()
114
+
115
+ def compute_returns(self):
116
+ """计算折扣回报"""
117
+ returns = []
118
+ R = 0
119
+ for r in reversed(self.rewards):
120
+ R = r + self.gamma * R
121
+ returns.insert(0, R)
122
+ returns = torch.FloatTensor(returns)
123
+ # 标准化回报以稳定训练
124
+ returns = (returns - returns.mean()) / (returns.std() + 1e-9)
125
+ return returns
126
+
127
+ def update(self):
128
+ """更新策略网络和价值网络"""
129
+ if len(self.rewards) == 0:
130
+ return
131
+
132
+ # 计算回报和价值估计
133
+ returns = self.compute_returns()
134
+ states = torch.cat(self.states)
135
+
136
+ # 1. 更新价值网络(基线)
137
+ value_pred = self.value_net(states).squeeze()
138
+ value_loss = nn.MSELoss()(value_pred, returns)
139
+
140
+ self.value_optimizer.zero_grad()
141
+ value_loss.backward()
142
+ torch.nn.utils.clip_grad_norm_(self.value_net.parameters(), 0.5)
143
+ self.value_optimizer.step()
144
+
145
+ # 2. 更新策略网络
146
+ # 重新计算价值估计用于优势函数
147
+ with torch.no_grad():
148
+ baselines = self.value_net(states).squeeze()
149
+
150
+ # 计算策略梯度
151
+ policy_loss = []
152
+ for log_prob, ret, baseline in zip(self.log_probs, returns, baselines):
153
+ advantage = ret - baseline
154
+ policy_loss.append(-log_prob * advantage)
155
+
156
+ policy_loss = torch.cat(policy_loss).sum()
157
+
158
+ self.policy_optimizer.zero_grad()
159
+ policy_loss.backward()
160
+ torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 0.5)
161
+ self.policy_optimizer.step()
162
+
163
+ # 记录价值损失
164
+ self.training_stats['baseline_loss'].append(value_loss.item())
165
+
166
+ # 清空记忆
167
+ self.reset_memory()
168
+
169
+ def train(self, num_episodes, max_steps_per_episode=500):
170
+ """训练智能体"""
171
+ episode_rewards = []
172
+ best_avg_reward = -np.inf
173
+ reward_window = deque(maxlen=100)
174
+
175
+ for episode in range(num_episodes):
176
+ state, _ = self.env.reset()
177
+ episode_reward = 0
178
+
179
+ for step in range(max_steps_per_episode):
180
+ if self.render:
181
+ self.env.render()
182
+
183
+ action = self.select_action(state)
184
+ next_state, reward, terminated, truncated, _ = self.env.step(action)
185
+ done = terminated or truncated
186
+
187
+ self.rewards.append(reward)
188
+ episode_reward += reward
189
+
190
+ if done:
191
+ break
192
+
193
+ state = next_state
194
+
195
+ # 更新网络
196
+ self.update()
197
+
198
+ # 记录并输出训练信息
199
+ episode_rewards.append(episode_reward)
200
+ reward_window.append(episode_reward)
201
+ avg_reward = np.mean(reward_window)
202
+
203
+ self.training_stats['episode_rewards'].append(episode_reward)
204
+
205
+ if (episode + 1) % 10 == 0:
206
+ avg_baseline_loss = np.mean(self.training_stats['baseline_loss'][-10:])
207
+ print(f'Episode {episode + 1}/{num_episodes}, '
208
+ f'Reward: {episode_reward:.1f}, '
209
+ f'Avg Reward: {avg_reward:.1f}, '
210
+ f'Baseline Loss: {avg_baseline_loss:.4f}')
211
+
212
+ # 保存最佳模型
213
+ if avg_reward > best_avg_reward and episode > 100:
214
+ best_avg_reward = avg_reward
215
+ torch.save({
216
+ 'policy_state_dict': self.policy_net.state_dict(),
217
+ 'value_state_dict': self.value_net.state_dict(),
218
+ 'avg_reward': avg_reward,
219
+ 'episode': episode
220
+ }, 'best_model.pth')
221
+
222
+ return episode_rewards
223
+
224
+
225
+ def plot_training_results(rewards, baseline_loss, window=100):
226
+ """绘制训练结果"""
227
+ fig, axes = plt.subplots(2, 1, figsize=(10, 8))
228
+
229
+ # 绘制奖励曲线
230
+ axes[0].plot(rewards, alpha=0.3, color='blue', label='Episode Reward')
231
+
232
+ # 绘制平滑曲线
233
+ if len(rewards) >= window:
234
+ smoothed = np.convolve(rewards, np.ones(window) / window, mode='valid')
235
+ axes[0].plot(range(window - 1, len(rewards)), smoothed,
236
+ color='red', linewidth=2, label=f'Moving Avg (window={window})')
237
+
238
+ axes[0].set_xlabel('Episode')
239
+ axes[0].set_ylabel('Total Reward')
240
+ axes[0].set_title('REINFORCE with Baseline - Training Rewards')
241
+ axes[0].legend()
242
+ axes[0].grid(True, alpha=0.3)
243
+
244
+ # 绘制基线损失
245
+ axes[1].plot(baseline_loss, color='green', alpha=0.6, label='Baseline Loss')
246
+ axes[1].set_xlabel('Update Step')
247
+ axes[1].set_ylabel('MSE Loss')
248
+ axes[1].set_title('Value Network (Baseline) Loss')
249
+ axes[1].legend()
250
+ axes[1].grid(True, alpha=0.3)
251
+
252
+ plt.tight_layout()
253
+ plt.savefig('training_results.png', dpi=100)
254
+ plt.show()
255
+
256
+
257
+ def test_agent(env, policy_net, num_episodes=10, render=True):
258
+ """测试训练好的智能体"""
259
+ episode_rewards = []
260
+
261
+ for episode in range(num_episodes):
262
+ state, _ = env.reset()
263
+ episode_reward = 0
264
+ done = False
265
+
266
+ while not done:
267
+ if render:
268
+ env.render()
269
+
270
+ state_tensor = torch.FloatTensor(state).unsqueeze(0)
271
+ with torch.no_grad():
272
+ probs = policy_net(state_tensor)
273
+ action = torch.argmax(probs).item()
274
+
275
+ next_state, reward, terminated, truncated, _ = env.step(action)
276
+ done = terminated or truncated
277
+ episode_reward += reward
278
+ state = next_state
279
+
280
+ episode_rewards.append(episode_reward)
281
+ print(f'Test Episode {episode + 1}: Reward = {episode_reward}')
282
+
283
+ print(f'Average Test Reward: {np.mean(episode_rewards):.2f} +/- {np.std(episode_rewards):.2f}')
284
+ return episode_rewards
285
+
286
+
287
+ def main():
288
+ parser = argparse.ArgumentParser(description='REINFORCE with Baseline for CartPole')
289
+ parser.add_argument('--episodes', type=int, default=1000, help='Number of training episodes')
290
+ parser.add_argument('--policy_lr', type=float, default=1e-3, help='Learning rate for policy network')
291
+ parser.add_argument('--value_lr', type=float, default=1e-3, help='Learning rate for value network')
292
+ parser.add_argument('--gamma', type=float, default=0.99, help='Discount factor')
293
+ parser.add_argument('--hidden_dim', type=int, default=128, help='Hidden layer dimension')
294
+ parser.add_argument('--render_train', action='store_true', help='Render during training')
295
+ parser.add_argument('--render_test', action='store_true', help='Render during testing')
296
+ args = parser.parse_args()
297
+
298
+ # 创建环境
299
+ env = gym.make('CartPole-v1')
300
+
301
+ # 创建智能体
302
+ agent = REINFORCEwithBaseline(
303
+ env=env,
304
+ policy_lr=args.policy_lr,
305
+ value_lr=args.value_lr,
306
+ gamma=args.gamma,
307
+ hidden_dim=args.hidden_dim,
308
+ render=args.render_train
309
+ )
310
+
311
+ # 训练
312
+ print("开始训练 REINFORCE with Baseline...")
313
+ rewards = agent.train(num_episodes=args.episodes)
314
+
315
+ # 绘制训练结果
316
+ plot_training_results(
317
+ agent.training_stats['episode_rewards'],
318
+ agent.training_stats['baseline_loss']
319
+ )
320
+
321
+ # 测试
322
+ print("\n测试训练好的模型...")
323
+ test_env = gym.make('CartPole-v1')
324
+ test_agent(test_env, agent.policy_net, num_episodes=10, render=args.render_test)
325
+
326
+ # 关闭环境
327
+ env.close()
328
+ test_env.close()
329
+
330
+
331
+ if __name__ == "__main__":
332
+ main()
examples/tutorials/rl/cart_pole/step_2_rl_dqn.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ 深度强化学习 + 基于值 + 模型无关 + 异策略 + 离线学习
5
+
6
+ 强化学习算法
7
+
8
+ ├── 🔵 基于值的方法 (Value-Based)
9
+ │ ├── 传统: Q-learning, SARSA
10
+ │ └── 🎯 DQN ← 在这里!
11
+ │ ├── DQN
12
+ │ ├── Double DQN
13
+ │ ├── Dueling DQN
14
+ │ └── Rainbow
15
+
16
+ ├── 🔴 基于策略的方法 (Policy-Based)
17
+ │ ├── REINFORCE
18
+ │ ├── PPO
19
+ │ └── TRPO
20
+
21
+ └── 🟣 演员-评论家方法 (Actor-Critic)
22
+ ├── A2C/A3C
23
+ ├── SAC
24
+ └── TD3
25
+
26
+ 贝尔曼方程
27
+
28
+ 贝尔曼的洞察(1957):
29
+ "最优策略有这样的性质:无论初始状态和初始决策如何,其余的决策必须构成一个以第一个决策产生的状态为初始状态的最优策略。"
30
+ 翻译成人话:
31
+ 如果A→B→C是最优路径,那么B→C也必须是最优路径!
32
+
33
+ 备注:所以 DQN 是基于贝尔曼假设的,当环境不符合这个假设时,该方法不成立。
34
+
35
+ """
36
+ from collections import deque
37
+ import random
38
+
39
+ import gymnasium as gym
40
+ import numpy as np
41
+ import matplotlib.pyplot as plt
42
+ import torch
43
+ import torch.nn as nn
44
+ import torch.optim as optim
45
+
46
+
47
+ # 神经网络定义
48
+ class DQN(nn.Module):
49
+ def __init__(self, state_size: int, action_size: int):
50
+ super(DQN, self).__init__()
51
+ self.fc1 = nn.Linear(state_size, 64)
52
+ self.fc2 = nn.Linear(64, 64)
53
+ self.fc3 = nn.Linear(64, action_size)
54
+ self.relu = nn.ReLU()
55
+
56
+ def forward(self, x):
57
+ x = self.relu(self.fc1(x))
58
+ x = self.relu(self.fc2(x))
59
+ return self.fc3(x)
60
+
61
+
62
+ # 经验回放缓冲区
63
+ class ReplayBuffer:
64
+ def __init__(self, capacity: int):
65
+ self.buffer = deque(maxlen=capacity)
66
+
67
+ def push(self, state, action, reward, next_state, done):
68
+ """
69
+ :param state: 状态 = [位置, 速度, 角度, 角速度]
70
+ :param action:
71
+ :param reward: float
72
+ :param next_state:
73
+ :param done:
74
+ :return:
75
+ """
76
+ self.buffer.append((state, action, reward, next_state, done))
77
+
78
+ def sample(self, batch_size):
79
+ batch = random.sample(self.buffer, batch_size)
80
+ states, actions, rewards, next_states, dones = zip(*batch)
81
+ result = (
82
+ np.array(states),
83
+ np.array(actions),
84
+ np.array(rewards),
85
+ np.array(next_states),
86
+ np.array(dones)
87
+ )
88
+ return result
89
+
90
+ def __len__(self):
91
+ return len(self.buffer)
92
+
93
+
94
+ # DQN智能体
95
+ class DQNAgent:
96
+ def __init__(self, state_size: int, action_size: int):
97
+ self.state_size = state_size
98
+ self.action_size = action_size
99
+
100
+ # 网络
101
+ self.policy_net = DQN(state_size, action_size)
102
+ self.target_net = DQN(state_size, action_size)
103
+ self.target_net.load_state_dict(self.policy_net.state_dict())
104
+
105
+ # 优化器
106
+ self.optimizer = optim.Adam(self.policy_net.parameters(), lr=0.001)
107
+
108
+ # 超参数
109
+ self.gamma = 0.99
110
+ self.epsilon = 1.0
111
+ self.epsilon_min = 0.01
112
+ self.epsilon_decay = 0.995
113
+ self.batch_size = 64
114
+ self.buffer = ReplayBuffer(10000)
115
+
116
+ # 目标网络更新频率
117
+ self.target_update = 10
118
+
119
+ def select_action(self, state, training=True):
120
+ if training and np.random.random() < self.epsilon:
121
+ return random.randrange(self.action_size)
122
+
123
+ state = torch.FloatTensor(state).unsqueeze(0)
124
+ with torch.no_grad():
125
+ q_values = self.policy_net(state)
126
+ return q_values.argmax().item()
127
+
128
+ def train_step(self):
129
+ if len(self.buffer) < self.batch_size:
130
+ return
131
+
132
+ # 从缓冲区采样
133
+ states, actions, rewards, next_states, dones = self.buffer.sample(self.batch_size)
134
+
135
+ # 转换为张量
136
+ states = torch.FloatTensor(states)
137
+ actions = torch.LongTensor(actions).unsqueeze(1)
138
+ rewards = torch.FloatTensor(rewards).unsqueeze(1)
139
+ next_states = torch.FloatTensor(next_states)
140
+ dones = torch.FloatTensor(dones).unsqueeze(1)
141
+
142
+ # 计算当前Q值
143
+ actions_logits = self.policy_net(states)
144
+ # actions_logits shape: [batch_size, action_size]
145
+ current_q = torch.gather(actions_logits, dim=1, index=actions)
146
+ # current_q shape: [batch_size, 1]
147
+
148
+ # 计算目标Q值
149
+ with torch.no_grad():
150
+ next_actions_logits = self.target_net(next_states)
151
+ # next_actions_logits shape: [batch_size, action_size]
152
+ next_q, _ = torch.max(next_actions_logits, 1)
153
+ next_q = torch.unsqueeze(next_q, 1)
154
+ # next_q shape: [batch_size, 1]
155
+ # 贝尔曼方程
156
+ target_q = rewards + (1 - dones) * self.gamma * next_q
157
+
158
+ # 计算损失
159
+ # current_q 预测采取当前动作后未来能获取的总奖励。并对远期奖励降权。
160
+ # target_q 上一次预测未来能获取的总奖励 = 当前已获得奖励 + 当前预测未来能获取的总奖励
161
+ loss = nn.MSELoss()(current_q, target_q)
162
+
163
+ # 引入 target_net 而不是直接使用 policy_net 是因为:
164
+ # 每一个 train_step 只训练了一个 batch,但训练本身具有随机性,
165
+ # 因此可以认为,只在训练多个 batch 之后对模型的优化才真实有效。
166
+ # 所以设置为定期将 target_net 与 policy_net 同步。
167
+
168
+ # 优化
169
+ self.optimizer.zero_grad()
170
+ loss.backward()
171
+ torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 1.0)
172
+ self.optimizer.step()
173
+
174
+ # 衰减探索率
175
+ if self.epsilon > self.epsilon_min:
176
+ self.epsilon *= self.epsilon_decay
177
+
178
+
179
+ # 训练函数
180
+ def train_dqn():
181
+ env = gym.make("CartPole-v1")
182
+ state_size = env.observation_space.shape[0]
183
+ action_size = env.action_space.n
184
+
185
+ agent = DQNAgent(state_size, action_size)
186
+ episodes = 500
187
+ rewards = []
188
+
189
+ for episode in range(episodes):
190
+ state, _ = env.reset()
191
+ total_reward = 0
192
+ done = False
193
+
194
+ while not done:
195
+ # 选择动作
196
+ action = agent.select_action(state)
197
+
198
+ # 执行动作
199
+ next_state, reward, terminated, truncated, _ = env.step(action)
200
+ done = terminated or truncated
201
+
202
+ # 存储经验
203
+ agent.buffer.push(state, action, reward, next_state, done)
204
+
205
+ # 训练
206
+ agent.train_step()
207
+
208
+ state = next_state
209
+ total_reward += reward
210
+
211
+ # 更新目标网络
212
+ if episode % agent.target_update == 0:
213
+ agent.target_net.load_state_dict(agent.policy_net.state_dict())
214
+
215
+ rewards.append(total_reward)
216
+
217
+ if (episode + 1) % 50 == 0:
218
+ avg_reward = np.mean(rewards[-50:])
219
+ print(f"Episode {episode + 1}, Avg Reward: {avg_reward:.2f}, Epsilon: {agent.epsilon:.3f}")
220
+
221
+ env.close()
222
+
223
+ # 可视化
224
+ plt.figure(figsize=(12, 4))
225
+
226
+ plt.subplot(1, 2, 1)
227
+ plt.plot(rewards)
228
+ plt.xlabel('Episode')
229
+ plt.ylabel('Reward')
230
+ plt.title('Training Rewards')
231
+
232
+ plt.subplot(1, 2, 2)
233
+ window = 20
234
+ moving_avg = [np.mean(rewards[max(0, i - window):i + 1]) for i in range(len(rewards))]
235
+ plt.plot(moving_avg)
236
+ plt.xlabel('Episode')
237
+ plt.ylabel('Moving Avg Reward')
238
+ plt.title(f'Moving Average (window={window})')
239
+
240
+ plt.tight_layout()
241
+ plt.show()
242
+
243
+ return agent
244
+
245
+
246
+ # 运行训练
247
+ agent = train_dqn()
248
+
249
+
250
+ if __name__ == "__main__":
251
+ pass
examples/tutorials/rlhf/gpt2_sst2/step_1_prepare_data.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ 或使用命令行
5
+ pip install modelscope
6
+ modelscope download \
7
+ --model 'qgyd2021/Qwen3-8B-sft-deepspeed' \
8
+ --local_dir '/root/autodl-tmp/trained_models/Qwen3-8B-sft-deepspeed'
9
+
10
+ """
11
+ import argparse
12
+ import os
13
+ from pathlib import Path
14
+ import platform
15
+
16
+ os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
17
+
18
+ if platform.system() in ("Windows", "Darwin"):
19
+ from project_settings import project_path, temp_directory
20
+ else:
21
+ project_path = os.path.abspath("../../../")
22
+ project_path = Path(project_path)
23
+ temp_directory = Path("/root/autodl-tmp/OpenMiniMind/temp")
24
+
25
+ from modelscope import snapshot_download
26
+ # from huggingface_hub import snapshot_download
27
+
28
+
29
+ def get_args():
30
+ parser = argparse.ArgumentParser()
31
+ parser.add_argument("--repo_id", default="Qwen/Qwen2.5-0.5B", type=str)
32
+ parser.add_argument(
33
+ "--local_dir",
34
+ default=(temp_directory / "../trained_models/Qwen/Qwen2.5-0.5B").as_posix(),
35
+ type=str
36
+ )
37
+ args = parser.parse_args()
38
+ return args
39
+
40
+
41
+ def main():
42
+ args = get_args()
43
+
44
+ #modelscope
45
+ # snapshot_download(
46
+ # model_id=args.repo_id,
47
+ # local_dir=args.local_dir,
48
+ # )
49
+ #huggingface_hub
50
+ snapshot_download(
51
+ repo_type="model",
52
+ repo_id=args.repo_id,
53
+ local_dir=args.local_dir,
54
+ )
55
+ return
56
+
57
+
58
+ if __name__ == "__main__":
59
+ main()
examples/tutorials/rlhf/gpt2_sst2/step_2_train_sft_model.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ 用sst的句子训练gpt2模型,让其随机生成一些评论。
5
+ """
6
+ import argparse
7
+ import os
8
+ from pathlib import Path
9
+ import platform
10
+
11
+ if platform.system() in ("Windows", "Darwin"):
12
+ from project_settings import project_path, temp_directory
13
+ else:
14
+ project_path = os.path.abspath("../../../")
15
+ project_path = Path(project_path)
16
+ temp_directory = Path("/root/autodl-tmp/OpenMiniMind/temp")
17
+
18
+ from datasets import load_dataset
19
+ import torch
20
+ from transformers import AutoModelForCausalLM, AutoTokenizer, DataCollatorForLanguageModeling, Trainer, TrainingArguments
21
+
22
+
23
+ def get_args():
24
+ parser = argparse.ArgumentParser()
25
+ parser.add_argument(
26
+ "--model_name",
27
+ # default="openai-community/gpt2",
28
+ default=(project_path / "pretrained_models/openai-community/gpt2").as_posix(),
29
+ type=str
30
+ ),
31
+ parser.add_argument(
32
+ "--dataset_path",
33
+ default="stanfordnlp/sst2",
34
+ type=str
35
+ ),
36
+ parser.add_argument("--dataset_name", default=None, type=str),
37
+ parser.add_argument("--dataset_split", default=None, type=str),
38
+ parser.add_argument(
39
+ "--dataset_cache_dir",
40
+ default=(temp_directory / "hub_datasets").as_posix(),
41
+ type=str
42
+ ),
43
+ parser.add_argument(
44
+ "--model_cache_dir",
45
+ default=(temp_directory / "hub_models").as_posix(),
46
+ type=str
47
+ ),
48
+ parser.add_argument("--dataset_streaming", default=None, type=str),
49
+ parser.add_argument("--valid_dataset_size", default=1000, type=int),
50
+ parser.add_argument("--shuffle_buffer_size", default=5000, type=int),
51
+
52
+ parser.add_argument(
53
+ "--output_model_dir",
54
+ default=(project_path / "trained_models/gpt2-sst2-generation-20260213-2048").as_posix(),
55
+ type=str
56
+ ),
57
+
58
+ parser.add_argument(
59
+ "--num_workers",
60
+ default=None if platform.system() in ("Windows", "Darwin") else os.cpu_count() // 2,
61
+ type=int
62
+ ),
63
+ parser.add_argument(
64
+ "--device",
65
+ default=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
66
+ type=int
67
+ ),
68
+ args = parser.parse_args()
69
+ return args
70
+
71
+
72
+ def main():
73
+ args = get_args()
74
+
75
+ model = AutoModelForCausalLM.from_pretrained(args.model_name)
76
+ model = model.to(args.device)
77
+ tokenizer = AutoTokenizer.from_pretrained(args.model_name)
78
+ tokenizer.pad_token = tokenizer.eos_token
79
+
80
+ dataset_dict = load_dataset(
81
+ path=args.dataset_path,
82
+ name=args.dataset_name,
83
+ split=args.dataset_split,
84
+ cache_dir=args.dataset_cache_dir,
85
+ # num_proc=args.num_workers if not args.dataset_streaming else None,
86
+ streaming=args.dataset_streaming,
87
+ )
88
+ train_dataset = dataset_dict["train"]
89
+ valid_dataset = dataset_dict["validation"]
90
+ # test_dataset = dataset_dict["test"]
91
+
92
+ def format_func(example):
93
+ sentence = example["sentence"]
94
+ sentence += tokenizer.eos_token
95
+ tokenized = tokenizer(sentence)
96
+ input_ids = tokenized["input_ids"]
97
+ attention_mask = tokenized["attention_mask"]
98
+ # print(input_ids)
99
+ # print(attention_mask)
100
+ result = {
101
+ "input_ids": input_ids,
102
+ "attention_mask": attention_mask,
103
+ }
104
+ return result
105
+
106
+ train_dataset = train_dataset.map(
107
+ format_func,
108
+ batched=False,
109
+ remove_columns=train_dataset.column_names,
110
+ )
111
+ valid_dataset = valid_dataset.map(
112
+ format_func,
113
+ batched=False,
114
+ remove_columns=valid_dataset.column_names,
115
+ )
116
+ train_dataset = train_dataset.filter(
117
+ function=lambda x: len(x["input_ids"]) > 5
118
+ )
119
+ valid_dataset = valid_dataset.filter(
120
+ function=lambda x: len(x["input_ids"]) > 5
121
+ )
122
+
123
+ data_collator = DataCollatorForLanguageModeling(
124
+ tokenizer,
125
+ mlm=False
126
+ )
127
+
128
+ training_args = TrainingArguments(
129
+ output_dir=args.output_model_dir,
130
+ # overwrite_output_dir=True,
131
+ num_train_epochs=1,
132
+ per_device_train_batch_size=16,
133
+ per_device_eval_batch_size=16,
134
+ eval_strategy="steps",
135
+ eval_steps=100,
136
+ save_strategy="steps",
137
+ save_steps=100,
138
+ save_total_limit=2,
139
+ logging_steps=100,
140
+ learning_rate=5e-5,
141
+ warmup_steps=500,
142
+ weight_decay=0.01,
143
+ fp16=torch.cuda.is_available(),
144
+ dataloader_num_workers=args.num_workers or 0,
145
+ remove_unused_columns=False,
146
+ load_best_model_at_end=False,
147
+ # metric_for_best_model="eval_loss",
148
+ # greater_is_better=False,
149
+ )
150
+
151
+ trainer = Trainer(
152
+ model=model,
153
+ args=training_args,
154
+ data_collator=data_collator,
155
+ train_dataset=train_dataset,
156
+ eval_dataset=valid_dataset,
157
+ tokenizer=tokenizer,
158
+ )
159
+
160
+ trainer.train()
161
+ trainer.save_model()
162
+ return
163
+
164
+
165
+ if __name__ == "__main__":
166
+ main()
examples/tutorials/rlhf/gpt2_sst2/step_3_train_reward_model.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ import os
5
+ from pathlib import Path
6
+ import platform
7
+ from typing import Any, Dict, List, Optional, Union, Tuple
8
+
9
+ if platform.system() in ("Windows", "Darwin"):
10
+ from project_settings import project_path, temp_directory
11
+ else:
12
+ project_path = os.path.abspath("../../../")
13
+ project_path = Path(project_path)
14
+ temp_directory = Path("/root/autodl-tmp/OpenMiniMind/temp")
15
+
16
+ from datasets import load_dataset
17
+ import numpy as np
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ from transformers import (AutoModelForCausalLM,
22
+ AutoTokenizer,
23
+ GPT2PreTrainedModel, GPT2Config, GPT2Model,
24
+ DataCollatorWithPadding,
25
+ Trainer, TrainingArguments
26
+ )
27
+
28
+
29
+ def get_args():
30
+ parser = argparse.ArgumentParser()
31
+ parser.add_argument(
32
+ "--model_name",
33
+ # default="openai-community/gpt2",
34
+ default=(project_path / "pretrained_models/openai-community/gpt2").as_posix(),
35
+ type=str
36
+ ),
37
+ parser.add_argument(
38
+ "--dataset_path",
39
+ default="stanfordnlp/sst2",
40
+ type=str
41
+ ),
42
+ parser.add_argument("--dataset_name", default=None, type=str),
43
+ parser.add_argument("--dataset_split", default=None, type=str),
44
+ parser.add_argument(
45
+ "--dataset_cache_dir",
46
+ default=(temp_directory / "hub_datasets").as_posix(),
47
+ type=str
48
+ ),
49
+ parser.add_argument(
50
+ "--model_cache_dir",
51
+ default=(temp_directory / "hub_models").as_posix(),
52
+ type=str
53
+ ),
54
+ parser.add_argument("--dataset_streaming", default=None, type=str),
55
+ parser.add_argument("--valid_dataset_size", default=1000, type=int),
56
+ parser.add_argument("--shuffle_buffer_size", default=5000, type=int),
57
+
58
+ parser.add_argument(
59
+ "--output_model_dir",
60
+ default=(project_path / "trained_models/gpt2-sst2-reward-20260213-2122").as_posix(),
61
+ type=str
62
+ ),
63
+ parser.add_argument(
64
+ "--num_workers",
65
+ default=None if platform.system() in ("Windows", "Darwin") else os.cpu_count() // 2,
66
+ type=int
67
+ ),
68
+ parser.add_argument(
69
+ "--device",
70
+ default=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
71
+ type=str
72
+ ),
73
+ args = parser.parse_args()
74
+ return args
75
+
76
+
77
+ class RewardHead(nn.Module):
78
+ def __init__(self, hidden_size: int):
79
+ super().__init__()
80
+ self.hidden_size = hidden_size
81
+ self.linear = nn.Linear(self.hidden_size, 1)
82
+ self._post_init()
83
+
84
+ def _post_init(self):
85
+ nn.init.normal_(
86
+ self.linear.weight,
87
+ std=(1.0 / np.sqrt(self.hidden_size + 1))
88
+ )
89
+ nn.init.zeros_(self.linear.bias)
90
+
91
+ def forward(self, hidden_states):
92
+ # hidden_states shape: [batch_size, seq_len, hidden_size]
93
+ reward_logits = self.linear(hidden_states)
94
+ # reward_logits shape: [batch_size, seq_len, 1]
95
+ return reward_logits
96
+
97
+
98
+ class GPT2RewardModel(GPT2PreTrainedModel):
99
+ def __init__(self, config: GPT2Config):
100
+ super().__init__(config)
101
+ self.transformer = GPT2Model(config)
102
+ self.reward_head = RewardHead(config.hidden_size)
103
+ self.post_init()
104
+
105
+ def forward(
106
+ self,
107
+ input_ids: Optional[torch.LongTensor] = None,
108
+ attention_mask: Optional[torch.FloatTensor] = None,
109
+ ) -> Union[Tuple, torch.Tensor]:
110
+ transformer_outputs = self.transformer(
111
+ input_ids=input_ids,
112
+ attention_mask=attention_mask,
113
+ output_hidden_states=True
114
+ )
115
+ last_hidden_state = transformer_outputs.hidden_states[-1]
116
+ # last_hidden_state shape: [batch_size, seq_len, hidden_size]
117
+ rewards_logits = self.reward_head(last_hidden_state)
118
+ # rewards_logits shape: [batch_size, seq_len, 1]
119
+ rewards_logits = torch.squeeze(rewards_logits, -1)
120
+ # rewards_logits shape: [batch_size, seq_len]
121
+ rewards = torch.sigmoid(rewards_logits)
122
+ # rewards shape: [batch_size, seq_len]
123
+ return rewards
124
+
125
+ @classmethod
126
+ def from_pretrained(cls, model_name_or_path, *model_args, **kwargs):
127
+ config = GPT2Config.from_pretrained(model_name_or_path)
128
+ model = cls(config)
129
+ pretrained_model = GPT2Model.from_pretrained(model_name_or_path)
130
+ model.transformer.load_state_dict(pretrained_model.state_dict(), strict=False)
131
+ return model
132
+
133
+
134
+ class SST2RewardTrainer(Trainer):
135
+ def compute_loss(
136
+ self,
137
+ model: nn.Module,
138
+ inputs: dict[str, Union[torch.Tensor, Any]],
139
+ return_outputs: bool = False,
140
+ num_items_in_batch: Optional[torch.Tensor] = None,
141
+ ):
142
+ rewards = model(
143
+ input_ids=inputs["input_ids"],
144
+ attention_mask=inputs["attention_mask"]
145
+ )
146
+ sequence_lengths = inputs["attention_mask"].sum(dim=1) - 1
147
+ batch_indices = torch.arange(rewards.size(0), device=rewards.device)
148
+ sequence_reward = rewards[batch_indices, sequence_lengths]
149
+ # sequence_reward shape: [batch_size,]
150
+ loss = F.mse_loss(
151
+ sequence_reward,
152
+ inputs["score"].float()
153
+ )
154
+ if return_outputs:
155
+ return loss, {
156
+ "loss": loss,
157
+ "predictions": sequence_reward.detach(),
158
+ }
159
+ return loss
160
+
161
+ def prediction_step(
162
+ self,
163
+ model: nn.Module,
164
+ inputs: dict[str, Union[torch.Tensor, Any]],
165
+ prediction_loss_only: bool,
166
+ ignore_keys: Optional[list[str]] = None,
167
+ ) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
168
+ with torch.no_grad():
169
+ loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
170
+
171
+ if prediction_loss_only:
172
+ return loss, None, None
173
+
174
+ predictions = outputs["predictions"]
175
+ labels = inputs["score"].float()
176
+ return loss, predictions, labels
177
+
178
+
179
+ def compute_metrics(eval_pred):
180
+ """计算评估指标"""
181
+ predictions, labels = eval_pred
182
+ predictions = torch.tensor(predictions)
183
+ labels = torch.tensor(labels)
184
+
185
+ error = (predictions - labels).abs()
186
+
187
+ return {
188
+ "mean_error": error.mean().item(),
189
+ "std_error": error.std().item(),
190
+ "reward_mean": predictions.mean().item(),
191
+ "reward_min": predictions.min().item(),
192
+ "reward_max": predictions.max().item(),
193
+ "score_mean": labels.mean().item(),
194
+ }
195
+
196
+
197
+ def main():
198
+ args = get_args()
199
+
200
+ model = GPT2RewardModel.from_pretrained(args.model_name)
201
+ model = model.to(args.device)
202
+ tokenizer = AutoTokenizer.from_pretrained(args.model_name)
203
+ tokenizer.pad_token = tokenizer.eos_token
204
+
205
+ # tokenized = tokenizer(
206
+ # "this is a good day",
207
+ # # "this is ",
208
+ # return_tensors="pt"
209
+ # )
210
+ # output_dict = model(**tokenized)
211
+
212
+ dataset_dict = load_dataset(
213
+ path=args.dataset_path,
214
+ name=args.dataset_name,
215
+ split=args.dataset_split,
216
+ cache_dir=args.dataset_cache_dir,
217
+ # num_proc=args.num_workers if not args.dataset_streaming else None,
218
+ streaming=args.dataset_streaming,
219
+ )
220
+ train_dataset = dataset_dict["train"]
221
+ valid_dataset = dataset_dict["validation"]
222
+ # test_dataset = dataset_dict["test"]
223
+
224
+ def format_func(example):
225
+ sentence: str = example["sentence"]
226
+ score: float = float(example["label"])
227
+ sentence += tokenizer.eos_token
228
+ tokenized = tokenizer(sentence)
229
+ input_ids = tokenized["input_ids"]
230
+ attention_mask = tokenized["attention_mask"]
231
+ result = {
232
+ "input_ids": input_ids,
233
+ "attention_mask": attention_mask,
234
+ "score": score,
235
+ }
236
+ return result
237
+
238
+ train_dataset = train_dataset.map(
239
+ format_func,
240
+ batched=False,
241
+ remove_columns=train_dataset.column_names,
242
+ )
243
+ valid_dataset = valid_dataset.map(
244
+ format_func,
245
+ batched=False,
246
+ remove_columns=valid_dataset.column_names,
247
+ )
248
+ train_dataset = train_dataset.filter(
249
+ function=lambda x: len(x["input_ids"]) > 6
250
+ )
251
+ valid_dataset = valid_dataset.filter(
252
+ function=lambda x: len(x["input_ids"]) > 6
253
+ )
254
+ data_collator = DataCollatorWithPadding(tokenizer)
255
+
256
+ training_args = TrainingArguments(
257
+ output_dir=args.output_model_dir,
258
+ # overwrite_output_dir=True,
259
+ num_train_epochs=1,
260
+ per_device_train_batch_size=16,
261
+ per_device_eval_batch_size=16,
262
+ eval_strategy="steps",
263
+ eval_steps=500,
264
+ save_strategy="steps",
265
+ save_steps=500,
266
+ save_total_limit=2,
267
+ logging_steps=500,
268
+ learning_rate=5e-5,
269
+ warmup_steps=1000,
270
+ weight_decay=0.01,
271
+ fp16=torch.cuda.is_available(),
272
+ dataloader_num_workers=args.num_workers or 0,
273
+ remove_unused_columns=False,
274
+ load_best_model_at_end=True,
275
+ metric_for_best_model="eval_loss",
276
+ greater_is_better=False,
277
+ )
278
+
279
+ trainer = SST2RewardTrainer(
280
+ model=model,
281
+ args=training_args,
282
+ data_collator=data_collator,
283
+ train_dataset=train_dataset,
284
+ eval_dataset=valid_dataset,
285
+ tokenizer=tokenizer,
286
+ compute_metrics=compute_metrics,
287
+ )
288
+
289
+ trainer.train()
290
+ trainer.save_model()
291
+ return
292
+
293
+
294
+ if __name__ == "__main__":
295
+ main()
examples/tutorials/rlhf/gpt2_sst2/step_4_test_reward_model.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ import os
5
+ from pathlib import Path
6
+ import platform
7
+ from typing import Any, Dict, List, Optional, Union, Tuple
8
+
9
+ if platform.system() in ("Windows", "Darwin"):
10
+ from project_settings import project_path, temp_directory
11
+ else:
12
+ project_path = os.path.abspath("../../../")
13
+ project_path = Path(project_path)
14
+ temp_directory = Path("/root/autodl-tmp/OpenMiniMind/temp")
15
+
16
+ from datasets import load_dataset
17
+ import numpy as np
18
+ import torch
19
+ import torch.nn as nn
20
+ from transformers import (AutoTokenizer,
21
+ GPT2PreTrainedModel, GPT2Config, GPT2Model,
22
+ )
23
+
24
+
25
+ def get_args():
26
+ parser = argparse.ArgumentParser()
27
+ parser.add_argument(
28
+ "--model_name",
29
+ # default="openai-community/gpt2",
30
+ # default=(project_path / "trained_models/gpt2-sst2-reward").as_posix(),
31
+ default=(project_path / "trained_models/gpt2-sst2-reward-20260213-2122").as_posix(),
32
+ type=str
33
+ ),
34
+ parser.add_argument(
35
+ "--dataset_path",
36
+ default="stanfordnlp/sst2",
37
+ type=str
38
+ ),
39
+ parser.add_argument("--dataset_name", default=None, type=str),
40
+ parser.add_argument("--dataset_split", default=None, type=str),
41
+ parser.add_argument(
42
+ "--dataset_cache_dir",
43
+ default=(temp_directory / "hub_datasets").as_posix(),
44
+ type=str
45
+ ),
46
+ parser.add_argument(
47
+ "--model_cache_dir",
48
+ default=(temp_directory / "hub_models").as_posix(),
49
+ type=str
50
+ ),
51
+ parser.add_argument("--dataset_streaming", default=None, type=str),
52
+
53
+ parser.add_argument(
54
+ "--num_workers",
55
+ default=None if platform.system() in ("Windows", "Darwin") else os.cpu_count() // 2,
56
+ type=int
57
+ ),
58
+ parser.add_argument(
59
+ "--device",
60
+ default=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
61
+ type=str
62
+ ),
63
+ args = parser.parse_args()
64
+ return args
65
+
66
+
67
+ class RewardHead(nn.Module):
68
+ def __init__(self, hidden_size: int):
69
+ super().__init__()
70
+ self.hidden_size = hidden_size
71
+ self.linear = nn.Linear(self.hidden_size, 1)
72
+ self._post_init()
73
+
74
+ def _post_init(self):
75
+ nn.init.normal_(
76
+ self.linear.weight,
77
+ std=(1.0 / np.sqrt(self.hidden_size + 1))
78
+ )
79
+ nn.init.zeros_(self.linear.bias)
80
+
81
+ def forward(self, hidden_states):
82
+ # hidden_states shape: [batch_size, seq_len, hidden_size]
83
+ reward_logits = self.linear(hidden_states)
84
+ # reward_logits shape: [batch_size, seq_len, 1]
85
+ return reward_logits
86
+
87
+
88
+ class GPT2RewardModel(GPT2PreTrainedModel):
89
+ def __init__(self, config: GPT2Config):
90
+ super().__init__(config)
91
+ self.transformer = GPT2Model(config)
92
+ self.reward_head = RewardHead(config.hidden_size)
93
+ self.post_init()
94
+
95
+ def forward(
96
+ self,
97
+ input_ids: Optional[torch.LongTensor] = None,
98
+ attention_mask: Optional[torch.FloatTensor] = None,
99
+ ) -> Union[Tuple, torch.Tensor]:
100
+ transformer_outputs = self.transformer(
101
+ input_ids=input_ids,
102
+ attention_mask=attention_mask,
103
+ output_hidden_states=True
104
+ )
105
+ last_hidden_state = transformer_outputs.hidden_states[-1]
106
+ # last_hidden_state shape: [batch_size, seq_len, hidden_size]
107
+ rewards_logits = self.reward_head(last_hidden_state)
108
+ # rewards_logits shape: [batch_size, seq_len, 1]
109
+ rewards_logits = torch.squeeze(rewards_logits, -1)
110
+ # rewards_logits shape: [batch_size, seq_len]
111
+ rewards = torch.sigmoid(rewards_logits)
112
+ # rewards shape: [batch_size, seq_len]
113
+ return rewards
114
+
115
+
116
+ def main():
117
+ args = get_args()
118
+
119
+ model = GPT2RewardModel.from_pretrained(args.model_name)
120
+ model = model.to(args.device)
121
+ # model.eval()
122
+ tokenizer = AutoTokenizer.from_pretrained(args.model_name)
123
+ tokenizer.pad_token = tokenizer.eos_token
124
+
125
+ dataset_dict = load_dataset(
126
+ path=args.dataset_path,
127
+ name=args.dataset_name,
128
+ split=args.dataset_split,
129
+ cache_dir=args.dataset_cache_dir,
130
+ # num_proc=args.num_workers if not args.dataset_streaming else None,
131
+ streaming=args.dataset_streaming,
132
+ )
133
+ # dataset = dataset_dict["train"]
134
+ dataset = dataset_dict["validation"]
135
+ # dataset = dataset_dict["test"]
136
+
137
+ for example in dataset:
138
+ sentence: str = example["sentence"]
139
+ score: float = float(example["label"])
140
+ # sentence = "this is very good movie, I recommend it."
141
+
142
+ sentence += tokenizer.eos_token
143
+ tokenized = tokenizer(
144
+ sentence,
145
+ return_tensors="pt"
146
+ )
147
+ with torch.no_grad():
148
+ rewards = model(**tokenized)
149
+ rewards = rewards[0]
150
+ rewards = rewards.detach().cpu().numpy()
151
+ last_token_reward = rewards[-1]
152
+ #rewards: {rewards}\n
153
+ msg = f"last_token_reward: {last_token_reward}\nscore: {score}\nsentence: {sentence}\n"
154
+ print(msg)
155
+
156
+ return
157
+
158
+
159
+ if __name__ == "__main__":
160
+ main()
examples/tutorials/rlhf/gpt2_sst2/step_5_ppo_rlhf.py ADDED
@@ -0,0 +1,430 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ import copy
5
+ import os
6
+ import random
7
+ from pathlib import Path
8
+ import platform
9
+ from typing import Optional, Tuple, List, Dict, Union
10
+
11
+ import numpy as np
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ from torch.utils.data import DataLoader
16
+ from datasets import load_dataset
17
+ from transformers import (
18
+ AutoTokenizer, AutoModelForCausalLM, GPT2PreTrainedModel,
19
+ GPT2Config, GPT2Model, GPT2LMHeadModel, DataCollatorWithPadding
20
+ )
21
+
22
+ # 路径配置
23
+ if platform.system() in ("Windows", "Darwin"):
24
+ from project_settings import project_path, temp_directory
25
+ else:
26
+ project_path = Path(os.path.abspath("../../../"))
27
+ temp_directory = Path("/root/autodl-tmp/OpenMiniMind/temp")
28
+
29
+
30
+ def get_args():
31
+ parser = argparse.ArgumentParser()
32
+ parser.add_argument("--reward_model_name", type=str,
33
+ default=(project_path / "trained_models/gpt2-sst2-reward").as_posix())
34
+ parser.add_argument("--sft_model_name", type=str,
35
+ default=(project_path / "trained_models/gpt2-sst2-generation").as_posix())
36
+ parser.add_argument("--dataset_path", default="stanfordnlp/sst2", type=str)
37
+ parser.add_argument("--dataset_cache_dir",
38
+ default=(temp_directory / "hub_datasets").as_posix(), type=str)
39
+ parser.add_argument("--model_cache_dir",
40
+ default=(temp_directory / "hub_models").as_posix(), type=str)
41
+ parser.add_argument("--valid_dataset_size", default=1000, type=int)
42
+
43
+ # 训练参数
44
+ parser.add_argument("--batch_size", default=16, type=int) # CPU上用小一点的batch
45
+ parser.add_argument("--ppo_epochs", default=4, type=int)
46
+ parser.add_argument("--mini_batch_size", default=4, type=int)
47
+ parser.add_argument("--kl_beta", default=0.2, type=float)
48
+ parser.add_argument("--gamma", default=1.0, type=float)
49
+ parser.add_argument("--lam", default=0.95, type=float)
50
+ parser.add_argument("--clip_epsilon", default=0.2, type=float)
51
+ parser.add_argument("--lr", default=1e-5, type=float)
52
+ parser.add_argument("--max_epochs", default=10, type=int)
53
+
54
+ # 生成参数
55
+ parser.add_argument("--max_new_tokens", default=32, type=int)
56
+ parser.add_argument("--top_p", default=0.85, type=float)
57
+ parser.add_argument("--temperature", default=0.85, type=float)
58
+ parser.add_argument("--min_response_len", default=5, type=int)
59
+ parser.add_argument("--max_response_len", default=16, type=int)
60
+
61
+ # 其他
62
+ parser.add_argument("--num_workers", default=0 if platform.system() == "Windows" else 2, type=int)
63
+ parser.add_argument("--device", default="cpu", type=str) # 强制用CPU
64
+
65
+ return parser.parse_args()
66
+
67
+
68
+ class ValueHead(nn.Module):
69
+ """价值头,为每个token预测一个价值"""
70
+
71
+ def __init__(self, hidden_size: int):
72
+ super().__init__()
73
+ self.linear = nn.Linear(hidden_size, 1)
74
+ self._init_weights()
75
+
76
+ def _init_weights(self):
77
+ nn.init.normal_(self.linear.weight, std=1.0 / np.sqrt(self.linear.in_features + 1))
78
+ nn.init.zeros_(self.linear.bias)
79
+
80
+ def forward(self, hidden_states):
81
+ return self.linear(hidden_states).squeeze(-1)
82
+
83
+
84
+ class GPT2ActorCritic(GPT2PreTrainedModel):
85
+ """Actor-Critic模型,同时输出logits和values"""
86
+
87
+ def __init__(self, config: GPT2Config):
88
+ super().__init__(config)
89
+ self.lm = GPT2LMHeadModel(config)
90
+ self.value_head = ValueHead(config.hidden_size)
91
+ self.post_init()
92
+
93
+ def forward(self, input_ids, attention_mask=None):
94
+ outputs = self.lm(
95
+ input_ids,
96
+ attention_mask=attention_mask,
97
+ output_hidden_states=True
98
+ )
99
+ # values来自最后一层hidden states
100
+ values = self.value_head(outputs.hidden_states[-1])
101
+ return outputs.logits, values
102
+
103
+ def generate(self, *args, **kwargs):
104
+ return self.lm.generate(*args, **kwargs)
105
+
106
+ @classmethod
107
+ def from_pretrained(cls, pretrained_model_name):
108
+ """从预训练GPT2LMHeadModel加载"""
109
+ config = GPT2Config.from_pretrained(pretrained_model_name)
110
+ model = cls(config)
111
+ pretrained = GPT2LMHeadModel.from_pretrained(pretrained_model_name)
112
+ model.lm.load_state_dict(pretrained.state_dict(), strict=False)
113
+ return model
114
+
115
+
116
+ class GPT2RewardModel(GPT2PreTrainedModel):
117
+ """奖励模型,为每个token预测奖励"""
118
+
119
+ def __init__(self, config: GPT2Config):
120
+ super().__init__(config)
121
+ self.transformer = GPT2Model(config)
122
+ self.reward_head = nn.Linear(config.hidden_size, 1)
123
+ self.post_init()
124
+
125
+ def forward(self, input_ids, attention_mask=None):
126
+ outputs = self.transformer(
127
+ input_ids,
128
+ attention_mask=attention_mask,
129
+ output_hidden_states=True
130
+ )
131
+ rewards = self.reward_head(outputs.hidden_states[-1]).squeeze(-1)
132
+ return torch.sigmoid(rewards) # [batch, seq_len]
133
+
134
+
135
+ class PPOAgent:
136
+ """PPO训练Agent,封装所有训练逻辑"""
137
+
138
+ def __init__(self, args):
139
+ self.args = args
140
+ self.device = torch.device(args.device)
141
+
142
+ # 加载tokenizer
143
+ self.tokenizer = AutoTokenizer.from_pretrained(args.sft_model_name)
144
+ self.tokenizer.pad_token = self.tokenizer.eos_token
145
+ self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
146
+
147
+ # 加载模型
148
+ print("Loading models...")
149
+ self.actor_critic = GPT2ActorCritic.from_pretrained(args.sft_model_name).to(self.device)
150
+ self.reward_model = GPT2RewardModel.from_pretrained(args.reward_model_name).to(self.device)
151
+ self.reward_model.eval()
152
+
153
+ # 参考模型(冻结)
154
+ self.ref_model = copy.deepcopy(self.actor_critic).to(self.device)
155
+ self.ref_model.eval()
156
+
157
+ # 优化器
158
+ self.optimizer = torch.optim.Adam(self.actor_critic.parameters(), lr=args.lr)
159
+
160
+ # 训练状态
161
+ self.training_step = 0
162
+
163
+ def prepare_dataset(self):
164
+ """准备训练数据集"""
165
+ print("Loading dataset...")
166
+ dataset = load_dataset(
167
+ path=self.args.dataset_path,
168
+ cache_dir=self.args.dataset_cache_dir,
169
+ split="train"
170
+ )
171
+
172
+ def filter_and_truncate(example):
173
+ # 只保留足够长的句子
174
+ tokens = self.tokenizer(example["sentence"])["input_ids"]
175
+ if len(tokens) <= 8:
176
+ return False
177
+
178
+ # 随机截取前2-6个token作为query
179
+ example["query_ids"] = tokens[:random.randint(2, 6)]
180
+ return True
181
+
182
+ dataset = dataset.filter(filter_and_truncate)
183
+ dataset = dataset.select(range(min(len(dataset), 5000))) # CPU上用小数据集
184
+
185
+ return dataset
186
+
187
+ def collect_rollouts(self, batch):
188
+ """收集一轮交互数据"""
189
+ query_ids_list = []
190
+ response_ids_list = []
191
+ rewards_list = []
192
+
193
+ for i in range(len(batch["query_ids"])):
194
+ query_ids = torch.tensor(batch["query_ids"][i]).to(self.device)
195
+ query_ids_list.append(query_ids)
196
+
197
+ # 生成response
198
+ with torch.no_grad():
199
+ response_len = random.randint(
200
+ self.args.min_response_len,
201
+ self.args.max_response_len
202
+ )
203
+ full_ids = self.actor_critic.generate(
204
+ input_ids=query_ids.unsqueeze(0),
205
+ max_new_tokens=response_len,
206
+ do_sample=True,
207
+ top_p=self.args.top_p,
208
+ temperature=self.args.temperature,
209
+ pad_token_id=self.tokenizer.pad_token_id,
210
+ eos_token_id=self.tokenizer.eos_token_id,
211
+ )[0]
212
+
213
+ response_ids = full_ids[len(query_ids):]
214
+ response_ids_list.append(response_ids)
215
+
216
+ # 计算奖励(只取最后一个token的奖励)
217
+ reward = self.reward_model(
218
+ full_ids.unsqueeze(0),
219
+ attention_mask=torch.ones_like(full_ids).unsqueeze(0)
220
+ )[0, -1]
221
+ # 缩放到[-1, 1]
222
+ rewards_list.append(2 * (reward - 0.5))
223
+
224
+ return query_ids_list, response_ids_list, rewards_list
225
+
226
+ def compute_advantages_and_returns(self, log_probs, values, rewards, masks):
227
+ """计算GAE advantages和returns"""
228
+ seq_len = rewards.shape[1]
229
+ advantages = torch.zeros_like(rewards)
230
+ returns = torch.zeros_like(rewards)
231
+
232
+ gae = 0
233
+ for t in reversed(range(seq_len)):
234
+ if t == seq_len - 1:
235
+ next_value = 0
236
+ else:
237
+ next_value = values[:, t + 1]
238
+
239
+ delta = rewards[:, t] + self.args.gamma * next_value - values[:, t]
240
+ gae = delta + self.args.gamma * self.args.lam * gae
241
+ advantages[:, t] = gae
242
+ returns[:, t] = advantages[:, t] + values[:, t]
243
+
244
+ # 只对有效位置进行whiten
245
+ advantages = self.masked_whiten(advantages, masks)
246
+ return advantages, returns
247
+
248
+ def masked_whiten(self, values, mask):
249
+ """带mask的whitening"""
250
+ mask = mask.float()
251
+ mean = (values * mask).sum() / mask.sum()
252
+ var = (((values - mean) * mask) ** 2).sum() / mask.sum()
253
+ whitened = (values - mean) * torch.rsqrt(var + 1e-8)
254
+ return whitened * mask
255
+
256
+ def ppo_step(self, batch_data):
257
+ """单步PPO更新"""
258
+ (query_ids_list, response_ids_list, old_log_probs,
259
+ advantages, returns, masks) = batch_data
260
+
261
+ # 拼接完整的query+response
262
+ full_ids_list = []
263
+ for q, r in zip(query_ids_list, response_ids_list):
264
+ full_ids_list.append(torch.cat([q, r]))
265
+
266
+ # padding
267
+ padded = self.tokenizer.pad(
268
+ {"input_ids": full_ids_list},
269
+ padding=True,
270
+ return_tensors="pt"
271
+ )
272
+ input_ids = padded["input_ids"].to(self.device)
273
+ attention_mask = padded["attention_mask"].to(self.device)
274
+
275
+ # 前向传播
276
+ logits, values = self.actor_critic(input_ids, attention_mask)
277
+
278
+ # 计算新的log_probs
279
+ log_probs = F.log_softmax(logits[:, :-1, :], dim=-1)
280
+ log_probs = torch.gather(
281
+ log_probs, 2,
282
+ input_ids[:, 1:].unsqueeze(-1)
283
+ ).squeeze(-1)
284
+
285
+ # 只保留response部分的log_probs
286
+ response_start = [len(q) for q in query_ids_list]
287
+ new_log_probs = []
288
+ for i, start in enumerate(response_start):
289
+ new_log_probs.append(log_probs[i, start - 1:start - 1 + len(response_ids_list[i])])
290
+ new_log_probs = torch.cat(new_log_probs)
291
+
292
+ # 计算ratio和PPO损失
293
+ old_log_probs = old_log_probs.detach()
294
+ ratio = torch.exp(new_log_probs - old_log_probs)
295
+
296
+ # 裁剪的policy loss
297
+ surr1 = ratio * advantages
298
+ surr2 = torch.clamp(ratio, 1 - self.args.clip_epsilon,
299
+ 1 + self.args.clip_epsilon) * advantages
300
+ policy_loss = -torch.min(surr1, surr2).mean()
301
+
302
+ # value loss
303
+ value_pred = []
304
+ for i, start in enumerate(response_start):
305
+ value_pred.append(values[i, start - 1:start - 1 + len(response_ids_list[i])])
306
+ value_pred = torch.cat(value_pred)
307
+ value_loss = F.mse_loss(value_pred, returns)
308
+
309
+ # 总loss
310
+ loss = policy_loss + 0.5 * value_loss
311
+
312
+ return loss, policy_loss, value_loss
313
+
314
+ def train_epoch(self, dataset):
315
+ """训练一个epoch"""
316
+ total_policy_loss = 0
317
+ total_value_loss = 0
318
+ num_batches = 0
319
+
320
+ for batch_idx in range(0, len(dataset), self.args.batch_size):
321
+ # 1. 收集数据
322
+ batch = dataset[batch_idx:batch_idx + self.args.batch_size]
323
+ query_ids_list, response_ids_list, rewards_list = self.collect_rollouts(batch)
324
+
325
+ # 2. 计算旧的log_probs和values
326
+ old_log_probs_list = []
327
+ values_list = []
328
+ masks_list = []
329
+
330
+ with torch.no_grad():
331
+ for q_ids, r_ids in zip(query_ids_list, response_ids_list):
332
+ full_ids = torch.cat([q_ids, r_ids]).unsqueeze(0).to(self.device)
333
+ attn_mask = torch.ones_like(full_ids)
334
+
335
+ logits, values = self.actor_critic(full_ids, attn_mask)
336
+
337
+ # 计算response部分的log_probs
338
+ log_probs = F.log_softmax(logits[:, :-1, :], dim=-1)
339
+ log_probs = torch.gather(
340
+ log_probs, 2,
341
+ full_ids[:, 1:].unsqueeze(-1)
342
+ ).squeeze(-1)
343
+
344
+ start = len(q_ids) - 1
345
+ end = start + len(r_ids)
346
+ old_log_probs_list.append(log_probs[0, start:end])
347
+ values_list.append(values[0, start:end])
348
+
349
+ # 创建mask
350
+ mask = torch.zeros(len(r_ids))
351
+ mask[-1] = 1 # 最后一个token有真实奖励
352
+ masks_list.append(mask)
353
+
354
+ # 转换为tensor
355
+ old_log_probs = torch.cat(old_log_probs_list).to(self.device)
356
+ values = torch.cat(values_list).to(self.device)
357
+ masks = torch.cat(masks_list).to(self.device)
358
+ rewards = torch.zeros_like(values).to(self.device)
359
+
360
+ # 设置奖励(只在最后一个token加上环境奖励)
361
+ for i, (r, mask) in enumerate(zip(rewards_list, masks_list)):
362
+ if mask[-1] > 0:
363
+ # KL惩罚
364
+ kl = old_log_probs[i] - old_log_probs[i] # 这里简化了,实际要用ref_model
365
+ kl_penalty = -self.args.kl_beta * kl
366
+ rewards[i] = kl_penalty + r
367
+
368
+ # 3. 计算advantages和returns
369
+ advantages, returns = self.compute_advantages_and_returns(
370
+ old_log_probs.unsqueeze(0),
371
+ values.unsqueeze(0),
372
+ rewards.unsqueeze(0),
373
+ masks.unsqueeze(0)
374
+ )
375
+
376
+ # 4. PPO多次更新
377
+ batch_data = (query_ids_list, response_ids_list, old_log_probs,
378
+ advantages.squeeze(0), returns.squeeze(0), masks)
379
+
380
+ for _ in range(self.args.ppo_epochs):
381
+ loss, policy_loss, value_loss = self.ppo_step(batch_data)
382
+
383
+ self.optimizer.zero_grad()
384
+ loss.backward()
385
+ torch.nn.utils.clip_grad_norm_(self.actor_critic.parameters(), 1.0)
386
+ self.optimizer.step()
387
+
388
+ total_policy_loss += policy_loss.item()
389
+ total_value_loss += value_loss.item()
390
+ num_batches += 1
391
+ self.training_step += 1
392
+
393
+ if batch_idx % 100 == 0:
394
+ print(f"Batch {batch_idx}/{len(dataset)}: "
395
+ f"policy_loss={total_policy_loss / num_batches:.4f}, "
396
+ f"value_loss={total_value_loss / num_batches:.4f}")
397
+
398
+ return total_policy_loss / num_batches, total_value_loss / num_batches
399
+
400
+ def train(self):
401
+ """主训练循环"""
402
+ dataset = self.prepare_dataset()
403
+ print(f"Dataset size: {len(dataset)}")
404
+
405
+ for epoch in range(self.args.max_epochs):
406
+ print(f"\n=== Epoch {epoch + 1}/{self.args.max_epochs} ===")
407
+ policy_loss, value_loss = self.train_epoch(dataset)
408
+ print(f"Epoch {epoch + 1} finished: "
409
+ f"policy_loss={policy_loss:.4f}, value_loss={value_loss:.4f}")
410
+
411
+
412
+ def main():
413
+ args = get_args()
414
+ print("PPO Training with CPU")
415
+ print(f"Arguments: {args}")
416
+
417
+ # 创建agent并开始训练
418
+ agent = PPOAgent(args)
419
+ agent.train()
420
+
421
+ # 保存模型
422
+ output_dir = Path(args.sft_model_name) / "ppo_trained"
423
+ output_dir.mkdir(exist_ok=True, parents=True)
424
+ agent.actor_critic.save_pretrained(output_dir)
425
+ agent.tokenizer.save_pretrained(output_dir)
426
+ print(f"Model saved to {output_dir}")
427
+
428
+
429
+ if __name__ == "__main__":
430
+ main()
examples/tutorials/rlhf/gpt2_sst2/step_5_ppo_rlhf2.py ADDED
@@ -0,0 +1,430 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ import copy
5
+ import os
6
+ from pathlib import Path
7
+ import platform
8
+ import random
9
+ from typing import Any, Dict, List, Optional, Union, Tuple
10
+
11
+ if platform.system() in ("Windows", "Darwin"):
12
+ from project_settings import project_path, temp_directory
13
+ else:
14
+ project_path = os.path.abspath("../../../")
15
+ project_path = Path(project_path)
16
+ temp_directory = Path("/root/autodl-tmp/OpenMiniMind/temp")
17
+
18
+ from datasets import load_dataset
19
+ import numpy as np
20
+ import torch
21
+ import torch.nn as nn
22
+ from torch.utils.data import DataLoader
23
+ from transformers.cache_utils import Cache, DynamicCache, EncoderDecoderCache
24
+ from transformers import (AutoTokenizer,
25
+ GPT2PreTrainedModel, GPT2Config, GPT2Model, GPT2LMHeadModel,
26
+ )
27
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
28
+ from transformers import DataCollatorWithPadding
29
+
30
+
31
+ def get_args():
32
+ parser = argparse.ArgumentParser()
33
+ parser.add_argument(
34
+ "--reward_model_name",
35
+ default=(project_path / "trained_models/gpt2-sst2-reward").as_posix(),
36
+ type=str
37
+ ),
38
+ parser.add_argument(
39
+ "--sft_model_name",
40
+ default=(project_path / "trained_models/gpt2-sst2-generation").as_posix(),
41
+ type=str
42
+ ),
43
+ parser.add_argument(
44
+ "--dataset_path",
45
+ default="stanfordnlp/sst2",
46
+ type=str
47
+ ),
48
+ parser.add_argument("--dataset_name", default=None, type=str),
49
+ parser.add_argument("--dataset_split", default=None, type=str),
50
+ parser.add_argument(
51
+ "--dataset_cache_dir",
52
+ default=(temp_directory / "hub_datasets").as_posix(),
53
+ type=str
54
+ ),
55
+ parser.add_argument(
56
+ "--model_cache_dir",
57
+ default=(temp_directory / "hub_models").as_posix(),
58
+ type=str
59
+ ),
60
+ parser.add_argument("--dataset_streaming", default=None, type=str),
61
+ parser.add_argument("--valid_dataset_size", default=1000, type=int),
62
+ parser.add_argument("--shuffle_buffer_size", default=5000, type=int),
63
+
64
+ parser.add_argument(
65
+ "--output_model_dir",
66
+ default=(project_path / "trained_models/gpt2-sst2-generation").as_posix(),
67
+ type=str
68
+ ),
69
+ # train
70
+ parser.add_argument("--batch_size", default=32, type=int)
71
+
72
+ # generator
73
+ parser.add_argument(
74
+ "--max_new_tokens",
75
+ default=128, # 8192, 128
76
+ type=int,
77
+ )
78
+ parser.add_argument("--top_p", default=0.85, type=float)
79
+ parser.add_argument("--temperature", default=0.85, type=float)
80
+
81
+ # other
82
+ parser.add_argument(
83
+ "--num_workers",
84
+ default=None if platform.system() in ("Windows", "Darwin") else os.cpu_count() // 2,
85
+ type=int
86
+ ),
87
+ parser.add_argument(
88
+ "--device",
89
+ default=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
90
+ type=int
91
+ ),
92
+ args = parser.parse_args()
93
+ return args
94
+
95
+
96
+ class RewardHead(nn.Module):
97
+ def __init__(self, hidden_size: int):
98
+ super().__init__()
99
+ self.hidden_size = hidden_size
100
+ self.linear = nn.Linear(self.hidden_size, 1)
101
+ self._post_init()
102
+
103
+ def _post_init(self):
104
+ nn.init.normal_(
105
+ self.linear.weight,
106
+ std=(1.0 / np.sqrt(self.hidden_size + 1))
107
+ )
108
+ nn.init.zeros_(self.linear.bias)
109
+
110
+ def forward(self, hidden_states):
111
+ # hidden_states shape: [batch_size, seq_len, hidden_size]
112
+ reward_logits = self.linear(hidden_states)
113
+ # reward_logits shape: [batch_size, seq_len, 1]
114
+ return reward_logits
115
+
116
+
117
+ class GPT2RewardModel(GPT2PreTrainedModel):
118
+ def __init__(self, config: GPT2Config):
119
+ super().__init__(config)
120
+ self.transformer = GPT2Model(config)
121
+ self.reward_head = RewardHead(config.hidden_size)
122
+ self.post_init()
123
+
124
+ def forward(
125
+ self,
126
+ input_ids: Optional[torch.LongTensor] = None,
127
+ attention_mask: Optional[torch.FloatTensor] = None,
128
+ ) -> Union[Tuple, torch.Tensor]:
129
+ transformer_outputs = self.transformer(
130
+ input_ids=input_ids,
131
+ attention_mask=attention_mask,
132
+ output_hidden_states=True
133
+ )
134
+ last_hidden_state = transformer_outputs.hidden_states[-1]
135
+ # last_hidden_state shape: [batch_size, seq_len, hidden_size]
136
+ rewards_logits = self.reward_head(last_hidden_state)
137
+ # rewards_logits shape: [batch_size, seq_len, 1]
138
+ rewards_logits = torch.squeeze(rewards_logits, -1)
139
+ # rewards_logits shape: [batch_size, seq_len]
140
+ rewards = torch.sigmoid(rewards_logits)
141
+ # rewards shape: [batch_size, seq_len]
142
+ return rewards
143
+
144
+
145
+ class ValueHead(nn.Module):
146
+ def __init__(self, hidden_size: int):
147
+ super().__init__()
148
+ self.hidden_size = hidden_size
149
+ self.linear = nn.Linear(self.hidden_size, 1)
150
+ self._post_init()
151
+
152
+ def _post_init(self):
153
+ nn.init.normal_(
154
+ self.linear.weight,
155
+ std=(1.0 / np.sqrt(self.hidden_size + 1))
156
+ )
157
+ nn.init.zeros_(self.linear.bias)
158
+
159
+ def forward(self, hidden_states):
160
+ # hidden_states shape: [batch_size, seq_len, hidden_size]
161
+ reward_logits = self.linear(hidden_states)
162
+ # reward_logits shape: [batch_size, seq_len, 1]
163
+ return reward_logits
164
+
165
+
166
+ class GPT2ActorCriticModel(GPT2PreTrainedModel):
167
+ def __init__(self, config: GPT2Config):
168
+ super().__init__(config)
169
+ self.lm = GPT2LMHeadModel(config)
170
+ self.value_head = ValueHead(config.hidden_size)
171
+ self.post_init()
172
+
173
+ def forward(
174
+ self,
175
+ input_ids: Optional[torch.LongTensor] = None,
176
+ attention_mask: Optional[torch.FloatTensor] = None,
177
+ ) -> Union[Tuple, torch.Tensor]:
178
+ transformer_outputs = self.lm.forward(
179
+ input_ids,
180
+ attention_mask=attention_mask,
181
+ output_hidden_states=True,
182
+ )
183
+ lm_logits = transformer_outputs.logits
184
+
185
+ # values
186
+ last_hidden_state = transformer_outputs.hidden_states[-1]
187
+ # last_hidden_state shape: [batch_size, seq_len, hidden_size]
188
+ values_logits = self.value_head(last_hidden_state)
189
+ # values_logits shape: [batch_size, seq_len, 1]
190
+ values = torch.squeeze(values_logits, -1)
191
+ # values shape: [batch_size, seq_len]
192
+ values = torch.sigmoid(values)
193
+ # values shape: [batch_size, seq_len]
194
+ return lm_logits, values
195
+
196
+ @classmethod
197
+ def from_pretrained(cls, model_name_or_path, *model_args, **kwargs):
198
+ config = GPT2Config.from_pretrained(model_name_or_path)
199
+ model = cls(config)
200
+ pretrained_model = GPT2LMHeadModel.from_pretrained(model_name_or_path)
201
+ model.lm.load_state_dict(pretrained_model.state_dict(), strict=False)
202
+ return model
203
+
204
+ def generate(self, *args, **kwargs):
205
+ return self.lm.generate(*args, **kwargs)
206
+
207
+
208
+ def masked_mean(values, mask):
209
+ # 计算带掩码的平均值
210
+ return (values * mask).sum() / mask.sum()
211
+
212
+ def masked_var(values, mask):
213
+ # 计算带掩码的方差
214
+ mean = masked_mean(values, mask)
215
+ centred_values = values - mean
216
+ return masked_mean(centred_values ** 2, mask)
217
+
218
+
219
+ def masked_whiten(values, mask):
220
+ """
221
+ 对数据进行带掩码的白化处理,
222
+ 让有效数据的方差变为1,但均值保持不变
223
+ """
224
+ mean, var = masked_mean(values, mask), masked_var(values, mask)
225
+ whitened = (values - mean) * torch.rsqrt(var + 1e-8)
226
+ whitened += mean
227
+ return whitened
228
+
229
+
230
+ def main():
231
+ args = get_args()
232
+
233
+ device = torch.device(args.device)
234
+
235
+ # reward_model
236
+ reward_model = GPT2RewardModel.from_pretrained(args.reward_model_name)
237
+ reward_model = reward_model.to(args.device)
238
+ reward_model.eval()
239
+ reward_tokenizer = AutoTokenizer.from_pretrained(args.reward_model_name)
240
+ reward_tokenizer.pad_token = reward_tokenizer.eos_token
241
+
242
+ # actor_critic_model
243
+ actor_critic_model = GPT2ActorCriticModel.from_pretrained(args.sft_model_name)
244
+ actor_critic_model = actor_critic_model.to(args.device)
245
+ actor_critic_tokenizer = AutoTokenizer.from_pretrained(args.sft_model_name)
246
+ actor_critic_tokenizer.pad_token = actor_critic_tokenizer.eos_token
247
+ actor_critic_tokenizer.pad_token_id = actor_critic_tokenizer.eos_token_id
248
+
249
+ # ref_model
250
+ ref_model = copy.deepcopy(actor_critic_model)
251
+ ref_model = ref_model.to(args.device)
252
+ ref_model.eval()
253
+
254
+ dataset_dict = load_dataset(
255
+ path=args.dataset_path,
256
+ name=args.dataset_name,
257
+ split=args.dataset_split,
258
+ cache_dir=args.dataset_cache_dir,
259
+ # num_proc=args.num_workers if not args.dataset_streaming else None,
260
+ streaming=args.dataset_streaming,
261
+ )
262
+ train_dataset = dataset_dict["train"]
263
+ # valid_dataset = dataset_dict["validation"]
264
+ # test_dataset = dataset_dict["test"]
265
+
266
+ def format_func(example):
267
+ sentence: str = example["sentence"]
268
+ score: float = float(example["label"])
269
+ tokenized = actor_critic_tokenizer(sentence)
270
+ input_ids = tokenized["input_ids"]
271
+ attention_mask = tokenized["attention_mask"]
272
+ result = {
273
+ "input_ids": input_ids,
274
+ "attention_mask": attention_mask,
275
+ }
276
+ return result
277
+
278
+ train_dataset = train_dataset.map(
279
+ format_func,
280
+ batched=False,
281
+ remove_columns=train_dataset.column_names,
282
+ )
283
+ train_dataset = train_dataset.filter(
284
+ function=lambda x: len(x["input_ids"]) > 8
285
+ )
286
+ def token_truncate(example):
287
+ target_length = random.randint(2, 6)
288
+ input_ids = example["input_ids"]
289
+ attention_mask = example["attention_mask"]
290
+ input_ids = input_ids[:target_length]
291
+ attention_mask = attention_mask[:target_length]
292
+ text = actor_critic_tokenizer.decode(input_ids)
293
+ result = {
294
+ "input_ids": input_ids,
295
+ "attention_mask": attention_mask,
296
+ # "text": text,
297
+ }
298
+ return result
299
+
300
+ train_dataset = train_dataset.map(
301
+ token_truncate,
302
+ batched=False,
303
+ remove_columns=train_dataset.column_names,
304
+ )
305
+ data_collator = DataCollatorWithPadding(
306
+ tokenizer=actor_critic_tokenizer,
307
+ padding=True,
308
+ )
309
+ train_data_loader = DataLoader(
310
+ dataset=train_dataset,
311
+ batch_size=args.batch_size,
312
+ shuffle=True,
313
+ num_workers=args.num_workers or 0,
314
+ collate_fn=data_collator,
315
+ )
316
+
317
+ # for train_batch in train_data_loader:
318
+ # print(train_batch)
319
+
320
+ for epoch_id in range(10):
321
+ for batch in train_data_loader:
322
+ input_ids = batch["input_ids"]
323
+ attention_mask = batch["attention_mask"]
324
+
325
+ query_ids_list = list()
326
+ query_and_response_ids_list = list()
327
+ response_ids_list = list()
328
+ reward_list = list()
329
+ for idx in range(args.batch_size):
330
+ input_ids_ = input_ids[idx]
331
+ attention_mask_ = attention_mask[idx]
332
+ input_ids_ = input_ids_.to(device)
333
+ attention_mask_ = attention_mask_.to(device)
334
+
335
+ with torch.no_grad():
336
+ query_and_response_ids = actor_critic_model.generate(
337
+ input_ids=input_ids_.unsqueeze(0),
338
+ attention_mask=attention_mask_.unsqueeze(0),
339
+ max_new_tokens=random.randint(5, 16),
340
+ do_sample=True,
341
+ top_p=0.85,
342
+ temperature=0.85,
343
+ pad_token_id=actor_critic_tokenizer.pad_token_id,
344
+ eos_token_id=actor_critic_tokenizer.eos_token_id,
345
+ repetition_penalty=1.0,
346
+ early_stopping=True,
347
+ ).squeeze(0)
348
+ query_ids_list.append(input_ids_)
349
+ query_and_response_ids_list.append(query_and_response_ids)
350
+ response_ids = query_and_response_ids[len(input_ids_):]
351
+ response_ids_list.append(response_ids)
352
+
353
+ reward = reward_model(
354
+ input_ids=query_and_response_ids.unsqueeze(0),
355
+ attention_mask=torch.ones_like(query_and_response_ids, dtype=torch.long).unsqueeze(0),
356
+ ).squeeze(0)[-1]
357
+ # 将奖励模型的评分从(0,1)缩放到(-1,1)
358
+ reward = 2 * (reward - 0.5)
359
+ reward_list.append(reward)
360
+
361
+ for query_ids, query_and_response_ids in zip(query_ids_list, query_and_response_ids_list):
362
+ print(actor_critic_tokenizer.decode(query_ids, skip_special_tokens=False))
363
+ print(actor_critic_tokenizer.decode(query_ids, skip_special_tokens=True))
364
+ print(actor_critic_tokenizer.decode(query_and_response_ids, skip_special_tokens=True))
365
+ exit(0)
366
+ #计算奖励
367
+ batch_ = list()
368
+ for query_and_response_ids in query_and_response_ids_list:
369
+ print(actor_critic_tokenizer.decode(query_and_response_ids))
370
+ batch_.append({
371
+ "input_ids": query_and_response_ids,
372
+ "attention_mask": torch.ones_like(query_and_response_ids),
373
+ })
374
+
375
+ batch_ = data_collator(batch_)
376
+ input_ids = batch_["input_ids"]
377
+ attention_mask = batch_["attention_mask"]
378
+ input_ids = input_ids.to(device)
379
+ attention_mask = attention_mask.to(device)
380
+
381
+ logits, values = actor_critic_model(input_ids=input_ids, attention_mask=attention_mask)
382
+ ref_logits, _ = ref_model(input_ids=input_ids, attention_mask=attention_mask)
383
+ log_prob = torch.nn.functional.log_softmax(logits[:, :-1, :], dim=-1)
384
+ ref_log_prob = torch.nn.functional.log_softmax(ref_logits[:, :-1, :], dim=-1)
385
+ index = input_ids[:, 1:].unsqueeze(-1)
386
+ log_prob = torch.gather(log_prob, dim=2, index=index).squeeze(-1)
387
+ ref_log_prob = torch.gather(ref_log_prob, dim=2, index=index).squeeze(-1)
388
+
389
+ kl = log_prob - ref_log_prob
390
+ beta = 0.2
391
+ kl_penalty = - beta * kl
392
+
393
+ rewards = kl_penalty
394
+ masks = torch.zeros_like(input_ids[:, 1:])
395
+ for idx in range(args.batch_size):
396
+ start = len(query_ids_list[idx]) - 1
397
+ end = start + len(response_ids_list[idx])
398
+ masks[idx, start:end] = 1
399
+ rewards[idx, end - 1] += reward_list[idx]
400
+ values[idx, :-1] *= masks[idx, :]
401
+ values[idx, -1] = 0
402
+ rewards = rewards * masks
403
+
404
+ # log_prob, rewards, kl_penalty, masks shape: [b, seq_len - 1]
405
+ # values shape: [b, seq_len]
406
+
407
+ # 计算优势
408
+ seq_len = rewards.shape[-1]
409
+ last_gae = 0.0
410
+ gamma, lam = 1.0, 0.95
411
+ advantage_reversed = list()
412
+ for t in reversed(range(seq_len)):
413
+ next_value = values[:, t + 1] if t < seq_len - 1 else 0.0
414
+ delta = rewards[:, t] + gamma * next_value - values[:, t]
415
+ last_gae = delta + gamma * lam * last_gae
416
+ advantage_reversed.append(last_gae)
417
+ advantages = torch.stack(advantage_reversed[::-1], dim=1)
418
+ advantages = masked_whiten(advantages, masks)
419
+ returns = advantages + values[:, :-1]
420
+
421
+ # advantages shape: [b, seq_len-1]
422
+ # returns shape: [b, seq_len-1]
423
+
424
+ exit(0)
425
+
426
+ return
427
+
428
+
429
+ if __name__ == "__main__":
430
+ main()
examples/tutorials/rlhf/gpt2_sst2/step_5_pre_ppo_rlhf.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ import os
5
+ from pathlib import Path
6
+ import platform
7
+ from typing import Any, Dict, List, Optional, Union, Tuple
8
+
9
+ if platform.system() in ("Windows", "Darwin"):
10
+ from project_settings import project_path, temp_directory
11
+ else:
12
+ project_path = os.path.abspath("../../../")
13
+ project_path = Path(project_path)
14
+ temp_directory = Path("/root/autodl-tmp/OpenMiniMind/temp")
15
+
16
+ from datasets import load_dataset
17
+ import numpy as np
18
+ import torch
19
+ import torch.nn as nn
20
+ from transformers.cache_utils import Cache, DynamicCache, EncoderDecoderCache
21
+ from transformers import (AutoTokenizer,
22
+ GPT2PreTrainedModel, GPT2Config, GPT2Model, GPT2LMHeadModel,
23
+ )
24
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
25
+
26
+
27
+ def get_args():
28
+ parser = argparse.ArgumentParser()
29
+ parser.add_argument(
30
+ "--reward_model_name",
31
+ default=(project_path / "trained_models/gpt2-sst2-reward").as_posix(),
32
+ type=str
33
+ ),
34
+ parser.add_argument(
35
+ "--policy_model_name",
36
+ default=(project_path / "trained_models/gpt2-sst2-generation").as_posix(),
37
+ type=str
38
+ ),
39
+ parser.add_argument(
40
+ "--dataset_path",
41
+ default="stanfordnlp/sst2",
42
+ type=str
43
+ ),
44
+ parser.add_argument("--dataset_name", default=None, type=str),
45
+ parser.add_argument("--dataset_split", default=None, type=str),
46
+ parser.add_argument(
47
+ "--dataset_cache_dir",
48
+ default=(temp_directory / "hub_datasets").as_posix(),
49
+ type=str
50
+ ),
51
+ parser.add_argument(
52
+ "--model_cache_dir",
53
+ default=(temp_directory / "hub_models").as_posix(),
54
+ type=str
55
+ ),
56
+ parser.add_argument("--dataset_streaming", default=None, type=str),
57
+ parser.add_argument("--valid_dataset_size", default=1000, type=int),
58
+ parser.add_argument("--shuffle_buffer_size", default=5000, type=int),
59
+
60
+ parser.add_argument(
61
+ "--output_model_dir",
62
+ default=(project_path / "trained_models/gpt2-sst2-generation").as_posix(),
63
+ type=str
64
+ ),
65
+
66
+ parser.add_argument(
67
+ "--max_new_tokens",
68
+ default=128, # 8192, 128
69
+ type=int, help="最大生成长度(注意:并非模型实际长文本能力)"
70
+ )
71
+ parser.add_argument("--top_p", default=0.85, type=float, help="nucleus采样阈值(0-1)")
72
+ parser.add_argument("--temperature", default=0.85, type=float, help="生成温度,控制随机性(0-1,越大越随机)")
73
+
74
+ # other
75
+ parser.add_argument(
76
+ "--num_workers",
77
+ default=None if platform.system() in ("Windows", "Darwin") else os.cpu_count() // 2,
78
+ type=int
79
+ ),
80
+ parser.add_argument(
81
+ "--device",
82
+ default=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
83
+ type=int
84
+ ),
85
+ args = parser.parse_args()
86
+ return args
87
+
88
+
89
+ class RewardHead(nn.Module):
90
+ def __init__(self, hidden_size: int):
91
+ super().__init__()
92
+ self.hidden_size = hidden_size
93
+ self.linear = nn.Linear(self.hidden_size, 1)
94
+ self._post_init()
95
+
96
+ def _post_init(self):
97
+ nn.init.normal_(
98
+ self.linear.weight,
99
+ std=(1.0 / np.sqrt(self.hidden_size + 1))
100
+ )
101
+ nn.init.zeros_(self.linear.bias)
102
+
103
+ def forward(self, hidden_states):
104
+ # hidden_states shape: [batch_size, seq_len, hidden_size]
105
+ reward_logits = self.linear(hidden_states)
106
+ # reward_logits shape: [batch_size, seq_len, 1]
107
+ return reward_logits
108
+
109
+
110
+ class GPT2RewardModel(GPT2PreTrainedModel):
111
+ def __init__(self, config: GPT2Config):
112
+ super().__init__(config)
113
+ self.transformer = GPT2Model(config)
114
+ self.reward_head = RewardHead(config.hidden_size)
115
+ self.post_init()
116
+
117
+ def forward(
118
+ self,
119
+ input_ids: Optional[torch.LongTensor] = None,
120
+ attention_mask: Optional[torch.FloatTensor] = None,
121
+ ) -> Union[Tuple, torch.Tensor]:
122
+ transformer_outputs = self.transformer(
123
+ input_ids=input_ids,
124
+ attention_mask=attention_mask,
125
+ output_hidden_states=True
126
+ )
127
+ last_hidden_state = transformer_outputs.hidden_states[-1]
128
+ # last_hidden_state shape: [batch_size, seq_len, hidden_size]
129
+ rewards_logits = self.reward_head(last_hidden_state)
130
+ # rewards_logits shape: [batch_size, seq_len, 1]
131
+ rewards_logits = torch.squeeze(rewards_logits, -1)
132
+ # rewards_logits shape: [batch_size, seq_len]
133
+ rewards = torch.sigmoid(rewards_logits)
134
+ # rewards shape: [batch_size, seq_len]
135
+ return rewards
136
+
137
+
138
+ class ValueHead(nn.Module):
139
+ def __init__(self, hidden_size: int):
140
+ super().__init__()
141
+ self.hidden_size = hidden_size
142
+ self.linear = nn.Linear(self.hidden_size, 1)
143
+ self._post_init()
144
+
145
+ def _post_init(self):
146
+ nn.init.normal_(
147
+ self.linear.weight,
148
+ std=(1.0 / np.sqrt(self.hidden_size + 1))
149
+ )
150
+ nn.init.zeros_(self.linear.bias)
151
+
152
+ def forward(self, hidden_states):
153
+ # hidden_states shape: [batch_size, seq_len, hidden_size]
154
+ reward_logits = self.linear(hidden_states)
155
+ # reward_logits shape: [batch_size, seq_len, 1]
156
+ return reward_logits
157
+
158
+
159
+ class GPT2ActorCriticModel(GPT2PreTrainedModel):
160
+ def __init__(self, config: GPT2Config):
161
+ super().__init__(config)
162
+ self.lm = GPT2LMHeadModel(config)
163
+ self.value_head = ValueHead(config.hidden_size)
164
+ self.post_init()
165
+
166
+ def forward(
167
+ self,
168
+ input_ids: Optional[torch.LongTensor] = None,
169
+ attention_mask: Optional[torch.FloatTensor] = None,
170
+ ) -> Union[Tuple, torch.Tensor]:
171
+ transformer_outputs = self.lm.forward(
172
+ input_ids,
173
+ attention_mask=attention_mask,
174
+ output_hidden_states=True,
175
+ )
176
+ lm_logits = transformer_outputs.logits
177
+
178
+ # values
179
+ last_hidden_state = transformer_outputs.hidden_states[-1]
180
+ # last_hidden_state shape: [batch_size, seq_len, hidden_size]
181
+ values_logits = self.value_head(last_hidden_state)
182
+ # values_logits shape: [batch_size, seq_len, 1]
183
+ values = torch.squeeze(values_logits, -1)
184
+ # values shape: [batch_size, seq_len]
185
+ values = torch.sigmoid(values)
186
+ # values shape: [batch_size, seq_len]
187
+ return lm_logits, values
188
+
189
+ @classmethod
190
+ def from_pretrained(cls, model_name_or_path, *model_args, **kwargs):
191
+ config = GPT2Config.from_pretrained(model_name_or_path)
192
+ model = cls(config)
193
+ pretrained_model = GPT2LMHeadModel.from_pretrained(model_name_or_path)
194
+ model.lm.load_state_dict(pretrained_model.state_dict(), strict=False)
195
+ return model
196
+
197
+ def generate(self, *args, **kwargs):
198
+ return self.lm.generate(*args, **kwargs)
199
+
200
+
201
+ def main():
202
+ args = get_args()
203
+
204
+ reward_model = GPT2RewardModel.from_pretrained(args.reward_model_name)
205
+ reward_model = reward_model.to(args.device)
206
+ reward_tokenizer = AutoTokenizer.from_pretrained(args.reward_model_name)
207
+ reward_tokenizer.pad_token = reward_tokenizer.eos_token
208
+ print(reward_model)
209
+ print(reward_tokenizer)
210
+
211
+ # tokenized = reward_tokenizer(
212
+ # "this is very good movie, I recommend it.",
213
+ # return_tensors="pt"
214
+ # )
215
+ # rewards = reward_model(**tokenized)
216
+ # rewards = rewards[0]
217
+ # rewards = rewards.detach().cpu().numpy()
218
+ # last_token_reward = rewards[-1]
219
+ # # rewards: {rewards}\n
220
+ # msg = f"last_token_reward: {last_token_reward}\n"
221
+ # print(msg)
222
+ # exit(0)
223
+
224
+ # actor_critic_model
225
+ actor_critic_model = GPT2ActorCriticModel.from_pretrained(args.policy_model_name)
226
+ actor_critic_model = actor_critic_model.to(args.device)
227
+ actor_critic_tokenizer = AutoTokenizer.from_pretrained(args.policy_model_name)
228
+ actor_critic_tokenizer.pad_token = actor_critic_tokenizer.eos_token
229
+ print(actor_critic_model)
230
+ print(actor_critic_tokenizer)
231
+
232
+ tokenized = actor_critic_tokenizer(
233
+ "this is ",
234
+ return_tensors="pt"
235
+ )
236
+ streamer = TextStreamer(actor_critic_tokenizer, skip_prompt=True, skip_special_tokens=True)
237
+ generated_ids = actor_critic_model.generate(
238
+ inputs=tokenized["input_ids"], attention_mask=tokenized["attention_mask"],
239
+ max_new_tokens=args.max_new_tokens, do_sample=True, streamer=streamer,
240
+ pad_token_id=actor_critic_tokenizer.pad_token_id, eos_token_id=actor_critic_tokenizer.eos_token_id,
241
+ top_p=args.top_p, temperature=args.temperature, repetition_penalty=1.0,
242
+ )
243
+ response = actor_critic_tokenizer.decode(generated_ids[0][len(tokenized["input_ids"][0]):], skip_special_tokens=True)
244
+ print(response)
245
+
246
+ tokenized = actor_critic_tokenizer(
247
+ "this is very good movie, I recommend it.",
248
+ return_tensors="pt"
249
+ )
250
+ lm_logits, values = actor_critic_model(**tokenized)
251
+ print(values)
252
+
253
+ return
254
+
255
+
256
+ if __name__ == "__main__":
257
+ main()
tabs/chat_template_tab.py CHANGED
@@ -13,8 +13,10 @@ def run_chat_template(conversation: str, model_name: str, add_generation_prompt:
13
 
14
  result = tokenizer.apply_chat_template(
15
  conversation,
 
16
  tokenize=False,
17
  add_generation_prompt=add_generation_prompt,
 
18
  )
19
  return result
20
 
 
13
 
14
  result = tokenizer.apply_chat_template(
15
  conversation,
16
+ # tools=None,
17
  tokenize=False,
18
  add_generation_prompt=add_generation_prompt,
19
+ # enable_thinking=True,
20
  )
21
  return result
22