这个实验训练 Deep Q Network(DQN)在 OpenAI Gym 上玩 Atari Breakout 游戏。它在多个进程上运行游戏环境以进行高效采样。
16import numpy as np
17import torch
18
19from labml import tracker, experiment, logger, monit
20from labml.internal.configs.dynamic_hyperparam import FloatDynamicHyperParam
21from labml_helpers.schedule import Piecewise
22from labml_nn.rl.dqn import QFuncLoss
23from labml_nn.rl.dqn.model import Model
24from labml_nn.rl.dqn.replay_buffer import ReplayBuffer
25from labml_nn.rl.game import Worker选择设备
28if torch.cuda.is_available():
29    device = torch.device("cuda:0")
30else:
31    device = torch.device("cpu")将观测值从缩放[0, 255]
到[0, 1]
34def obs_to_torch(obs: np.ndarray) -> torch.Tensor:36    return torch.tensor(obs, dtype=torch.float32, device=device) / 255.39class Trainer:44    def __init__(self, *,
45                 updates: int, epochs: int,
46                 n_workers: int, worker_steps: int, mini_batch_size: int,
47                 update_target_model: int,
48                 learning_rate: FloatDynamicHyperParam,
49                 ):工作人员人数
51        self.n_workers = n_workers每次更新时采样的步骤
53        self.worker_steps = worker_steps训练迭代次数
55        self.train_epochs = epochs更新次数
58        self.updates = updates用于训练的微型批次的大小
60        self.mini_batch_size = mini_batch_size每 250 次更新一次目标网络
63        self.update_target_model = update_target_model学习率
66        self.learning_rate = learning_rate作为更新函数的探索
69        self.exploration_coefficient = Piecewise(
70            [
71                (0, 1.0),
72                (25_000, 0.1),
73                (self.updates / 2, 0.01)
74            ], outside_value=0.01)作为更新函数的重播缓冲区
77        self.prioritized_replay_beta = Piecewise(
78            [
79                (0, 0.4),
80                (self.updates, 1)
81            ], outside_value=1)重播缓冲区。重播缓冲区的容量必须是 2 的幂。
84        self.replay_buffer = ReplayBuffer(2 ** 14, 0.6)采样和训练模型
87        self.model = Model().to(device)要获取的目标模型
89        self.target_model = Model().to(device)创建工作人员
92        self.workers = [Worker(47 + i) for i in range(self.n_workers)]初始化观测值的张量
95        self.obs = np.zeros((self.n_workers, 4, 84, 84), dtype=np.uint8)重置工作人员
98        for worker in self.workers:
99            worker.child.send(("reset", None))获得初步观测值
102        for i, worker in enumerate(self.workers):
103            self.obs[i] = worker.child.recv()损失函数
106        self.loss_func = QFuncLoss(0.99)优化者
108        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=2.5e-4)110    def _sample_action(self, q_value: torch.Tensor, exploration_coefficient: float):采样不需要渐变
120        with torch.no_grad():采样具有最高 Q 值的动作。这是贪婪的行动。
122            greedy_action = torch.argmax(q_value, dim=-1)统一采样和行动
124            random_action = torch.randint(q_value.shape[-1], greedy_action.shape, device=q_value.device)选择贪婪动作还是随机动作
126            is_choose_rand = torch.rand(greedy_action.shape, device=q_value.device) < exploration_coefficient根据以下内容选择操作is_choose_rand
128            return torch.where(is_choose_rand, random_action, greedy_action).cpu().numpy()130    def sample(self, exploration_coefficient: float):这不需要渐变
134        with torch.no_grad():样本worker_steps
136            for t in range(self.worker_steps):获取当前观测值的 Q_Values
138                q_value = self.model(obs_to_torch(self.obs))操作示例
140                actions = self._sample_action(q_value, exploration_coefficient)对每个工作器运行采样操作
143                for w, worker in enumerate(self.workers):
144                    worker.child.send(("step", actions[w]))收集每位员工的信息
147                for w, worker in enumerate(self.workers):执行操作后获取结果
149                    next_obs, reward, done, info = worker.child.recv()将过渡添加到重播缓冲区
152                    self.replay_buffer.add(self.obs[w], actions[w], reward, next_obs, done)更新剧集信息。收集剧集信息,如果剧集结束则可用;这包括总奖励和剧集时长——看看Game
它是如何运作的。
158                    if info:
159                        tracker.add('reward', info['reward'])
160                        tracker.add('length', info['length'])更新当前观测值
163                    self.obs[w] = next_obs165    def train(self, beta: float):169        for _ in range(self.train_epochs):来自优先级重播缓冲区的样本
171            samples = self.replay_buffer.sample(self.mini_batch_size, beta)获取预测的 Q 值
173            q_value = self.model(obs_to_torch(samples['obs']))获取 “双 Q 学习” 的下一个状态的 Q 值。梯度不应该为这些传播
177            with torch.no_grad():得到
179                double_q_value = self.model(obs_to_torch(samples['next_obs']))得到
181                target_q_value = self.target_model(obs_to_torch(samples['next_obs']))计算时差 (TD) 误差和损失。
184            td_errors, loss = self.loss_func(q_value,
185                                             q_value.new_tensor(samples['action']),
186                                             double_q_value, target_q_value,
187                                             q_value.new_tensor(samples['done']),
188                                             q_value.new_tensor(samples['reward']),
189                                             q_value.new_tensor(samples['weights']))计算重播缓冲区的优先级
192            new_priorities = np.abs(td_errors.cpu().numpy()) + 1e-6更新重播缓冲区优先级
194            self.replay_buffer.update_priorities(samples['indexes'], new_priorities)设置学习速率
197            for pg in self.optimizer.param_groups:
198                pg['lr'] = self.learning_rate()将先前计算的梯度归零
200            self.optimizer.zero_grad()计算梯度
202            loss.backward()剪辑渐变
204            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=0.5)根据渐变更新参数
206            self.optimizer.step()208    def run_training_loop(self):最近 100 集信息
214        tracker.set_queue('reward', 100, True)
215        tracker.set_queue('length', 100, True)最初复制到目标网络
218        self.target_model.load_state_dict(self.model.state_dict())
219
220        for update in monit.loop(self.updates):,勘探分数
222            exploration = self.exploration_coefficient(update)
223            tracker.add('exploration', exploration)用于优先重播
225            beta = self.prioritized_replay_beta(update)
226            tracker.add('beta', beta)当前政策的示例
229            self.sample(exploration)缓冲区满后开始训练
232            if self.replay_buffer.is_full():训练模型
234                self.train(beta)定期更新目标网络
237                if update % self.update_target_model == 0:
238                    self.target_model.load_state_dict(self.model.state_dict())保存跟踪的指标。
241            tracker.save()定期在屏幕上添加新行
243            if (update + 1) % 1_000 == 0:
244                logger.log()246    def destroy(self):251        for worker in self.workers:
252            worker.child.send(("close", None))255def main():创建实验
257    experiment.create(name='dqn')配置
260    configs = {更新次数
262        'updates': 1_000_000,采样数据训练模型的周期数。
264        'epochs': 8,工作进程数
266        'n_workers': 8,单次更新的每个进程要运行的步骤数
268        'worker_steps': 4,小批量
270        'mini_batch_size': 32,目标模型更新间隔
272        'update_target_model': 250,学习率。
274        'learning_rate': FloatDynamicHyperParam(1e-4, (0, 1e-3)),
275    }配置
278    experiment.configs(configs)初始化训练器
281    m = Trainer(**configs)运行并监控实验
283    with experiment.start():
284        m.run_training_loop()阻止工人
286    m.destroy()290if __name__ == "__main__":
291    main()