This experiment trains a Deep Q Network (DQN) to play Atari Breakout game on OpenAI Gym. It runs the game environments on multiple processes to sample efficiently.
13import numpy as np
14import torch
15
16from labml import tracker, experiment, logger, monit
17from labml_helpers.schedule import Piecewise
18from labml_nn.rl.dqn import QFuncLoss
19from labml_nn.rl.dqn.model import Model
20from labml_nn.rl.dqn.replay_buffer import ReplayBuffer
21from labml_nn.rl.game import Worker
Select device
24if torch.cuda.is_available():
25 device = torch.device("cuda:0")
26else:
27 device = torch.device("cpu")
Scale observations from [0, 255]
to [0, 1]
30def obs_to_torch(obs: np.ndarray) -> torch.Tensor:
32 return torch.tensor(obs, dtype=torch.float32, device=device) / 255.
35class Trainer:
40 def __init__(self):
number of workers
44 self.n_workers = 8
steps sampled on each update
46 self.worker_steps = 4
number of training iterations
48 self.train_epochs = 8
number of updates
51 self.updates = 1_000_000
size of mini batch for training
53 self.mini_batch_size = 32
exploration as a function of updates
56 self.exploration_coefficient = Piecewise(
57 [
58 (0, 1.0),
59 (25_000, 0.1),
60 (self.updates / 2, 0.01)
61 ], outside_value=0.01)
update target network every 250 update
64 self.update_target_model = 250
$\beta$ for replay buffer as a function of updates
67 self.prioritized_replay_beta = Piecewise(
68 [
69 (0, 0.4),
70 (self.updates, 1)
71 ], outside_value=1)
Replay buffer with $\alpha = 0.6$. Capacity of the replay buffer must be a power of 2.
74 self.replay_buffer = ReplayBuffer(2 ** 14, 0.6)
Model for sampling and training
77 self.model = Model().to(device)
target model to get $\color{orange}Q(s’;\color{orange}{\theta_i^{-}})$
79 self.target_model = Model().to(device)
create workers
82 self.workers = [Worker(47 + i) for i in range(self.n_workers)]
initialize tensors for observations
85 self.obs = np.zeros((self.n_workers, 4, 84, 84), dtype=np.uint8)
86 for worker in self.workers:
87 worker.child.send(("reset", None))
88 for i, worker in enumerate(self.workers):
89 self.obs[i] = worker.child.recv()
loss function
92 self.loss_func = QFuncLoss(0.99)
optimizer
94 self.optimizer = torch.optim.Adam(self.model.parameters(), lr=2.5e-4)
When sampling actions we use a $\epsilon$-greedy strategy, where we
take a greedy action with probabiliy $1 - \epsilon$ and
take a random action with probability $\epsilon$.
We refer to $\epsilon$ as exploration_coefficient
.
96 def _sample_action(self, q_value: torch.Tensor, exploration_coefficient: float):
Sampling doesn’t need gradients
106 with torch.no_grad():
Sample the action with highest Q-value. This is the greedy action.
108 greedy_action = torch.argmax(q_value, dim=-1)
Uniformly sample and action
110 random_action = torch.randint(q_value.shape[-1], greedy_action.shape, device=q_value.device)
Whether to chose greedy action or the random action
112 is_choose_rand = torch.rand(greedy_action.shape, device=q_value.device) < exploration_coefficient
Pick the action based on is_choose_rand
114 return torch.where(is_choose_rand, random_action, greedy_action).cpu().numpy()
116 def sample(self, exploration_coefficient: float):
This doesn’t need gradients
120 with torch.no_grad():
Sample worker_steps
122 for t in range(self.worker_steps):
Get Q_values for the current observation
124 q_value = self.model(obs_to_torch(self.obs))
Sample actions
126 actions = self._sample_action(q_value, exploration_coefficient)
Run sampled actions on each worker
129 for w, worker in enumerate(self.workers):
130 worker.child.send(("step", actions[w]))
Collect information from each worker
133 for w, worker in enumerate(self.workers):
Get results after executing the actions
135 next_obs, reward, done, info = worker.child.recv()
Add transition to replay buffer
138 self.replay_buffer.add(self.obs[w], actions[w], reward, next_obs, done)
update episode information.
collect episode info, which is available if an episode finished;
this includes total reward and length of the episode -
look at Game
to see how it works.
144 if info:
145 tracker.add('reward', info['reward'])
146 tracker.add('length', info['length'])
update current observation
149 self.obs[w] = next_obs
151 def train(self, beta: float):
155 for _ in range(self.train_epochs):
Sample from priority replay buffer
157 samples = self.replay_buffer.sample(self.mini_batch_size, beta)
Get the predicted Q-value
159 q_value = self.model(obs_to_torch(samples['obs']))
Get the Q-values of the next state for Double Q-learning. Gradients shouldn’t propagate for these
163 with torch.no_grad():
Get $\color{cyan}Q(s’;\color{cyan}{\theta_i})$
165 double_q_value = self.model(obs_to_torch(samples['next_obs']))
Get $\color{orange}Q(s’;\color{orange}{\theta_i^{-}})$
167 target_q_value = self.target_model(obs_to_torch(samples['next_obs']))
Compute Temporal Difference (TD) errors, $\delta$, and the loss, $\mathcal{L}(\theta)$.
170 td_errors, loss = self.loss_func(q_value,
171 q_value.new_tensor(samples['action']),
172 double_q_value, target_q_value,
173 q_value.new_tensor(samples['done']),
174 q_value.new_tensor(samples['reward']),
175 q_value.new_tensor(samples['weights']))
Calculate priorities for replay buffer $p_i = |\delta_i| + \epsilon$
178 new_priorities = np.abs(td_errors.cpu().numpy()) + 1e-6
Update replay buffer priorities
180 self.replay_buffer.update_priorities(samples['indexes'], new_priorities)
Zero out the previously calculated gradients
183 self.optimizer.zero_grad()
Calculate gradients
185 loss.backward()
Clip gradients
187 torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=0.5)
Update parameters based on gradients
189 self.optimizer.step()
191 def run_training_loop(self):
Last 100 episode information
197 tracker.set_queue('reward', 100, True)
198 tracker.set_queue('length', 100, True)
Copy to target network initially
201 self.target_model.load_state_dict(self.model.state_dict())
202
203 for update in monit.loop(self.updates):
$\epsilon$, exploration fraction
205 exploration = self.exploration_coefficient(update)
206 tracker.add('exploration', exploration)
$\beta$ for prioritized replay
208 beta = self.prioritized_replay_beta(update)
209 tracker.add('beta', beta)
Sample with current policy
212 self.sample(exploration)
Start training after the buffer is full
215 if self.replay_buffer.is_full():
Train the model
217 self.train(beta)
Periodically update target network
220 if update % self.update_target_model == 0:
221 self.target_model.load_state_dict(self.model.state_dict())
Save tracked indicators.
224 tracker.save()
Add a new line to the screen periodically
226 if (update + 1) % 1_000 == 0:
227 logger.log()
229 def destroy(self):
234 for worker in self.workers:
235 worker.child.send(("close", None))
238def main():
Create the experiment
240 experiment.create(name='dqn')
Initialize the trainer
242 m = Trainer()
Run and monitor the experiment
244 with experiment.start():
245 m.run_training_loop()
Stop the workers
247 m.destroy()
251if __name__ == "__main__":
252 main()