This experiment trains a Deep Q Network (DQN) to play Atari Breakout game on OpenAI Gym. It runs the game environments on multiple processes to sample efficiently.
16import numpy as np
17import torch
18
19from labml import tracker, experiment, logger, monit
20from labml.internal.configs.dynamic_hyperparam import FloatDynamicHyperParam
21from labml_helpers.schedule import Piecewise
22from labml_nn.rl.dqn import QFuncLoss
23from labml_nn.rl.dqn.model import Model
24from labml_nn.rl.dqn.replay_buffer import ReplayBuffer
25from labml_nn.rl.game import WorkerSelect device
28if torch.cuda.is_available():
29 device = torch.device("cuda:0")
30else:
31 device = torch.device("cpu")Scale observations from [0, 255]
to [0, 1]
34def obs_to_torch(obs: np.ndarray) -> torch.Tensor:36 return torch.tensor(obs, dtype=torch.float32, device=device) / 255.39class Trainer:44 def __init__(self, *,
45 updates: int, epochs: int,
46 n_workers: int, worker_steps: int, mini_batch_size: int,
47 update_target_model: int,
48 learning_rate: FloatDynamicHyperParam,
49 ):number of workers
51 self.n_workers = n_workerssteps sampled on each update
53 self.worker_steps = worker_stepsnumber of training iterations
55 self.train_epochs = epochsnumber of updates
58 self.updates = updatessize of mini batch for training
60 self.mini_batch_size = mini_batch_sizeupdate target network every 250 update
63 self.update_target_model = update_target_modellearning rate
66 self.learning_rate = learning_rateexploration as a function of updates
69 self.exploration_coefficient = Piecewise(
70 [
71 (0, 1.0),
72 (25_000, 0.1),
73 (self.updates / 2, 0.01)
74 ], outside_value=0.01)for replay buffer as a function of updates
77 self.prioritized_replay_beta = Piecewise(
78 [
79 (0, 0.4),
80 (self.updates, 1)
81 ], outside_value=1)Replay buffer with . Capacity of the replay buffer must be a power of 2.
84 self.replay_buffer = ReplayBuffer(2 ** 14, 0.6)Model for sampling and training
87 self.model = Model().to(device)target model to get
89 self.target_model = Model().to(device)create workers
92 self.workers = [Worker(47 + i) for i in range(self.n_workers)]initialize tensors for observations
95 self.obs = np.zeros((self.n_workers, 4, 84, 84), dtype=np.uint8)reset the workers
98 for worker in self.workers:
99 worker.child.send(("reset", None))get the initial observations
102 for i, worker in enumerate(self.workers):
103 self.obs[i] = worker.child.recv()loss function
106 self.loss_func = QFuncLoss(0.99)optimizer
108 self.optimizer = torch.optim.Adam(self.model.parameters(), lr=2.5e-4)When sampling actions we use a -greedy strategy, where we take a greedy action with probabiliy and take a random action with probability . We refer to as exploration_coefficient
.
110 def _sample_action(self, q_value: torch.Tensor, exploration_coefficient: float):Sampling doesn't need gradients
120 with torch.no_grad():Sample the action with highest Q-value. This is the greedy action.
122 greedy_action = torch.argmax(q_value, dim=-1)Uniformly sample and action
124 random_action = torch.randint(q_value.shape[-1], greedy_action.shape, device=q_value.device)Whether to chose greedy action or the random action
126 is_choose_rand = torch.rand(greedy_action.shape, device=q_value.device) < exploration_coefficientPick the action based on is_choose_rand
128 return torch.where(is_choose_rand, random_action, greedy_action).cpu().numpy()130 def sample(self, exploration_coefficient: float):This doesn't need gradients
134 with torch.no_grad():Sample worker_steps
136 for t in range(self.worker_steps):Get Q_values for the current observation
138 q_value = self.model(obs_to_torch(self.obs))Sample actions
140 actions = self._sample_action(q_value, exploration_coefficient)Run sampled actions on each worker
143 for w, worker in enumerate(self.workers):
144 worker.child.send(("step", actions[w]))Collect information from each worker
147 for w, worker in enumerate(self.workers):Get results after executing the actions
149 next_obs, reward, done, info = worker.child.recv()Add transition to replay buffer
152 self.replay_buffer.add(self.obs[w], actions[w], reward, next_obs, done)update episode information. collect episode info, which is available if an episode finished; this includes total reward and length of the episode - look at Game
to see how it works.
158 if info:
159 tracker.add('reward', info['reward'])
160 tracker.add('length', info['length'])update current observation
163 self.obs[w] = next_obs165 def train(self, beta: float):169 for _ in range(self.train_epochs):Sample from priority replay buffer
171 samples = self.replay_buffer.sample(self.mini_batch_size, beta)Get the predicted Q-value
173 q_value = self.model(obs_to_torch(samples['obs']))Get the Q-values of the next state for Double Q-learning. Gradients shouldn't propagate for these
177 with torch.no_grad():Get
179 double_q_value = self.model(obs_to_torch(samples['next_obs']))Get
181 target_q_value = self.target_model(obs_to_torch(samples['next_obs']))Compute Temporal Difference (TD) errors, , and the loss, .
184 td_errors, loss = self.loss_func(q_value,
185 q_value.new_tensor(samples['action']),
186 double_q_value, target_q_value,
187 q_value.new_tensor(samples['done']),
188 q_value.new_tensor(samples['reward']),
189 q_value.new_tensor(samples['weights']))Calculate priorities for replay buffer
192 new_priorities = np.abs(td_errors.cpu().numpy()) + 1e-6Update replay buffer priorities
194 self.replay_buffer.update_priorities(samples['indexes'], new_priorities)Set learning rate
197 for pg in self.optimizer.param_groups:
198 pg['lr'] = self.learning_rate()Zero out the previously calculated gradients
200 self.optimizer.zero_grad()Calculate gradients
202 loss.backward()Clip gradients
204 torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=0.5)Update parameters based on gradients
206 self.optimizer.step()208 def run_training_loop(self):Last 100 episode information
214 tracker.set_queue('reward', 100, True)
215 tracker.set_queue('length', 100, True)Copy to target network initially
218 self.target_model.load_state_dict(self.model.state_dict())
219
220 for update in monit.loop(self.updates):, exploration fraction
222 exploration = self.exploration_coefficient(update)
223 tracker.add('exploration', exploration)for prioritized replay
225 beta = self.prioritized_replay_beta(update)
226 tracker.add('beta', beta)Sample with current policy
229 self.sample(exploration)Start training after the buffer is full
232 if self.replay_buffer.is_full():Train the model
234 self.train(beta)Periodically update target network
237 if update % self.update_target_model == 0:
238 self.target_model.load_state_dict(self.model.state_dict())Save tracked indicators.
241 tracker.save()Add a new line to the screen periodically
243 if (update + 1) % 1_000 == 0:
244 logger.log()246 def destroy(self):251 for worker in self.workers:
252 worker.child.send(("close", None))255def main():Create the experiment
257 experiment.create(name='dqn')Configurations
260 configs = {Number of updates
262 'updates': 1_000_000,Number of epochs to train the model with sampled data.
264 'epochs': 8,Number of worker processes
266 'n_workers': 8,Number of steps to run on each process for a single update
268 'worker_steps': 4,Mini batch size
270 'mini_batch_size': 32,Target model updating interval
272 'update_target_model': 250,Learning rate.
274 'learning_rate': FloatDynamicHyperParam(1e-4, (0, 1e-3)),
275 }Configurations
278 experiment.configs(configs)Initialize the trainer
281 m = Trainer(**configs)Run and monitor the experiment
283 with experiment.start():
284 m.run_training_loop()Stop the workers
286 m.destroy()290if __name__ == "__main__":
291 main()