diff --git a/docs/rl/dqn/experiment.html b/docs/rl/dqn/experiment.html index 10bf4e87..9730c6b5 100644 --- a/docs/rl/dqn/experiment.html +++ b/docs/rl/dqn/experiment.html @@ -70,17 +70,20 @@
This experiment trains a Deep Q Network (DQN) to play Atari Breakout game on OpenAI Gym. It runs the game environments on multiple processes to sample efficiently.
+13import numpy as np
-14import torch
-15
-16from labml import tracker, experiment, logger, monit
-17from labml_helpers.schedule import Piecewise
-18from labml_nn.rl.dqn import QFuncLoss
-19from labml_nn.rl.dqn.model import Model
-20from labml_nn.rl.dqn.replay_buffer import ReplayBuffer
-21from labml_nn.rl.game import Worker16import numpy as np
+17import torch
+18
+19from labml import tracker, experiment, logger, monit
+20from labml.internal.configs.dynamic_hyperparam import FloatDynamicHyperParam
+21from labml_helpers.schedule import Piecewise
+22from labml_nn.rl.dqn import QFuncLoss
+23from labml_nn.rl.dqn.model import Model
+24from labml_nn.rl.dqn.replay_buffer import ReplayBuffer
+25from labml_nn.rl.game import WorkerSelect device
24if torch.cuda.is_available():
-25 device = torch.device("cuda:0")
-26else:
-27 device = torch.device("cpu")28if torch.cuda.is_available():
+29 device = torch.device("cuda:0")
+30else:
+31 device = torch.device("cpu")Scale observations from [0, 255] to [0, 1]
30def obs_to_torch(obs: np.ndarray) -> torch.Tensor:34def obs_to_torch(obs: np.ndarray) -> torch.Tensor:32 return torch.tensor(obs, dtype=torch.float32, device=device) / 255.36 return torch.tensor(obs, dtype=torch.float32, device=device) / 255.35class Trainer:39class Trainer:40 def __init__(self):44 def __init__(self, *,
+45 updates: int, epochs: int,
+46 n_workers: int, worker_steps: int, mini_batch_size: int,
+47 update_target_model: int,
+48 learning_rate: FloatDynamicHyperParam,
+49 ):number of workers
51 self.n_workers = n_workersnumber of workers
+steps sampled on each update
44 self.n_workers = 853 self.worker_steps = worker_stepssteps sampled on each update
+number of training iterations
46 self.worker_steps = 455 self.train_epochs = epochsnumber of training iterations
+number of updates
48 self.train_epochs = 858 self.updates = updatesnumber of updates
+size of mini batch for training
51 self.updates = 1_000_00060 self.mini_batch_size = mini_batch_sizesize of mini batch for training
+update target network every 250 update
53 self.mini_batch_size = 3263 self.update_target_model = update_target_modelexploration as a function of updates
+learning rate
56 self.exploration_coefficient = Piecewise(
-57 [
-58 (0, 1.0),
-59 (25_000, 0.1),
-60 (self.updates / 2, 0.01)
-61 ], outside_value=0.01)66 self.learning_rate = learning_rateupdate target network every 250 update
+exploration as a function of updates
64 self.update_target_model = 25069 self.exploration_coefficient = Piecewise(
+70 [
+71 (0, 1.0),
+72 (25_000, 0.1),
+73 (self.updates / 2, 0.01)
+74 ], outside_value=0.01)$\beta$ for replay buffer as a function of updates
67 self.prioritized_replay_beta = Piecewise(
-68 [
-69 (0, 0.4),
-70 (self.updates, 1)
-71 ], outside_value=1)77 self.prioritized_replay_beta = Piecewise(
+78 [
+79 (0, 0.4),
+80 (self.updates, 1)
+81 ], outside_value=1)Replay buffer with $\alpha = 0.6$. Capacity of the replay buffer must be a power of 2.
74 self.replay_buffer = ReplayBuffer(2 ** 14, 0.6)84 self.replay_buffer = ReplayBuffer(2 ** 14, 0.6)Model for sampling and training
77 self.model = Model().to(device)87 self.model = Model().to(device)target model to get $\color{orange}Q(s’;\color{orange}{\theta_i^{-}})$
79 self.target_model = Model().to(device)89 self.target_model = Model().to(device)create workers
82 self.workers = [Worker(47 + i) for i in range(self.n_workers)]92 self.workers = [Worker(47 + i) for i in range(self.n_workers)]initialize tensors for observations
85 self.obs = np.zeros((self.n_workers, 4, 84, 84), dtype=np.uint8)
-86 for worker in self.workers:
-87 worker.child.send(("reset", None))
-88 for i, worker in enumerate(self.workers):
-89 self.obs[i] = worker.child.recv()95 self.obs = np.zeros((self.n_workers, 4, 84, 84), dtype=np.uint8)loss function
+reset the workers
92 self.loss_func = QFuncLoss(0.99)98 for worker in self.workers:
+99 worker.child.send(("reset", None))optimizer
+get the initial observations
94 self.optimizer = torch.optim.Adam(self.model.parameters(), lr=2.5e-4)102 for i, worker in enumerate(self.workers):
+103 self.obs[i] = worker.child.recv()loss function
+106 self.loss_func = QFuncLoss(0.99)optimizer
+108 self.optimizer = torch.optim.Adam(self.model.parameters(), lr=2.5e-4)When sampling actions we use a $\epsilon$-greedy strategy, where we
take a greedy action with probabiliy $1 - \epsilon$ and
@@ -342,29 +370,7 @@ take a random action with probability $\epsilon$.
We refer to $\epsilon$ as exploration_coefficient.
96 def _sample_action(self, q_value: torch.Tensor, exploration_coefficient: float):Sampling doesn’t need gradients
-106 with torch.no_grad():Sample the action with highest Q-value. This is the greedy action.
-108 greedy_action = torch.argmax(q_value, dim=-1)110 def _sample_action(self, q_value: torch.Tensor, exploration_coefficient: float):exploration_coefficient.
- Uniformly sample and action
+Sampling doesn’t need gradients
110 random_action = torch.randint(q_value.shape[-1], greedy_action.shape, device=q_value.device)120 with torch.no_grad():exploration_coefficient.
- Whether to chose greedy action or the random action
+Sample the action with highest Q-value. This is the greedy action.
112 is_choose_rand = torch.rand(greedy_action.shape, device=q_value.device) < exploration_coefficient122 greedy_action = torch.argmax(q_value, dim=-1)exploration_coefficient.
- Pick the action based on is_choose_rand
Uniformly sample and action
114 return torch.where(is_choose_rand, random_action, greedy_action).cpu().numpy()124 random_action = torch.randint(q_value.shape[-1], greedy_action.shape, device=q_value.device)116 def sample(self, exploration_coefficient: float):126 is_choose_rand = torch.rand(greedy_action.shape, device=q_value.device) < exploration_coefficientexploration_coefficient.
- This doesn’t need gradients
+Pick the action based on is_choose_rand
120 with torch.no_grad():128 return torch.where(is_choose_rand, random_action, greedy_action).cpu().numpy()122 for t in range(self.worker_steps):130 def sample(self, exploration_coefficient: float):exploration_coefficient.
- Get Q_values for the current observation
+This doesn’t need gradients
124 q_value = self.model(obs_to_torch(self.obs))134 with torch.no_grad():exploration_coefficient.
- Sample actions
+Sample worker_steps
126 actions = self._sample_action(q_value, exploration_coefficient)136 for t in range(self.worker_steps):exploration_coefficient.
- Run sampled actions on each worker
+Get Q_values for the current observation
129 for w, worker in enumerate(self.workers):
-130 worker.child.send(("step", actions[w]))138 q_value = self.model(obs_to_torch(self.obs))exploration_coefficient.
- Collect information from each worker
+Sample actions
133 for w, worker in enumerate(self.workers):140 actions = self._sample_action(q_value, exploration_coefficient)exploration_coefficient.
- Get results after executing the actions
+Run sampled actions on each worker
135 next_obs, reward, done, info = worker.child.recv()143 for w, worker in enumerate(self.workers):
+144 worker.child.send(("step", actions[w]))exploration_coefficient.
- Add transition to replay buffer
+Collect information from each worker
138 self.replay_buffer.add(self.obs[w], actions[w], reward, next_obs, done)147 for w, worker in enumerate(self.workers):exploration_coefficient.
- update episode information.
-collect episode info, which is available if an episode finished;
- this includes total reward and length of the episode -
- look at Game to see how it works.
Get results after executing the actions
144 if info:
-145 tracker.add('reward', info['reward'])
-146 tracker.add('length', info['length'])149 next_obs, reward, done, info = worker.child.recv()update current observation
+Add transition to replay buffer
149 self.obs[w] = next_obs152 self.replay_buffer.add(self.obs[w], actions[w], reward, next_obs, done)update episode information.
+collect episode info, which is available if an episode finished;
+ this includes total reward and length of the episode -
+ look at Game to see how it works.
151 def train(self, beta: float):158 if info:
+159 tracker.add('reward', info['reward'])
+160 tracker.add('length', info['length'])update current observation
155 for _ in range(self.train_epochs):163 self.obs[w] = next_obs157 samples = self.replay_buffer.sample(self.mini_batch_size, beta)165 def train(self, beta: float):Get the predicted Q-value
+159 q_value = self.model(obs_to_torch(samples['obs']))169 for _ in range(self.train_epochs):Get the Q-values of the next state for Double Q-learning. -Gradients shouldn’t propagate for these
+Sample from priority replay buffer
163 with torch.no_grad():171 samples = self.replay_buffer.sample(self.mini_batch_size, beta)Get $\color{cyan}Q(s’;\color{cyan}{\theta_i})$
+Get the predicted Q-value
165 double_q_value = self.model(obs_to_torch(samples['next_obs']))173 q_value = self.model(obs_to_torch(samples['obs']))Get $\color{orange}Q(s’;\color{orange}{\theta_i^{-}})$
+Get the Q-values of the next state for Double Q-learning. +Gradients shouldn’t propagate for these
167 target_q_value = self.target_model(obs_to_torch(samples['next_obs']))177 with torch.no_grad():Compute Temporal Difference (TD) errors, $\delta$, and the loss, $\mathcal{L}(\theta)$.
+Get $\color{cyan}Q(s’;\color{cyan}{\theta_i})$
170 td_errors, loss = self.loss_func(q_value,
-171 q_value.new_tensor(samples['action']),
-172 double_q_value, target_q_value,
-173 q_value.new_tensor(samples['done']),
-174 q_value.new_tensor(samples['reward']),
-175 q_value.new_tensor(samples['weights']))179 double_q_value = self.model(obs_to_torch(samples['next_obs']))Calculate priorities for replay buffer $p_i = |\delta_i| + \epsilon$
+Get $\color{orange}Q(s’;\color{orange}{\theta_i^{-}})$
178 new_priorities = np.abs(td_errors.cpu().numpy()) + 1e-6181 target_q_value = self.target_model(obs_to_torch(samples['next_obs']))Update replay buffer priorities
+Compute Temporal Difference (TD) errors, $\delta$, and the loss, $\mathcal{L}(\theta)$.
180 self.replay_buffer.update_priorities(samples['indexes'], new_priorities)184 td_errors, loss = self.loss_func(q_value,
+185 q_value.new_tensor(samples['action']),
+186 double_q_value, target_q_value,
+187 q_value.new_tensor(samples['done']),
+188 q_value.new_tensor(samples['reward']),
+189 q_value.new_tensor(samples['weights']))Zero out the previously calculated gradients
+Calculate priorities for replay buffer $p_i = |\delta_i| + \epsilon$
183 self.optimizer.zero_grad()192 new_priorities = np.abs(td_errors.cpu().numpy()) + 1e-6Calculate gradients
+Update replay buffer priorities
185 loss.backward()194 self.replay_buffer.update_priorities(samples['indexes'], new_priorities)Clip gradients
+Set learning rate
187 torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=0.5)197 for pg in self.optimizer.param_groups:
+198 pg['lr'] = self.learning_rate()Update parameters based on gradients
+Zero out the previously calculated gradients
189 self.optimizer.step()200 self.optimizer.zero_grad()191 def run_training_loop(self):202 loss.backward()Last 100 episode information
+Clip gradients
197 tracker.set_queue('reward', 100, True)
-198 tracker.set_queue('length', 100, True)204 torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=0.5)Copy to target network initially
+Update parameters based on gradients
201 self.target_model.load_state_dict(self.model.state_dict())
-202
-203 for update in monit.loop(self.updates):206 self.optimizer.step()205 exploration = self.exploration_coefficient(update)
-206 tracker.add('exploration', exploration)208 def run_training_loop(self):$\beta$ for prioritized replay
+Last 100 episode information
208 beta = self.prioritized_replay_beta(update)
-209 tracker.add('beta', beta)214 tracker.set_queue('reward', 100, True)
+215 tracker.set_queue('length', 100, True)Sample with current policy
+Copy to target network initially
212 self.sample(exploration)218 self.target_model.load_state_dict(self.model.state_dict())
+219
+220 for update in monit.loop(self.updates):Start training after the buffer is full
+$\epsilon$, exploration fraction
215 if self.replay_buffer.is_full():222 exploration = self.exploration_coefficient(update)
+223 tracker.add('exploration', exploration)Train the model
+$\beta$ for prioritized replay
217 self.train(beta)225 beta = self.prioritized_replay_beta(update)
+226 tracker.add('beta', beta)Periodically update target network
+Sample with current policy
220 if update % self.update_target_model == 0:
-221 self.target_model.load_state_dict(self.model.state_dict())229 self.sample(exploration)Save tracked indicators.
+Start training after the buffer is full
224 tracker.save()232 if self.replay_buffer.is_full():Add a new line to the screen periodically
+Train the model
226 if (update + 1) % 1_000 == 0:
-227 logger.log()234 self.train(beta)229 def destroy(self):237 if update % self.update_target_model == 0:
+238 self.target_model.load_state_dict(self.model.state_dict())Save tracked indicators.
234 for worker in self.workers:
-235 worker.child.send(("close", None))241 tracker.save()Add a new line to the screen periodically
238def main():243 if (update + 1) % 1_000 == 0:
+244 logger.log()240 experiment.create(name='dqn')246 def destroy(self):Initialize the trainer
+242 m = Trainer()251 for worker in self.workers:
+252 worker.child.send(("close", None))Run and monitor the experiment
+244 with experiment.start():
-245 m.run_training_loop()255def main():Stop the workers
+Create the experiment
247 m.destroy()257 experiment.create(name='dqn')Configurations
+260 configs = {Number of updates
+262 'updates': 1_000_000,Number of epochs to train the model with sampled data.
+264 'epochs': 8,Number of worker processes
+266 'n_workers': 8,Number of steps to run on each process for a single update
+268 'worker_steps': 4,Mini batch size
+270 'mini_batch_size': 32,Target model updating interval
+272 'update_target_model': 250,Learning rate.
+274 'learning_rate': FloatDynamicHyperParam(2.5e-4, (0, 1e-3)),
+275 }Configurations
+278 experiment.configs(configs)Initialize the trainer
+281 m = Trainer(**configs)Run and monitor the experiment
+283 with experiment.start():
+284 m.run_training_loop()Stop the workers
+286 m.destroy()251if __name__ == "__main__":
-252 main()290if __name__ == "__main__":
+291 main()27from typing import Tuple
-28
-29import torch
-30from torch import nn
-31
-32from labml import tracker
-33from labml_helpers.module import Module
-34from labml_nn.rl.dqn.replay_buffer import ReplayBuffer25from typing import Tuple
+26
+27import torch
+28from torch import nn
+29
+30from labml import tracker
+31from labml_helpers.module import Module
+32from labml_nn.rl.dqn.replay_buffer import ReplayBuffer37class QFuncLoss(Module):35class QFuncLoss(Module):104 def __init__(self, gamma: float):
-105 super().__init__()
-106 self.gamma = gamma
-107 self.huber_loss = nn.SmoothL1Loss(reduction='none')102 def __init__(self, gamma: float):
+103 super().__init__()
+104 self.gamma = gamma
+105 self.huber_loss = nn.SmoothL1Loss(reduction='none')109 def forward(self, q: torch.Tensor, action: torch.Tensor, double_q: torch.Tensor,
-110 target_q: torch.Tensor, done: torch.Tensor, reward: torch.Tensor,
-111 weights: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:107 def forward(self, q: torch.Tensor, action: torch.Tensor, double_q: torch.Tensor,
+108 target_q: torch.Tensor, done: torch.Tensor, reward: torch.Tensor,
+109 weights: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:$Q(s,a;\theta_i)$
123 q_sampled_action = q.gather(-1, action.to(torch.long).unsqueeze(-1)).squeeze(-1)
-124 tracker.add('q_sampled_action', q_sampled_action)121 q_sampled_action = q.gather(-1, action.to(torch.long).unsqueeze(-1)).squeeze(-1)
+122 tracker.add('q_sampled_action', q_sampled_action)132 with torch.no_grad():130 with torch.no_grad():136 best_next_action = torch.argmax(double_q, -1)134 best_next_action = torch.argmax(double_q, -1)142 best_next_q_value = target_q.gather(-1, best_next_action.unsqueeze(-1)).squeeze(-1)140 best_next_q_value = target_q.gather(-1, best_next_action.unsqueeze(-1)).squeeze(-1)153 q_update = reward + self.gamma * best_next_q_value * (1 - done)
-154 tracker.add('q_update', q_update)151 q_update = reward + self.gamma * best_next_q_value * (1 - done)
+152 tracker.add('q_update', q_update)Temporal difference error $\delta$ is used to weigh samples in replay buffer
157 td_error = q_sampled_action - q_update
-158 tracker.add('td_error', td_error)155 td_error = q_sampled_action - q_update
+156 tracker.add('td_error', td_error)162 losses = self.huber_loss(q_sampled_action, q_update)160 losses = self.huber_loss(q_sampled_action, q_update)Get weighted means
164 loss = torch.mean(weights * losses)
-165 tracker.add('loss', loss)
-166
-167 return td_error, loss162 loss = torch.mean(weights * losses)
+163 tracker.add('loss', loss)
+164
+165 return td_error, loss10import torch
-11from torch import nn
-12
-13from labml_helpers.module import Module13import torch
+14from torch import nn
+15
+16from labml_helpers.module import Module16class Model(Module):19class Model(Module):47 def __init__(self):
-48 super().__init__()
-49 self.conv = nn.Sequential(50 def __init__(self):
+51 super().__init__()
+52 self.conv = nn.Sequential(52 nn.Conv2d(in_channels=4, out_channels=32, kernel_size=8, stride=4),
-53 nn.ReLU(),55 nn.Conv2d(in_channels=4, out_channels=32, kernel_size=8, stride=4),
+56 nn.ReLU(),57 nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2),
-58 nn.ReLU(),60 nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2),
+61 nn.ReLU(),62 nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1),
-63 nn.ReLU(),
-64 )65 nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1),
+66 nn.ReLU(),
+67 )69 self.lin = nn.Linear(in_features=7 * 7 * 64, out_features=512)
-70 self.activation = nn.ReLU()72 self.lin = nn.Linear(in_features=7 * 7 * 64, out_features=512)
+73 self.activation = nn.ReLU()This head gives the state value $V$
73 self.state_value = nn.Sequential(
-74 nn.Linear(in_features=512, out_features=256),
-75 nn.ReLU(),
-76 nn.Linear(in_features=256, out_features=1),
-77 )76 self.state_value = nn.Sequential(
+77 nn.Linear(in_features=512, out_features=256),
+78 nn.ReLU(),
+79 nn.Linear(in_features=256, out_features=1),
+80 )This head gives the action value $A$
79 self.action_value = nn.Sequential(
-80 nn.Linear(in_features=512, out_features=256),
-81 nn.ReLU(),
-82 nn.Linear(in_features=256, out_features=4),
-83 )82 self.action_value = nn.Sequential(
+83 nn.Linear(in_features=512, out_features=256),
+84 nn.ReLU(),
+85 nn.Linear(in_features=256, out_features=4),
+86 )85 def forward(self, obs: torch.Tensor):88 def forward(self, obs: torch.Tensor):Convolution
87 h = self.conv(obs)90 h = self.conv(obs)Reshape for linear layers
89 h = h.reshape((-1, 7 * 7 * 64))92 h = h.reshape((-1, 7 * 7 * 64))Linear layer
92 h = self.activation(self.lin(h))95 h = self.activation(self.lin(h))$A$
95 action_value = self.action_value(h)98 action_value = self.action_value(h)$V$
97 state_value = self.state_value(h)100 state_value = self.state_value(h)$A(s, a) - \frac{1}{|\mathcal{A}|} \sum_{a’ \in \mathcal{A}} A(s, a’)$
100 action_score_centered = action_value - action_value.mean(dim=-1, keepdim=True)103 action_score_centered = action_value - action_value.mean(dim=-1, keepdim=True)$Q(s, a) =V(s) + \Big(A(s, a) - \frac{1}{|\mathcal{A}|} \sum_{a’ \in \mathcal{A}} A(s, a’)\Big)$
102 q = state_value + action_score_centered
-103
-104 return q105 q = state_value + action_score_centered
+106
+107 return q13import random
-14
-15import numpy as np16import random
+17
+18import numpy as npWe use the same structure to compute the minimum.
18class ReplayBuffer:21class ReplayBuffer:88 def __init__(self, capacity, alpha):91 def __init__(self, capacity, alpha):We use a power of $2$ for capacity because it simplifies the code and debugging
93 self.capacity = capacity96 self.capacity = capacity$\alpha$
95 self.alpha = alpha98 self.alpha = alphaMaintain segment binary trees to take sum and find minimum over a range
98 self.priority_sum = [0 for _ in range(2 * self.capacity)]
-99 self.priority_min = [float('inf') for _ in range(2 * self.capacity)]101 self.priority_sum = [0 for _ in range(2 * self.capacity)]
+102 self.priority_min = [float('inf') for _ in range(2 * self.capacity)]Current max priority, $p$, to be assigned to new transitions
102 self.max_priority = 1.105 self.max_priority = 1.Arrays for buffer
105 self.data = {
-106 'obs': np.zeros(shape=(capacity, 4, 84, 84), dtype=np.uint8),
-107 'action': np.zeros(shape=capacity, dtype=np.int32),
-108 'reward': np.zeros(shape=capacity, dtype=np.float32),
-109 'next_obs': np.zeros(shape=(capacity, 4, 84, 84), dtype=np.uint8),
-110 'done': np.zeros(shape=capacity, dtype=np.bool)
-111 }108 self.data = {
+109 'obs': np.zeros(shape=(capacity, 4, 84, 84), dtype=np.uint8),
+110 'action': np.zeros(shape=capacity, dtype=np.int32),
+111 'reward': np.zeros(shape=capacity, dtype=np.float32),
+112 'next_obs': np.zeros(shape=(capacity, 4, 84, 84), dtype=np.uint8),
+113 'done': np.zeros(shape=capacity, dtype=np.bool)
+114 }114 self.next_idx = 0117 self.next_idx = 0Size of the buffer
117 self.size = 0120 self.size = 0119 def add(self, obs, action, reward, next_obs, done):122 def add(self, obs, action, reward, next_obs, done):Get next available slot
125 idx = self.next_idx128 idx = self.next_idxstore in the queue
128 self.data['obs'][idx] = obs
-129 self.data['action'][idx] = action
-130 self.data['reward'][idx] = reward
-131 self.data['next_obs'][idx] = next_obs
-132 self.data['done'][idx] = done131 self.data['obs'][idx] = obs
+132 self.data['action'][idx] = action
+133 self.data['reward'][idx] = reward
+134 self.data['next_obs'][idx] = next_obs
+135 self.data['done'][idx] = doneIncrement next available slot
135 self.next_idx = (idx + 1) % self.capacity138 self.next_idx = (idx + 1) % self.capacityCalculate the size
137 self.size = min(self.capacity, self.size + 1)140 self.size = min(self.capacity, self.size + 1)$p_i^\alpha$, new samples get max_priority
140 priority_alpha = self.max_priority ** self.alpha143 priority_alpha = self.max_priority ** self.alphaUpdate the two segment trees for sum and minimum
142 self._set_priority_min(idx, priority_alpha)
-143 self._set_priority_sum(idx, priority_alpha)145 self._set_priority_min(idx, priority_alpha)
+146 self._set_priority_sum(idx, priority_alpha)145 def _set_priority_min(self, idx, priority_alpha):148 def _set_priority_min(self, idx, priority_alpha):Leaf of the binary tree
151 idx += self.capacity
-152 self.priority_min[idx] = priority_alpha154 idx += self.capacity
+155 self.priority_min[idx] = priority_alpha156 while idx >= 2:159 while idx >= 2:Get the index of the parent node
158 idx //= 2161 idx //= 2Value of the parent node is the minimum of it’s two children
160 self.priority_min[idx] = min(self.priority_min[2 * idx], self.priority_min[2 * idx + 1])163 self.priority_min[idx] = min(self.priority_min[2 * idx], self.priority_min[2 * idx + 1])162 def _set_priority_sum(self, idx, priority):165 def _set_priority_sum(self, idx, priority):Leaf of the binary tree
168 idx += self.capacity171 idx += self.capacitySet the priority at the leaf
170 self.priority_sum[idx] = priority173 self.priority_sum[idx] = priority174 while idx >= 2:177 while idx >= 2:Get the index of the parent node
176 idx //= 2179 idx //= 2Value of the parent node is the sum of it’s two children
178 self.priority_sum[idx] = self.priority_sum[2 * idx] + self.priority_sum[2 * idx + 1]181 self.priority_sum[idx] = self.priority_sum[2 * idx] + self.priority_sum[2 * idx + 1]180 def _sum(self):183 def _sum(self):The root node keeps the sum of all values
186 return self.priority_sum[1]189 return self.priority_sum[1]188 def _min(self):191 def _min(self):The root node keeps the minimum of all values
194 return self.priority_min[1]197 return self.priority_min[1]196 def find_prefix_sum_idx(self, prefix_sum):199 def find_prefix_sum_idx(self, prefix_sum):Start from the root
202 idx = 1
-203 while idx < self.capacity:205 idx = 1
+206 while idx < self.capacity:If the sum of the left branch is higher than required sum
205 if self.priority_sum[idx * 2] > prefix_sum:208 if self.priority_sum[idx * 2] > prefix_sum:Go to left branch of the tree
207 idx = 2 * idx
-208 else:210 idx = 2 * idx
+211 else:211 prefix_sum -= self.priority_sum[idx * 2]
-212 idx = 2 * idx + 1214 prefix_sum -= self.priority_sum[idx * 2]
+215 idx = 2 * idx + 1216 return idx - self.capacity219 return idx - self.capacity218 def sample(self, batch_size, beta):221 def sample(self, batch_size, beta):Initialize samples
224 samples = {
-225 'weights': np.zeros(shape=batch_size, dtype=np.float32),
-226 'indexes': np.zeros(shape=batch_size, dtype=np.int32)
-227 }227 samples = {
+228 'weights': np.zeros(shape=batch_size, dtype=np.float32),
+229 'indexes': np.zeros(shape=batch_size, dtype=np.int32)
+230 }Get sample indexes
230 for i in range(batch_size):
-231 p = random.random() * self._sum()
-232 idx = self.find_prefix_sum_idx(p)
-233 samples['indexes'][i] = idx233 for i in range(batch_size):
+234 p = random.random() * self._sum()
+235 idx = self.find_prefix_sum_idx(p)
+236 samples['indexes'][i] = idx$\min_i P(i) = \frac{\min_i p_i^\alpha}{\sum_k p_k^\alpha}$
236 prob_min = self._min() / self._sum()239 prob_min = self._min() / self._sum()$\max_i w_i = \bigg(\frac{1}{N} \frac{1}{\min_i P(i)}\bigg)^\beta$
238 max_weight = (prob_min * self.size) ** (-beta)
-239
-240 for i in range(batch_size):
-241 idx = samples['indexes'][i]241 max_weight = (prob_min * self.size) ** (-beta)
+242
+243 for i in range(batch_size):
+244 idx = samples['indexes'][i]$P(i) = \frac{p_i^\alpha}{\sum_k p_k^\alpha}$
243 prob = self.priority_sum[idx + self.capacity] / self._sum()246 prob = self.priority_sum[idx + self.capacity] / self._sum()$w_i = \bigg(\frac{1}{N} \frac{1}{P(i)}\bigg)^\beta$
245 weight = (prob * self.size) ** (-beta)248 weight = (prob * self.size) ** (-beta)248 samples['weights'][i] = weight / max_weight251 samples['weights'][i] = weight / max_weightGet samples data
251 for k, v in self.data.items():
-252 samples[k] = v[samples['indexes']]
-253
-254 return samples254 for k, v in self.data.items():
+255 samples[k] = v[samples['indexes']]
+256
+257 return samples256 def update_priorities(self, indexes, priorities):259 def update_priorities(self, indexes, priorities):261 for idx, priority in zip(indexes, priorities):264 for idx, priority in zip(indexes, priorities):Set current max priority
263 self.max_priority = max(self.max_priority, priority)266 self.max_priority = max(self.max_priority, priority)Calculate $p_i^\alpha$
266 priority_alpha = priority ** self.alpha269 priority_alpha = priority ** self.alphaUpdate the trees
268 self._set_priority_min(idx, priority_alpha)
-269 self._set_priority_sum(idx, priority_alpha)271 self._set_priority_min(idx, priority_alpha)
+272 self._set_priority_sum(idx, priority_alpha)271 def is_full(self):274 def is_full(self):275 return self.capacity == self.size278 return self.capacity == self.size391 m = Trainer(
-392 updates=configs['updates'],
-393 epochs=configs['epochs'],
-394 n_workers=configs['n_workers'],
-395 worker_steps=configs['worker_steps'],
-396 batches=configs['batches'],
-397 value_loss_coef=configs['value_loss_coef'],
-398 entropy_bonus_coef=configs['entropy_bonus_coef'],
-399 clip_range=configs['clip_range'],
-400 learning_rate=configs['learning_rate'],
-401 )391 m = Trainer(**configs)Run and monitor the experiment
404 with experiment.start():
-405 m.run_training_loop()394 with experiment.start():
+395 m.run_training_loop()Stop the workers
407 m.destroy()397 m.destroy()411if __name__ == "__main__":
-412 main()401if __name__ == "__main__":
+402 main()