mirror of
				https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
				synced 2025-10-31 18:58:43 +08:00 
			
		
		
		
	📚 annotations
This commit is contained in:
		| @ -1,9 +1,8 @@ | ||||
| """ | ||||
| \( | ||||
|    \def\hl1#1{{\color{orange}{#1}}} | ||||
|    \def\blue#1{{\color{cyan}{#1}}} | ||||
|    \def\green#1{{\color{yellowgreen}{#1}}} | ||||
| \) | ||||
| # DQN Experiment with Atari Breakout | ||||
|  | ||||
| This experiment trains a Deep Q Network (DQN) to play Atari Breakout game on OpenAI Gym. | ||||
| It runs the [game environments on multiple processes](../game.html) to sample efficiently. | ||||
| """ | ||||
|  | ||||
| import numpy as np | ||||
| @ -16,6 +15,7 @@ from labml_nn.rl.dqn.model import Model | ||||
| from labml_nn.rl.dqn.replay_buffer import ReplayBuffer | ||||
| from labml_nn.rl.game import Worker | ||||
|  | ||||
| # Select device | ||||
| if torch.cuda.is_available(): | ||||
|     device = torch.device("cuda:0") | ||||
| else: | ||||
| @ -29,17 +29,10 @@ def obs_to_torch(obs: np.ndarray) -> torch.Tensor: | ||||
|  | ||||
| class Trainer: | ||||
|     """ | ||||
|     ## <a name="main"></a>Main class | ||||
|     This class runs the training loop. | ||||
|     It initializes TensorFlow, handles logging and monitoring, | ||||
|      and runs workers as multiple processes. | ||||
|     ## Trainer | ||||
|     """ | ||||
|  | ||||
|     def __init__(self): | ||||
|         """ | ||||
|         ### Initialize | ||||
|         """ | ||||
|  | ||||
|         # #### Configurations | ||||
|  | ||||
|         # number of workers | ||||
| @ -54,7 +47,7 @@ class Trainer: | ||||
|         # size of mini batch for training | ||||
|         self.mini_batch_size = 32 | ||||
|  | ||||
|         # exploration as a function of time step | ||||
|         # exploration as a function of updates | ||||
|         self.exploration_coefficient = Piecewise( | ||||
|             [ | ||||
|                 (0, 1.0), | ||||
| @ -65,20 +58,21 @@ class Trainer: | ||||
|         # update target network every 250 update | ||||
|         self.update_target_model = 250 | ||||
|  | ||||
|         # $\beta$ for replay buffer as a function of time steps | ||||
|         # $\beta$ for replay buffer as a function of updates | ||||
|         self.prioritized_replay_beta = Piecewise( | ||||
|             [ | ||||
|                 (0, 0.4), | ||||
|                 (self.updates, 1) | ||||
|             ], outside_value=1) | ||||
|  | ||||
|         # replay buffer | ||||
|         # replay buffer with $\alpha = 0.6$ | ||||
|         self.replay_buffer = ReplayBuffer(2 ** 14, 0.6) | ||||
|  | ||||
|         # Model for sampling and training | ||||
|         self.model = Model().to(device) | ||||
|         # target model to get $\color{orange}Q(s';\color{orange}{\theta_i^{-}})$ | ||||
|         self.target_model = Model().to(device) | ||||
|  | ||||
|         # last observation for each worker | ||||
|         # create workers | ||||
|         self.workers = [Worker(47 + i) for i in range(self.n_workers)] | ||||
|  | ||||
| @ -89,6 +83,7 @@ class Trainer: | ||||
|         for i, worker in enumerate(self.workers): | ||||
|             self.obs[i] = worker.child.recv() | ||||
|  | ||||
|         # loss function | ||||
|         self.loss_func = QFuncLoss(0.99) | ||||
|         # optimizer | ||||
|         self.optimizer = torch.optim.Adam(self.model.parameters(), lr=2.5e-4) | ||||
| @ -99,44 +94,48 @@ class Trainer: | ||||
|         When sampling actions we use a $\epsilon$-greedy strategy, where we | ||||
|         take a greedy action with probabiliy $1 - \epsilon$ and | ||||
|         take a random action with probability $\epsilon$. | ||||
|         We refer to $\epsilon$ as *exploration*. | ||||
|         We refer to $\epsilon$ as `exploration_coefficient`. | ||||
|         """ | ||||
|  | ||||
|         # Sampling doesn't need gradients | ||||
|         with torch.no_grad(): | ||||
|             # Sample the action with highest Q-value. This is the greedy action. | ||||
|             greedy_action = torch.argmax(q_value, dim=-1) | ||||
|             # Uniformly sample and action | ||||
|             random_action = torch.randint(q_value.shape[-1], greedy_action.shape, device=q_value.device) | ||||
|  | ||||
|             # Whether to chose greedy action or the random action | ||||
|             is_choose_rand = torch.rand(greedy_action.shape, device=q_value.device) < exploration_coefficient | ||||
|  | ||||
|             # Pick the action based on `is_choose_rand` | ||||
|             return torch.where(is_choose_rand, random_action, greedy_action).cpu().numpy() | ||||
|  | ||||
|     def sample(self, exploration_coefficient: float): | ||||
|         """### Sample data""" | ||||
|  | ||||
|         # This doesn't need gradients | ||||
|         with torch.no_grad(): | ||||
|             # sample `SAMPLE_STEPS` | ||||
|             # Sample `worker_steps` | ||||
|             for t in range(self.worker_steps): | ||||
|                 # sample actions | ||||
|                 # Get Q_values for the current observation | ||||
|                 q_value = self.model(obs_to_torch(self.obs)) | ||||
|                 # Sample actions | ||||
|                 actions = self._sample_action(q_value, exploration_coefficient) | ||||
|  | ||||
|                 # run sampled actions on each worker | ||||
|                 # Run sampled actions on each worker | ||||
|                 for w, worker in enumerate(self.workers): | ||||
|                     worker.child.send(("step", actions[w])) | ||||
|  | ||||
|                 # collect information from each worker | ||||
|                 # Collect information from each worker | ||||
|                 for w, worker in enumerate(self.workers): | ||||
|                     # get results after executing the actions | ||||
|                     # Get results after executing the actions | ||||
|                     next_obs, reward, done, info = worker.child.recv() | ||||
|  | ||||
|                     # add transition to replay buffer | ||||
|                     # Add transition to replay buffer | ||||
|                     self.replay_buffer.add(self.obs[w], actions[w], reward, next_obs, done) | ||||
|  | ||||
|                     # update episode information | ||||
|                     # collect episode info, which is available if an episode finished; | ||||
|                     #  this includes total reward and length of the episode - | ||||
|                     #  look at `Game` to see how it works. | ||||
|                     # We also add a game frame to it for monitoring. | ||||
|                     if info: | ||||
|                         tracker.add('reward', info['reward']) | ||||
|                         tracker.add('length', info['length']) | ||||
| @ -145,16 +144,24 @@ class Trainer: | ||||
|                     self.obs[w] = next_obs | ||||
|  | ||||
|     def train(self, beta: float): | ||||
|         """ | ||||
|         ### Train the model | ||||
|         """ | ||||
|         for _ in range(self.train_epochs): | ||||
|             # sample from priority replay buffer | ||||
|             # Sample from priority replay buffer | ||||
|             samples = self.replay_buffer.sample(self.mini_batch_size, beta) | ||||
|             # train network | ||||
|             # Get the predicted Q-value | ||||
|             q_value = self.model(obs_to_torch(samples['obs'])) | ||||
|  | ||||
|             # Get the Q-values of the next state for [Double Q-learning](index.html). | ||||
|             # Gradients shouldn't propagate for these | ||||
|             with torch.no_grad(): | ||||
|                 # Get $\color{cyan}Q(s';\color{cyan}{\theta_i})$ | ||||
|                 double_q_value = self.model(obs_to_torch(samples['next_obs'])) | ||||
|                 # Get $\color{orange}Q(s';\color{orange}{\theta_i^{-}})$ | ||||
|                 target_q_value = self.target_model(obs_to_torch(samples['next_obs'])) | ||||
|  | ||||
|             # Compute Temporal Difference (TD) errors, $\delta$, and the loss, $\mathcal{L}(\theta)$. | ||||
|             td_errors, loss = self.loss_func(q_value, | ||||
|                                              q_value.new_tensor(samples['action']), | ||||
|                                              double_q_value, target_q_value, | ||||
| @ -162,15 +169,18 @@ class Trainer: | ||||
|                                              q_value.new_tensor(samples['reward']), | ||||
|                                              q_value.new_tensor(samples['weights'])) | ||||
|  | ||||
|             # $p_i = |\delta_i| + \epsilon$ | ||||
|             # Calculate priorities for replay buffer $p_i = |\delta_i| + \epsilon$ | ||||
|             new_priorities = np.abs(td_errors.cpu().numpy()) + 1e-6 | ||||
|             # update replay buffer | ||||
|             # Update replay buffer priorities | ||||
|             self.replay_buffer.update_priorities(samples['indexes'], new_priorities) | ||||
|  | ||||
|             # compute gradients | ||||
|             # Zero out the previously calculated gradients | ||||
|             self.optimizer.zero_grad() | ||||
|             # Calculate gradients | ||||
|             loss.backward() | ||||
|             # Clip gradients | ||||
|             torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=0.5) | ||||
|             # Update parameters based on gradients | ||||
|             self.optimizer.step() | ||||
|  | ||||
|     def run_training_loop(self): | ||||
| @ -178,33 +188,36 @@ class Trainer: | ||||
|         ### Run training loop | ||||
|         """ | ||||
|  | ||||
|         # copy to target network initially | ||||
|         self.target_model.load_state_dict(self.model.state_dict()) | ||||
|  | ||||
|         # last 100 episode information | ||||
|         # Last 100 episode information | ||||
|         tracker.set_queue('reward', 100, True) | ||||
|         tracker.set_queue('length', 100, True) | ||||
|  | ||||
|         # Copy to target network initially | ||||
|         self.target_model.load_state_dict(self.model.state_dict()) | ||||
|  | ||||
|         for update in monit.loop(self.updates): | ||||
|             # $\epsilon$, exploration fraction | ||||
|             exploration = self.exploration_coefficient(update) | ||||
|             tracker.add('exploration', exploration) | ||||
|             # $\beta$ for priority replay | ||||
|             # $\beta$ for prioritized replay | ||||
|             beta = self.prioritized_replay_beta(update) | ||||
|             tracker.add('beta', beta) | ||||
|  | ||||
|             # sample with current policy | ||||
|             # Sample with current policy | ||||
|             self.sample(exploration) | ||||
|  | ||||
|             # Start training after the buffer is full | ||||
|             if self.replay_buffer.is_full(): | ||||
|                 # train the model | ||||
|                 # Train the model | ||||
|                 self.train(beta) | ||||
|  | ||||
|                 # periodically update target network | ||||
|                 # Periodically update target network | ||||
|                 if update % self.update_target_model == 0: | ||||
|                     self.target_model.load_state_dict(self.model.state_dict()) | ||||
|  | ||||
|             # Save tracked indicators. | ||||
|             tracker.save() | ||||
|             # Add a new line to the screen periodically | ||||
|             if (update + 1) % 1_000 == 0: | ||||
|                 logger.log() | ||||
|  | ||||
| @ -217,10 +230,18 @@ class Trainer: | ||||
|             worker.child.send(("close", None)) | ||||
|  | ||||
|  | ||||
| # ## Run it | ||||
| if __name__ == "__main__": | ||||
| def main(): | ||||
|     # Create the experiment | ||||
|     experiment.create(name='dqn') | ||||
|     # Initialize the trainer | ||||
|     m = Trainer() | ||||
|     # Run and monitor the experiment | ||||
|     with experiment.start(): | ||||
|         m.run_training_loop() | ||||
|     # Stop the workers | ||||
|     m.destroy() | ||||
|  | ||||
|  | ||||
| # ## Run it | ||||
| if __name__ == "__main__": | ||||
|     main() | ||||
|  | ||||
		Reference in New Issue
	
	Block a user