This experiment trains a Deep Q Network (DQN) to play Atari Breakout game on OpenAI Gym. It runs the game environments on multiple processes to sample efficiently.
15import numpy as np
16import torch
17
18from labml import tracker, experiment, logger, monit
19from labml.internal.configs.dynamic_hyperparam import FloatDynamicHyperParam
20from labml_nn.helpers.schedule import Piecewise
21from labml_nn.rl.dqn import QFuncLoss
22from labml_nn.rl.dqn.model import Model
23from labml_nn.rl.dqn.replay_buffer import ReplayBuffer
24from labml_nn.rl.game import WorkerSelect device
27if torch.cuda.is_available():
28    device = torch.device("cuda:0")
29else:
30    device = torch.device("cpu")Scale observations from [0, 255]
 to [0, 1]
 
33def obs_to_torch(obs: np.ndarray) -> torch.Tensor:35    return torch.tensor(obs, dtype=torch.float32, device=device) / 255.38class Trainer:43    def __init__(self, *,
44                 updates: int, epochs: int,
45                 n_workers: int, worker_steps: int, mini_batch_size: int,
46                 update_target_model: int,
47                 learning_rate: FloatDynamicHyperParam,
48                 ):number of workers
50        self.n_workers = n_workerssteps sampled on each update
52        self.worker_steps = worker_stepsnumber of training iterations
54        self.train_epochs = epochsnumber of updates
57        self.updates = updatessize of mini batch for training
59        self.mini_batch_size = mini_batch_sizeupdate target network every 250 update
62        self.update_target_model = update_target_modellearning rate
65        self.learning_rate = learning_rateexploration as a function of updates
68        self.exploration_coefficient = Piecewise(
69            [
70                (0, 1.0),
71                (25_000, 0.1),
72                (self.updates / 2, 0.01)
73            ], outside_value=0.01)for replay buffer as a function of updates
76        self.prioritized_replay_beta = Piecewise(
77            [
78                (0, 0.4),
79                (self.updates, 1)
80            ], outside_value=1)Replay buffer with . Capacity of the replay buffer must be a power of 2.
83        self.replay_buffer = ReplayBuffer(2 ** 14, 0.6)Model for sampling and training
86        self.model = Model().to(device)target model to get
88        self.target_model = Model().to(device)create workers
91        self.workers = [Worker(47 + i) for i in range(self.n_workers)]initialize tensors for observations
94        self.obs = np.zeros((self.n_workers, 4, 84, 84), dtype=np.uint8)reset the workers
97        for worker in self.workers:
98            worker.child.send(("reset", None))get the initial observations
101        for i, worker in enumerate(self.workers):
102            self.obs[i] = worker.child.recv()loss function
105        self.loss_func = QFuncLoss(0.99)optimizer
107        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=2.5e-4)When sampling actions we use a -greedy strategy, where we take a greedy action with probabiliy  and take a random action with probability . We refer to  as exploration_coefficient
.
109    def _sample_action(self, q_value: torch.Tensor, exploration_coefficient: float):Sampling doesn't need gradients
119        with torch.no_grad():Sample the action with highest Q-value. This is the greedy action.
121            greedy_action = torch.argmax(q_value, dim=-1)Uniformly sample and action
123            random_action = torch.randint(q_value.shape[-1], greedy_action.shape, device=q_value.device)Whether to chose greedy action or the random action
125            is_choose_rand = torch.rand(greedy_action.shape, device=q_value.device) < exploration_coefficientPick the action based on is_choose_rand
 
127            return torch.where(is_choose_rand, random_action, greedy_action).cpu().numpy()129    def sample(self, exploration_coefficient: float):This doesn't need gradients
133        with torch.no_grad():Sample worker_steps
 
135            for t in range(self.worker_steps):Get Q_values for the current observation
137                q_value = self.model(obs_to_torch(self.obs))Sample actions
139                actions = self._sample_action(q_value, exploration_coefficient)Run sampled actions on each worker
142                for w, worker in enumerate(self.workers):
143                    worker.child.send(("step", actions[w]))Collect information from each worker
146                for w, worker in enumerate(self.workers):Get results after executing the actions
148                    next_obs, reward, done, info = worker.child.recv()Add transition to replay buffer
151                    self.replay_buffer.add(self.obs[w], actions[w], reward, next_obs, done)update episode information. collect episode info, which is available if an episode finished;  this includes total reward and length of the episode -  look at Game
 to see how it works. 
157                    if info:
158                        tracker.add('reward', info['reward'])
159                        tracker.add('length', info['length'])update current observation
162                    self.obs[w] = next_obs164    def train(self, beta: float):168        for _ in range(self.train_epochs):Sample from priority replay buffer
170            samples = self.replay_buffer.sample(self.mini_batch_size, beta)Get the predicted Q-value
172            q_value = self.model(obs_to_torch(samples['obs']))Get the Q-values of the next state for Double Q-learning. Gradients shouldn't propagate for these
176            with torch.no_grad():Get
178                double_q_value = self.model(obs_to_torch(samples['next_obs']))Get
180                target_q_value = self.target_model(obs_to_torch(samples['next_obs']))Compute Temporal Difference (TD) errors, , and the loss, .
183            td_errors, loss = self.loss_func(q_value,
184                                             q_value.new_tensor(samples['action']),
185                                             double_q_value, target_q_value,
186                                             q_value.new_tensor(samples['done']),
187                                             q_value.new_tensor(samples['reward']),
188                                             q_value.new_tensor(samples['weights']))Calculate priorities for replay buffer
191            new_priorities = np.abs(td_errors.cpu().numpy()) + 1e-6Update replay buffer priorities
193            self.replay_buffer.update_priorities(samples['indexes'], new_priorities)Set learning rate
196            for pg in self.optimizer.param_groups:
197                pg['lr'] = self.learning_rate()Zero out the previously calculated gradients
199            self.optimizer.zero_grad()Calculate gradients
201            loss.backward()Clip gradients
203            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=0.5)Update parameters based on gradients
205            self.optimizer.step()207    def run_training_loop(self):Last 100 episode information
213        tracker.set_queue('reward', 100, True)
214        tracker.set_queue('length', 100, True)Copy to target network initially
217        self.target_model.load_state_dict(self.model.state_dict())
218
219        for update in monit.loop(self.updates):, exploration fraction
221            exploration = self.exploration_coefficient(update)
222            tracker.add('exploration', exploration)for prioritized replay
224            beta = self.prioritized_replay_beta(update)
225            tracker.add('beta', beta)Sample with current policy
228            self.sample(exploration)Start training after the buffer is full
231            if self.replay_buffer.is_full():Train the model
233                self.train(beta)Periodically update target network
236                if update % self.update_target_model == 0:
237                    self.target_model.load_state_dict(self.model.state_dict())Save tracked indicators.
240            tracker.save()Add a new line to the screen periodically
242            if (update + 1) % 1_000 == 0:
243                logger.log()245    def destroy(self):250        for worker in self.workers:
251            worker.child.send(("close", None))254def main():Create the experiment
256    experiment.create(name='dqn')Configurations
259    configs = {Number of updates
261        'updates': 1_000_000,Number of epochs to train the model with sampled data.
263        'epochs': 8,Number of worker processes
265        'n_workers': 8,Number of steps to run on each process for a single update
267        'worker_steps': 4,Mini batch size
269        'mini_batch_size': 32,Target model updating interval
271        'update_target_model': 250,Learning rate.
273        'learning_rate': FloatDynamicHyperParam(1e-4, (0, 1e-3)),
274    }Configurations
277    experiment.configs(configs)Initialize the trainer
280    m = Trainer(**configs)Run and monitor the experiment
282    with experiment.start():
283        m.run_training_loop()Stop the workers
285    m.destroy()289if __name__ == "__main__":
290    main()