This experiment trains a Deep Q Network (DQN) to play Atari Breakout game on OpenAI Gym. It runs the game environments on multiple processes to sample efficiently.
13import numpy as np
14import torch
15
16from labml import tracker, experiment, logger, monit
17from labml_helpers.schedule import Piecewise
18from labml_nn.rl.dqn import QFuncLoss
19from labml_nn.rl.dqn.model import Model
20from labml_nn.rl.dqn.replay_buffer import ReplayBuffer
21from labml_nn.rl.game import WorkerSelect device
24if torch.cuda.is_available():
25    device = torch.device("cuda:0")
26else:
27    device = torch.device("cpu")Scale observations from [0, 255] to [0, 1]
30def obs_to_torch(obs: np.ndarray) -> torch.Tensor:32    return torch.tensor(obs, dtype=torch.float32, device=device) / 255.35class Trainer:40    def __init__(self):number of workers
44        self.n_workers = 8steps sampled on each update
46        self.worker_steps = 4number of training iterations
48        self.train_epochs = 8number of updates
51        self.updates = 1_000_000size of mini batch for training
53        self.mini_batch_size = 32exploration as a function of updates
56        self.exploration_coefficient = Piecewise(
57            [
58                (0, 1.0),
59                (25_000, 0.1),
60                (self.updates / 2, 0.01)
61            ], outside_value=0.01)update target network every 250 update
64        self.update_target_model = 250$\beta$ for replay buffer as a function of updates
67        self.prioritized_replay_beta = Piecewise(
68            [
69                (0, 0.4),
70                (self.updates, 1)
71            ], outside_value=1)Replay buffer with $\alpha = 0.6$. Capacity of the replay buffer must be a power of 2.
74        self.replay_buffer = ReplayBuffer(2 ** 14, 0.6)Model for sampling and training
77        self.model = Model().to(device)target model to get $\color{orange}Q(s’;\color{orange}{\theta_i^{-}})$
79        self.target_model = Model().to(device)create workers
82        self.workers = [Worker(47 + i) for i in range(self.n_workers)]initialize tensors for observations
85        self.obs = np.zeros((self.n_workers, 4, 84, 84), dtype=np.uint8)
86        for worker in self.workers:
87            worker.child.send(("reset", None))
88        for i, worker in enumerate(self.workers):
89            self.obs[i] = worker.child.recv()loss function
92        self.loss_func = QFuncLoss(0.99)optimizer
94        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=2.5e-4)When sampling actions we use a $\epsilon$-greedy strategy, where we
take a greedy action with probabiliy $1 - \epsilon$ and
take a random action with probability $\epsilon$.
We refer to $\epsilon$ as exploration_coefficient.
96    def _sample_action(self, q_value: torch.Tensor, exploration_coefficient: float):Sampling doesn’t need gradients
106        with torch.no_grad():Sample the action with highest Q-value. This is the greedy action.
108            greedy_action = torch.argmax(q_value, dim=-1)Uniformly sample and action
110            random_action = torch.randint(q_value.shape[-1], greedy_action.shape, device=q_value.device)Whether to chose greedy action or the random action
112            is_choose_rand = torch.rand(greedy_action.shape, device=q_value.device) < exploration_coefficientPick the action based on is_choose_rand
114            return torch.where(is_choose_rand, random_action, greedy_action).cpu().numpy()116    def sample(self, exploration_coefficient: float):This doesn’t need gradients
120        with torch.no_grad():Sample worker_steps
122            for t in range(self.worker_steps):Get Q_values for the current observation
124                q_value = self.model(obs_to_torch(self.obs))Sample actions
126                actions = self._sample_action(q_value, exploration_coefficient)Run sampled actions on each worker
129                for w, worker in enumerate(self.workers):
130                    worker.child.send(("step", actions[w]))Collect information from each worker
133                for w, worker in enumerate(self.workers):Get results after executing the actions
135                    next_obs, reward, done, info = worker.child.recv()Add transition to replay buffer
138                    self.replay_buffer.add(self.obs[w], actions[w], reward, next_obs, done)update episode information
collect episode info, which is available if an episode finished;
 this includes total reward and length of the episode -
 look at Game to see how it works.
144                    if info:
145                        tracker.add('reward', info['reward'])
146                        tracker.add('length', info['length'])update current observation
149                    self.obs[w] = next_obs151    def train(self, beta: float):155        for _ in range(self.train_epochs):Sample from priority replay buffer
157            samples = self.replay_buffer.sample(self.mini_batch_size, beta)Get the predicted Q-value
159            q_value = self.model(obs_to_torch(samples['obs']))Get the Q-values of the next state for Double Q-learning. Gradients shouldn’t propagate for these
163            with torch.no_grad():Get $\color{cyan}Q(s’;\color{cyan}{\theta_i})$
165                double_q_value = self.model(obs_to_torch(samples['next_obs']))Get $\color{orange}Q(s’;\color{orange}{\theta_i^{-}})$
167                target_q_value = self.target_model(obs_to_torch(samples['next_obs']))Compute Temporal Difference (TD) errors, $\delta$, and the loss, $\mathcal{L}(\theta)$.
170            td_errors, loss = self.loss_func(q_value,
171                                             q_value.new_tensor(samples['action']),
172                                             double_q_value, target_q_value,
173                                             q_value.new_tensor(samples['done']),
174                                             q_value.new_tensor(samples['reward']),
175                                             q_value.new_tensor(samples['weights']))Calculate priorities for replay buffer $p_i = |\delta_i| + \epsilon$
178            new_priorities = np.abs(td_errors.cpu().numpy()) + 1e-6Update replay buffer priorities
180            self.replay_buffer.update_priorities(samples['indexes'], new_priorities)Zero out the previously calculated gradients
183            self.optimizer.zero_grad()Calculate gradients
185            loss.backward()Clip gradients
187            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=0.5)Update parameters based on gradients
189            self.optimizer.step()191    def run_training_loop(self):Last 100 episode information
197        tracker.set_queue('reward', 100, True)
198        tracker.set_queue('length', 100, True)Copy to target network initially
201        self.target_model.load_state_dict(self.model.state_dict())
202
203        for update in monit.loop(self.updates):$\epsilon$, exploration fraction
205            exploration = self.exploration_coefficient(update)
206            tracker.add('exploration', exploration)$\beta$ for prioritized replay
208            beta = self.prioritized_replay_beta(update)
209            tracker.add('beta', beta)Sample with current policy
212            self.sample(exploration)Start training after the buffer is full
215            if self.replay_buffer.is_full():Train the model
217                self.train(beta)Periodically update target network
220                if update % self.update_target_model == 0:
221                    self.target_model.load_state_dict(self.model.state_dict())Save tracked indicators.
224            tracker.save()Add a new line to the screen periodically
226            if (update + 1) % 1_000 == 0:
227                logger.log()229    def destroy(self):234        for worker in self.workers:
235            worker.child.send(("close", None))238def main():Create the experiment
240    experiment.create(name='dqn')Initialize the trainer
242    m = Trainer()Run and monitor the experiment
244    with experiment.start():
245        m.run_training_loop()Stop the workers
247    m.destroy()251if __name__ == "__main__":
252    main()