diff --git a/docs/rl/dqn/experiment.html b/docs/rl/dqn/experiment.html index 10bf4e87..9730c6b5 100644 --- a/docs/rl/dqn/experiment.html +++ b/docs/rl/dqn/experiment.html @@ -70,17 +70,20 @@

DQN Experiment with Atari Breakout

This experiment trains a Deep Q Network (DQN) to play Atari Breakout game on OpenAI Gym. It runs the game environments on multiple processes to sample efficiently.

+

Open In Colab +View Run

-
13import numpy as np
-14import torch
-15
-16from labml import tracker, experiment, logger, monit
-17from labml_helpers.schedule import Piecewise
-18from labml_nn.rl.dqn import QFuncLoss
-19from labml_nn.rl.dqn.model import Model
-20from labml_nn.rl.dqn.replay_buffer import ReplayBuffer
-21from labml_nn.rl.game import Worker
+
16import numpy as np
+17import torch
+18
+19from labml import tracker, experiment, logger, monit
+20from labml.internal.configs.dynamic_hyperparam import FloatDynamicHyperParam
+21from labml_helpers.schedule import Piecewise
+22from labml_nn.rl.dqn import QFuncLoss
+23from labml_nn.rl.dqn.model import Model
+24from labml_nn.rl.dqn.replay_buffer import ReplayBuffer
+25from labml_nn.rl.game import Worker
@@ -91,10 +94,10 @@ It runs the game environments on multiple processes t

Select device

-
24if torch.cuda.is_available():
-25    device = torch.device("cuda:0")
-26else:
-27    device = torch.device("cpu")
+
28if torch.cuda.is_available():
+29    device = torch.device("cuda:0")
+30else:
+31    device = torch.device("cpu")
@@ -105,7 +108,7 @@ It runs the game environments on multiple processes t

Scale observations from [0, 255] to [0, 1]

-
30def obs_to_torch(obs: np.ndarray) -> torch.Tensor:
+
34def obs_to_torch(obs: np.ndarray) -> torch.Tensor:
@@ -116,7 +119,7 @@ It runs the game environments on multiple processes t
-
32    return torch.tensor(obs, dtype=torch.float32, device=device) / 255.
+
36    return torch.tensor(obs, dtype=torch.float32, device=device) / 255.
@@ -127,7 +130,7 @@ It runs the game environments on multiple processes t

Trainer

-
35class Trainer:
+
39class Trainer:
@@ -138,7 +141,12 @@ It runs the game environments on multiple processes t
-
40    def __init__(self):
+
44    def __init__(self, *,
+45                 updates: int, epochs: int,
+46                 n_workers: int, worker_steps: int, mini_batch_size: int,
+47                 update_target_model: int,
+48                 learning_rate: FloatDynamicHyperParam,
+49                 ):
@@ -146,10 +154,10 @@ It runs the game environments on multiple processes t -

Configurations

+

number of workers

-
+
51        self.n_workers = n_workers
@@ -157,10 +165,10 @@ It runs the game environments on multiple processes t -

number of workers

+

steps sampled on each update

-
44        self.n_workers = 8
+
53        self.worker_steps = worker_steps
@@ -168,10 +176,10 @@ It runs the game environments on multiple processes t -

steps sampled on each update

+

number of training iterations

-
46        self.worker_steps = 4
+
55        self.train_epochs = epochs
@@ -179,10 +187,10 @@ It runs the game environments on multiple processes t -

number of training iterations

+

number of updates

-
48        self.train_epochs = 8
+
58        self.updates = updates
@@ -190,10 +198,10 @@ It runs the game environments on multiple processes t -

number of updates

+

size of mini batch for training

-
51        self.updates = 1_000_000
+
60        self.mini_batch_size = mini_batch_size
@@ -201,10 +209,10 @@ It runs the game environments on multiple processes t -

size of mini batch for training

+

update target network every 250 update

-
53        self.mini_batch_size = 32
+
63        self.update_target_model = update_target_model
@@ -212,15 +220,10 @@ It runs the game environments on multiple processes t -

exploration as a function of updates

+

learning rate

-
56        self.exploration_coefficient = Piecewise(
-57            [
-58                (0, 1.0),
-59                (25_000, 0.1),
-60                (self.updates / 2, 0.01)
-61            ], outside_value=0.01)
+
66        self.learning_rate = learning_rate
@@ -228,10 +231,15 @@ It runs the game environments on multiple processes t -

update target network every 250 update

+

exploration as a function of updates

-
64        self.update_target_model = 250
+
69        self.exploration_coefficient = Piecewise(
+70            [
+71                (0, 1.0),
+72                (25_000, 0.1),
+73                (self.updates / 2, 0.01)
+74            ], outside_value=0.01)
@@ -242,11 +250,11 @@ It runs the game environments on multiple processes t

$\beta$ for replay buffer as a function of updates

-
67        self.prioritized_replay_beta = Piecewise(
-68            [
-69                (0, 0.4),
-70                (self.updates, 1)
-71            ], outside_value=1)
+
77        self.prioritized_replay_beta = Piecewise(
+78            [
+79                (0, 0.4),
+80                (self.updates, 1)
+81            ], outside_value=1)
@@ -257,7 +265,7 @@ It runs the game environments on multiple processes t

Replay buffer with $\alpha = 0.6$. Capacity of the replay buffer must be a power of 2.

-
74        self.replay_buffer = ReplayBuffer(2 ** 14, 0.6)
+
84        self.replay_buffer = ReplayBuffer(2 ** 14, 0.6)
@@ -268,7 +276,7 @@ It runs the game environments on multiple processes t

Model for sampling and training

-
77        self.model = Model().to(device)
+
87        self.model = Model().to(device)
@@ -279,7 +287,7 @@ It runs the game environments on multiple processes t

target model to get $\color{orange}Q(s’;\color{orange}{\theta_i^{-}})$

-
79        self.target_model = Model().to(device)
+
89        self.target_model = Model().to(device)
@@ -290,7 +298,7 @@ It runs the game environments on multiple processes t

create workers

-
82        self.workers = [Worker(47 + i) for i in range(self.n_workers)]
+
92        self.workers = [Worker(47 + i) for i in range(self.n_workers)]
@@ -301,11 +309,7 @@ It runs the game environments on multiple processes t

initialize tensors for observations

-
85        self.obs = np.zeros((self.n_workers, 4, 84, 84), dtype=np.uint8)
-86        for worker in self.workers:
-87            worker.child.send(("reset", None))
-88        for i, worker in enumerate(self.workers):
-89            self.obs[i] = worker.child.recv()
+
95        self.obs = np.zeros((self.n_workers, 4, 84, 84), dtype=np.uint8)
@@ -313,10 +317,11 @@ It runs the game environments on multiple processes t -

loss function

+

reset the workers

-
92        self.loss_func = QFuncLoss(0.99)
+
98        for worker in self.workers:
+99            worker.child.send(("reset", None))
@@ -324,17 +329,40 @@ It runs the game environments on multiple processes t -

optimizer

+

get the initial observations

-
94        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=2.5e-4)
+
102        for i, worker in enumerate(self.workers):
+103            self.obs[i] = worker.child.recv()
-
+
+

loss function

+
+
+
106        self.loss_func = QFuncLoss(0.99)
+
+
+
+
+ +

optimizer

+
+
+
108        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=2.5e-4)
+
+
+
+
+

$\epsilon$-greedy Sampling

When sampling actions we use a $\epsilon$-greedy strategy, where we take a greedy action with probabiliy $1 - \epsilon$ and @@ -342,29 +370,7 @@ take a random action with probability $\epsilon$. We refer to $\epsilon$ as exploration_coefficient.

-
96    def _sample_action(self, q_value: torch.Tensor, exploration_coefficient: float):
-
-
-
-
- -

Sampling doesn’t need gradients

-
-
-
106        with torch.no_grad():
-
-
-
-
- -

Sample the action with highest Q-value. This is the greedy action.

-
-
-
108            greedy_action = torch.argmax(q_value, dim=-1)
+
110    def _sample_action(self, q_value: torch.Tensor, exploration_coefficient: float):
@@ -372,10 +378,10 @@ We refer to $\epsilon$ as exploration_coefficient.

-

Uniformly sample and action

+

Sampling doesn’t need gradients

-
110            random_action = torch.randint(q_value.shape[-1], greedy_action.shape, device=q_value.device)
+
120        with torch.no_grad():
@@ -383,10 +389,10 @@ We refer to $\epsilon$ as exploration_coefficient.

-

Whether to chose greedy action or the random action

+

Sample the action with highest Q-value. This is the greedy action.

-
112            is_choose_rand = torch.rand(greedy_action.shape, device=q_value.device) < exploration_coefficient
+
122            greedy_action = torch.argmax(q_value, dim=-1)
@@ -394,21 +400,21 @@ We refer to $\epsilon$ as exploration_coefficient.

-

Pick the action based on is_choose_rand

+

Uniformly sample and action

-
114            return torch.where(is_choose_rand, random_action, greedy_action).cpu().numpy()
+
124            random_action = torch.randint(q_value.shape[-1], greedy_action.shape, device=q_value.device)
-
+
-

Sample data

+

Whether to chose greedy action or the random action

-
116    def sample(self, exploration_coefficient: float):
+
126            is_choose_rand = torch.rand(greedy_action.shape, device=q_value.device) < exploration_coefficient
@@ -416,21 +422,21 @@ We refer to $\epsilon$ as exploration_coefficient.

-

This doesn’t need gradients

+

Pick the action based on is_choose_rand

-
120        with torch.no_grad():
+
128            return torch.where(is_choose_rand, random_action, greedy_action).cpu().numpy()
-
+
-

Sample worker_steps

+

Sample data

-
122            for t in range(self.worker_steps):
+
130    def sample(self, exploration_coefficient: float):
@@ -438,10 +444,10 @@ We refer to $\epsilon$ as exploration_coefficient.

-

Get Q_values for the current observation

+

This doesn’t need gradients

-
124                q_value = self.model(obs_to_torch(self.obs))
+
134        with torch.no_grad():
@@ -449,10 +455,10 @@ We refer to $\epsilon$ as exploration_coefficient.

-

Sample actions

+

Sample worker_steps

-
126                actions = self._sample_action(q_value, exploration_coefficient)
+
136            for t in range(self.worker_steps):
@@ -460,11 +466,10 @@ We refer to $\epsilon$ as exploration_coefficient.

-

Run sampled actions on each worker

+

Get Q_values for the current observation

-
129                for w, worker in enumerate(self.workers):
-130                    worker.child.send(("step", actions[w]))
+
138                q_value = self.model(obs_to_torch(self.obs))
@@ -472,10 +477,10 @@ We refer to $\epsilon$ as exploration_coefficient.

-

Collect information from each worker

+

Sample actions

-
133                for w, worker in enumerate(self.workers):
+
140                actions = self._sample_action(q_value, exploration_coefficient)
@@ -483,10 +488,11 @@ We refer to $\epsilon$ as exploration_coefficient.

-

Get results after executing the actions

+

Run sampled actions on each worker

-
135                    next_obs, reward, done, info = worker.child.recv()
+
143                for w, worker in enumerate(self.workers):
+144                    worker.child.send(("step", actions[w]))
@@ -494,10 +500,10 @@ We refer to $\epsilon$ as exploration_coefficient.

-

Add transition to replay buffer

+

Collect information from each worker

-
138                    self.replay_buffer.add(self.obs[w], actions[w], reward, next_obs, done)
+
147                for w, worker in enumerate(self.workers):
@@ -505,15 +511,10 @@ We refer to $\epsilon$ as exploration_coefficient.

-

update episode information. -collect episode info, which is available if an episode finished; - this includes total reward and length of the episode - - look at Game to see how it works.

+

Get results after executing the actions

-
144                    if info:
-145                        tracker.add('reward', info['reward'])
-146                        tracker.add('length', info['length'])
+
149                    next_obs, reward, done, info = worker.child.recv()
@@ -521,21 +522,26 @@ collect episode info, which is available if an episode finished; -

update current observation

+

Add transition to replay buffer

-
149                    self.obs[w] = next_obs
+
152                    self.replay_buffer.add(self.obs[w], actions[w], reward, next_obs, done)
-
+
-

Train the model

+

update episode information. +collect episode info, which is available if an episode finished; + this includes total reward and length of the episode - + look at Game to see how it works.

-
151    def train(self, beta: float):
+
158                    if info:
+159                        tracker.add('reward', info['reward'])
+160                        tracker.add('length', info['length'])
@@ -543,21 +549,21 @@ collect episode info, which is available if an episode finished; - +

update current observation

-
155        for _ in range(self.train_epochs):
+
163                    self.obs[w] = next_obs
-
+
-

Sample from priority replay buffer

+

Train the model

-
157            samples = self.replay_buffer.sample(self.mini_batch_size, beta)
+
165    def train(self, beta: float):
@@ -565,10 +571,10 @@ collect episode info, which is available if an episode finished; -

Get the predicted Q-value

+
-
159            q_value = self.model(obs_to_torch(samples['obs']))
+
169        for _ in range(self.train_epochs):
@@ -576,11 +582,10 @@ collect episode info, which is available if an episode finished; -

Get the Q-values of the next state for Double Q-learning. -Gradients shouldn’t propagate for these

+

Sample from priority replay buffer

-
163            with torch.no_grad():
+
171            samples = self.replay_buffer.sample(self.mini_batch_size, beta)
@@ -588,10 +593,10 @@ Gradients shouldn’t propagate for these

-

Get $\color{cyan}Q(s’;\color{cyan}{\theta_i})$

+

Get the predicted Q-value

-
165                double_q_value = self.model(obs_to_torch(samples['next_obs']))
+
173            q_value = self.model(obs_to_torch(samples['obs']))
@@ -599,10 +604,11 @@ Gradients shouldn’t propagate for these

-

Get $\color{orange}Q(s’;\color{orange}{\theta_i^{-}})$

+

Get the Q-values of the next state for Double Q-learning. +Gradients shouldn’t propagate for these

-
167                target_q_value = self.target_model(obs_to_torch(samples['next_obs']))
+
177            with torch.no_grad():
@@ -610,15 +616,10 @@ Gradients shouldn’t propagate for these

-

Compute Temporal Difference (TD) errors, $\delta$, and the loss, $\mathcal{L}(\theta)$.

+

Get $\color{cyan}Q(s’;\color{cyan}{\theta_i})$

-
170            td_errors, loss = self.loss_func(q_value,
-171                                             q_value.new_tensor(samples['action']),
-172                                             double_q_value, target_q_value,
-173                                             q_value.new_tensor(samples['done']),
-174                                             q_value.new_tensor(samples['reward']),
-175                                             q_value.new_tensor(samples['weights']))
+
179                double_q_value = self.model(obs_to_torch(samples['next_obs']))
@@ -626,10 +627,10 @@ Gradients shouldn’t propagate for these

-

Calculate priorities for replay buffer $p_i = |\delta_i| + \epsilon$

+

Get $\color{orange}Q(s’;\color{orange}{\theta_i^{-}})$

-
178            new_priorities = np.abs(td_errors.cpu().numpy()) + 1e-6
+
181                target_q_value = self.target_model(obs_to_torch(samples['next_obs']))
@@ -637,10 +638,15 @@ Gradients shouldn’t propagate for these

-

Update replay buffer priorities

+

Compute Temporal Difference (TD) errors, $\delta$, and the loss, $\mathcal{L}(\theta)$.

-
180            self.replay_buffer.update_priorities(samples['indexes'], new_priorities)
+
184            td_errors, loss = self.loss_func(q_value,
+185                                             q_value.new_tensor(samples['action']),
+186                                             double_q_value, target_q_value,
+187                                             q_value.new_tensor(samples['done']),
+188                                             q_value.new_tensor(samples['reward']),
+189                                             q_value.new_tensor(samples['weights']))
@@ -648,10 +654,10 @@ Gradients shouldn’t propagate for these

-

Zero out the previously calculated gradients

+

Calculate priorities for replay buffer $p_i = |\delta_i| + \epsilon$

-
183            self.optimizer.zero_grad()
+
192            new_priorities = np.abs(td_errors.cpu().numpy()) + 1e-6
@@ -659,10 +665,10 @@ Gradients shouldn’t propagate for these

-

Calculate gradients

+

Update replay buffer priorities

-
185            loss.backward()
+
194            self.replay_buffer.update_priorities(samples['indexes'], new_priorities)
@@ -670,10 +676,11 @@ Gradients shouldn’t propagate for these

-

Clip gradients

+

Set learning rate

-
187            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=0.5)
+
197            for pg in self.optimizer.param_groups:
+198                pg['lr'] = self.learning_rate()
@@ -681,21 +688,21 @@ Gradients shouldn’t propagate for these

-

Update parameters based on gradients

+

Zero out the previously calculated gradients

-
189            self.optimizer.step()
+
200            self.optimizer.zero_grad()
-
+
-

Run training loop

+

Calculate gradients

-
191    def run_training_loop(self):
+
202            loss.backward()
@@ -703,11 +710,10 @@ Gradients shouldn’t propagate for these

-

Last 100 episode information

+

Clip gradients

-
197        tracker.set_queue('reward', 100, True)
-198        tracker.set_queue('length', 100, True)
+
204            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=0.5)
@@ -715,24 +721,21 @@ Gradients shouldn’t propagate for these

-

Copy to target network initially

+

Update parameters based on gradients

-
201        self.target_model.load_state_dict(self.model.state_dict())
-202
-203        for update in monit.loop(self.updates):
+
206            self.optimizer.step()
-
+
-

$\epsilon$, exploration fraction

+

Run training loop

-
205            exploration = self.exploration_coefficient(update)
-206            tracker.add('exploration', exploration)
+
208    def run_training_loop(self):
@@ -740,11 +743,11 @@ Gradients shouldn’t propagate for these

-

$\beta$ for prioritized replay

+

Last 100 episode information

-
208            beta = self.prioritized_replay_beta(update)
-209            tracker.add('beta', beta)
+
214        tracker.set_queue('reward', 100, True)
+215        tracker.set_queue('length', 100, True)
@@ -752,10 +755,12 @@ Gradients shouldn’t propagate for these

-

Sample with current policy

+

Copy to target network initially

-
212            self.sample(exploration)
+
218        self.target_model.load_state_dict(self.model.state_dict())
+219
+220        for update in monit.loop(self.updates):
@@ -763,10 +768,11 @@ Gradients shouldn’t propagate for these

-

Start training after the buffer is full

+

$\epsilon$, exploration fraction

-
215            if self.replay_buffer.is_full():
+
222            exploration = self.exploration_coefficient(update)
+223            tracker.add('exploration', exploration)
@@ -774,10 +780,11 @@ Gradients shouldn’t propagate for these

-

Train the model

+

$\beta$ for prioritized replay

-
217                self.train(beta)
+
225            beta = self.prioritized_replay_beta(update)
+226            tracker.add('beta', beta)
@@ -785,11 +792,10 @@ Gradients shouldn’t propagate for these

-

Periodically update target network

+

Sample with current policy

-
220                if update % self.update_target_model == 0:
-221                    self.target_model.load_state_dict(self.model.state_dict())
+
229            self.sample(exploration)
@@ -797,10 +803,10 @@ Gradients shouldn’t propagate for these

-

Save tracked indicators.

+

Start training after the buffer is full

-
224            tracker.save()
+
232            if self.replay_buffer.is_full():
@@ -808,23 +814,22 @@ Gradients shouldn’t propagate for these

-

Add a new line to the screen periodically

+

Train the model

-
226            if (update + 1) % 1_000 == 0:
-227                logger.log()
+
234                self.train(beta)
-
+
-

Destroy

-

Stop the workers

+

Periodically update target network

-
229    def destroy(self):
+
237                if update % self.update_target_model == 0:
+238                    self.target_model.load_state_dict(self.model.state_dict())
@@ -832,11 +837,10 @@ Gradients shouldn’t propagate for these

- +

Save tracked indicators.

-
234        for worker in self.workers:
-235            worker.child.send(("close", None))
+
241            tracker.save()
@@ -844,21 +848,23 @@ Gradients shouldn’t propagate for these

- +

Add a new line to the screen periodically

-
238def main():
+
243            if (update + 1) % 1_000 == 0:
+244                logger.log()
-
+
-

Create the experiment

+

Destroy

+

Stop the workers

-
240    experiment.create(name='dqn')
+
246    def destroy(self):
@@ -866,10 +872,11 @@ Gradients shouldn’t propagate for these

-

Initialize the trainer

+
-
242    m = Trainer()
+
251        for worker in self.workers:
+252            worker.child.send(("close", None))
@@ -877,11 +884,10 @@ Gradients shouldn’t propagate for these

-

Run and monitor the experiment

+
-
244    with experiment.start():
-245        m.run_training_loop()
+
255def main():
@@ -889,10 +895,10 @@ Gradients shouldn’t propagate for these

-

Stop the workers

+

Create the experiment

-
247    m.destroy()
+
257    experiment.create(name='dqn')
@@ -900,11 +906,145 @@ Gradients shouldn’t propagate for these

+

Configurations

+
+
+
260    configs = {
+
+ +
+
+ +

Number of updates

+
+
+
262        'updates': 1_000_000,
+
+
+
+
+ +

Number of epochs to train the model with sampled data.

+
+
+
264        'epochs': 8,
+
+
+
+
+ +

Number of worker processes

+
+
+
266        'n_workers': 8,
+
+
+
+
+ +

Number of steps to run on each process for a single update

+
+
+
268        'worker_steps': 4,
+
+
+
+
+ +

Mini batch size

+
+
+
270        'mini_batch_size': 32,
+
+
+
+
+ +

Target model updating interval

+
+
+
272        'update_target_model': 250,
+
+
+
+
+ +

Learning rate.

+
+
+
274        'learning_rate': FloatDynamicHyperParam(2.5e-4, (0, 1e-3)),
+275    }
+
+
+
+
+ +

Configurations

+
+
+
278    experiment.configs(configs)
+
+
+
+
+ +

Initialize the trainer

+
+
+
281    m = Trainer(**configs)
+
+
+
+
+ +

Run and monitor the experiment

+
+
+
283    with experiment.start():
+284        m.run_training_loop()
+
+
+
+
+ +

Stop the workers

+
+
+
286    m.destroy()
+
+
+
+
+

Run it

-
251if __name__ == "__main__":
-252    main()
+
290if __name__ == "__main__":
+291    main()
-
27from typing import Tuple
-28
-29import torch
-30from torch import nn
-31
-32from labml import tracker
-33from labml_helpers.module import Module
-34from labml_nn.rl.dqn.replay_buffer import ReplayBuffer
+
25from typing import Tuple
+26
+27import torch
+28from torch import nn
+29
+30from labml import tracker
+31from labml_helpers.module import Module
+32from labml_nn.rl.dqn.replay_buffer import ReplayBuffer
@@ -158,7 +155,7 @@ the value is taken from $\color{orange}{\theta_i^{-}}$.

-
37class QFuncLoss(Module):
+
35class QFuncLoss(Module):
@@ -169,10 +166,10 @@ the value is taken from $\color{orange}{\theta_i^{-}}$.

-
104    def __init__(self, gamma: float):
-105        super().__init__()
-106        self.gamma = gamma
-107        self.huber_loss = nn.SmoothL1Loss(reduction='none')
+
102    def __init__(self, gamma: float):
+103        super().__init__()
+104        self.gamma = gamma
+105        self.huber_loss = nn.SmoothL1Loss(reduction='none')
@@ -191,9 +188,9 @@ the value is taken from $\color{orange}{\theta_i^{-}}$.

-
109    def forward(self, q: torch.Tensor, action: torch.Tensor, double_q: torch.Tensor,
-110                target_q: torch.Tensor, done: torch.Tensor, reward: torch.Tensor,
-111                weights: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+
107    def forward(self, q: torch.Tensor, action: torch.Tensor, double_q: torch.Tensor,
+108                target_q: torch.Tensor, done: torch.Tensor, reward: torch.Tensor,
+109                weights: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
@@ -204,8 +201,8 @@ the value is taken from $\color{orange}{\theta_i^{-}}$.

$Q(s,a;\theta_i)$

-
123        q_sampled_action = q.gather(-1, action.to(torch.long).unsqueeze(-1)).squeeze(-1)
-124        tracker.add('q_sampled_action', q_sampled_action)
+
121        q_sampled_action = q.gather(-1, action.to(torch.long).unsqueeze(-1)).squeeze(-1)
+122        tracker.add('q_sampled_action', q_sampled_action)
@@ -222,7 +219,7 @@ the value is taken from $\color{orange}{\theta_i^{-}}$.

-
132        with torch.no_grad():
+
130        with torch.no_grad():
@@ -236,7 +233,7 @@ the value is taken from $\color{orange}{\theta_i^{-}}$.

-
136            best_next_action = torch.argmax(double_q, -1)
+
134            best_next_action = torch.argmax(double_q, -1)
@@ -252,7 +249,7 @@ the value is taken from $\color{orange}{\theta_i^{-}}$.

-
142            best_next_q_value = target_q.gather(-1, best_next_action.unsqueeze(-1)).squeeze(-1)
+
140            best_next_q_value = target_q.gather(-1, best_next_action.unsqueeze(-1)).squeeze(-1)
@@ -272,8 +269,8 @@ the next state Q values if the game ended.

-
153            q_update = reward + self.gamma * best_next_q_value * (1 - done)
-154            tracker.add('q_update', q_update)
+
151            q_update = reward + self.gamma * best_next_q_value * (1 - done)
+152            tracker.add('q_update', q_update)
@@ -284,8 +281,8 @@ the next state Q values if the game ended.

Temporal difference error $\delta$ is used to weigh samples in replay buffer

-
157            td_error = q_sampled_action - q_update
-158            tracker.add('td_error', td_error)
+
155            td_error = q_sampled_action - q_update
+156            tracker.add('td_error', td_error)
@@ -297,7 +294,7 @@ the next state Q values if the game ended.

mean squared error loss because it is less sensitive to outliers

-
162        losses = self.huber_loss(q_sampled_action, q_update)
+
160        losses = self.huber_loss(q_sampled_action, q_update)
@@ -308,10 +305,10 @@ mean squared error loss because it is less sensitive to outliers

Get weighted means

-
164        loss = torch.mean(weights * losses)
-165        tracker.add('loss', loss)
-166
-167        return td_error, loss
+
162        loss = torch.mean(weights * losses)
+163        tracker.add('loss', loss)
+164
+165        return td_error, loss

Deep Q Network (DQN) Model

+

Open In Colab +View Run

-
10import torch
-11from torch import nn
-12
-13from labml_helpers.module import Module
+
13import torch
+14from torch import nn
+15
+16from labml_helpers.module import Module
@@ -109,7 +111,7 @@ and in some states the action is significant. Dueling network allows We share the initial layers of the $V$ and $A$ networks.

-
16class Model(Module):
+
19class Model(Module):
@@ -120,9 +122,9 @@ We share the initial layers of the $V$ and $A$ networks.

-
47    def __init__(self):
-48        super().__init__()
-49        self.conv = nn.Sequential(
+
50    def __init__(self):
+51        super().__init__()
+52        self.conv = nn.Sequential(
@@ -134,8 +136,8 @@ We share the initial layers of the $V$ and $A$ networks.

$84\times84$ frame and produces a $20\times20$ frame

-
52            nn.Conv2d(in_channels=4, out_channels=32, kernel_size=8, stride=4),
-53            nn.ReLU(),
+
55            nn.Conv2d(in_channels=4, out_channels=32, kernel_size=8, stride=4),
+56            nn.ReLU(),
@@ -147,8 +149,8 @@ $84\times84$ frame and produces a $20\times20$ frame

$20\times20$ frame and produces a $9\times9$ frame

-
57            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2),
-58            nn.ReLU(),
+
60            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2),
+61            nn.ReLU(),
@@ -160,9 +162,9 @@ $20\times20$ frame and produces a $9\times9$ frame

$9\times9$ frame and produces a $7\times7$ frame

-
62            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1),
-63            nn.ReLU(),
-64        )
+
65            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1),
+66            nn.ReLU(),
+67        )
@@ -175,8 +177,8 @@ frame from third convolution layer, and outputs $512$ features

-
69        self.lin = nn.Linear(in_features=7 * 7 * 64, out_features=512)
-70        self.activation = nn.ReLU()
+
72        self.lin = nn.Linear(in_features=7 * 7 * 64, out_features=512)
+73        self.activation = nn.ReLU()
@@ -187,11 +189,11 @@ $512$ features

This head gives the state value $V$

-
73        self.state_value = nn.Sequential(
-74            nn.Linear(in_features=512, out_features=256),
-75            nn.ReLU(),
-76            nn.Linear(in_features=256, out_features=1),
-77        )
+
76        self.state_value = nn.Sequential(
+77            nn.Linear(in_features=512, out_features=256),
+78            nn.ReLU(),
+79            nn.Linear(in_features=256, out_features=1),
+80        )
@@ -202,11 +204,11 @@ $512$ features

This head gives the action value $A$

-
79        self.action_value = nn.Sequential(
-80            nn.Linear(in_features=512, out_features=256),
-81            nn.ReLU(),
-82            nn.Linear(in_features=256, out_features=4),
-83        )
+
82        self.action_value = nn.Sequential(
+83            nn.Linear(in_features=512, out_features=256),
+84            nn.ReLU(),
+85            nn.Linear(in_features=256, out_features=4),
+86        )
@@ -217,7 +219,7 @@ $512$ features

-
85    def forward(self, obs: torch.Tensor):
+
88    def forward(self, obs: torch.Tensor):
@@ -228,7 +230,7 @@ $512$ features

Convolution

-
87        h = self.conv(obs)
+
90        h = self.conv(obs)
@@ -239,7 +241,7 @@ $512$ features

Reshape for linear layers

-
89        h = h.reshape((-1, 7 * 7 * 64))
+
92        h = h.reshape((-1, 7 * 7 * 64))
@@ -250,7 +252,7 @@ $512$ features

Linear layer

-
92        h = self.activation(self.lin(h))
+
95        h = self.activation(self.lin(h))
@@ -261,7 +263,7 @@ $512$ features

$A$

-
95        action_value = self.action_value(h)
+
98        action_value = self.action_value(h)
@@ -272,7 +274,7 @@ $512$ features

$V$

-
97        state_value = self.state_value(h)
+
100        state_value = self.state_value(h)
@@ -283,7 +285,7 @@ $512$ features

$A(s, a) - \frac{1}{|\mathcal{A}|} \sum_{a’ \in \mathcal{A}} A(s, a’)$

-
100        action_score_centered = action_value - action_value.mean(dim=-1, keepdim=True)
+
103        action_score_centered = action_value - action_value.mean(dim=-1, keepdim=True)
@@ -294,9 +296,9 @@ $512$ features

$Q(s, a) =V(s) + \Big(A(s, a) - \frac{1}{|\mathcal{A}|} \sum_{a’ \in \mathcal{A}} A(s, a’)\Big)$

-
102        q = state_value + action_score_centered
-103
-104        return q
+
105        q = state_value + action_score_centered
+106
+107        return q
-
13import random
-14
-15import numpy as np
+
16import random
+17
+18import numpy as np
@@ -142,7 +144,7 @@ That is,

We use the same structure to compute the minimum.

-
18class ReplayBuffer:
+
21class ReplayBuffer:
@@ -153,7 +155,7 @@ That is,

Initialize

-
88    def __init__(self, capacity, alpha):
+
91    def __init__(self, capacity, alpha):
@@ -164,7 +166,7 @@ That is,

We use a power of $2$ for capacity because it simplifies the code and debugging

-
93        self.capacity = capacity
+
96        self.capacity = capacity
@@ -175,7 +177,7 @@ That is,

$\alpha$

-
95        self.alpha = alpha
+
98        self.alpha = alpha
@@ -186,8 +188,8 @@ That is,

Maintain segment binary trees to take sum and find minimum over a range

-
98        self.priority_sum = [0 for _ in range(2 * self.capacity)]
-99        self.priority_min = [float('inf') for _ in range(2 * self.capacity)]
+
101        self.priority_sum = [0 for _ in range(2 * self.capacity)]
+102        self.priority_min = [float('inf') for _ in range(2 * self.capacity)]
@@ -198,7 +200,7 @@ That is,

Current max priority, $p$, to be assigned to new transitions

-
102        self.max_priority = 1.
+
105        self.max_priority = 1.
@@ -209,13 +211,13 @@ That is,

Arrays for buffer

-
105        self.data = {
-106            'obs': np.zeros(shape=(capacity, 4, 84, 84), dtype=np.uint8),
-107            'action': np.zeros(shape=capacity, dtype=np.int32),
-108            'reward': np.zeros(shape=capacity, dtype=np.float32),
-109            'next_obs': np.zeros(shape=(capacity, 4, 84, 84), dtype=np.uint8),
-110            'done': np.zeros(shape=capacity, dtype=np.bool)
-111        }
+
108        self.data = {
+109            'obs': np.zeros(shape=(capacity, 4, 84, 84), dtype=np.uint8),
+110            'action': np.zeros(shape=capacity, dtype=np.int32),
+111            'reward': np.zeros(shape=capacity, dtype=np.float32),
+112            'next_obs': np.zeros(shape=(capacity, 4, 84, 84), dtype=np.uint8),
+113            'done': np.zeros(shape=capacity, dtype=np.bool)
+114        }
@@ -227,7 +229,7 @@ That is, slot

-
114        self.next_idx = 0
+
117        self.next_idx = 0
@@ -238,7 +240,7 @@ slot

Size of the buffer

-
117        self.size = 0
+
120        self.size = 0
@@ -249,7 +251,7 @@ slot

Add sample to queue

-
119    def add(self, obs, action, reward, next_obs, done):
+
122    def add(self, obs, action, reward, next_obs, done):
@@ -260,7 +262,7 @@ slot

Get next available slot

-
125        idx = self.next_idx
+
128        idx = self.next_idx
@@ -271,11 +273,11 @@ slot

store in the queue

-
128        self.data['obs'][idx] = obs
-129        self.data['action'][idx] = action
-130        self.data['reward'][idx] = reward
-131        self.data['next_obs'][idx] = next_obs
-132        self.data['done'][idx] = done
+
131        self.data['obs'][idx] = obs
+132        self.data['action'][idx] = action
+133        self.data['reward'][idx] = reward
+134        self.data['next_obs'][idx] = next_obs
+135        self.data['done'][idx] = done
@@ -286,7 +288,7 @@ slot

Increment next available slot

-
135        self.next_idx = (idx + 1) % self.capacity
+
138        self.next_idx = (idx + 1) % self.capacity
@@ -297,7 +299,7 @@ slot

Calculate the size

-
137        self.size = min(self.capacity, self.size + 1)
+
140        self.size = min(self.capacity, self.size + 1)
@@ -308,7 +310,7 @@ slot

$p_i^\alpha$, new samples get max_priority

-
140        priority_alpha = self.max_priority ** self.alpha
+
143        priority_alpha = self.max_priority ** self.alpha
@@ -319,8 +321,8 @@ slot

Update the two segment trees for sum and minimum

-
142        self._set_priority_min(idx, priority_alpha)
-143        self._set_priority_sum(idx, priority_alpha)
+
145        self._set_priority_min(idx, priority_alpha)
+146        self._set_priority_sum(idx, priority_alpha)
@@ -331,7 +333,7 @@ slot

Set priority in binary segment tree for minimum

-
145    def _set_priority_min(self, idx, priority_alpha):
+
148    def _set_priority_min(self, idx, priority_alpha):
@@ -342,8 +344,8 @@ slot

Leaf of the binary tree

-
151        idx += self.capacity
-152        self.priority_min[idx] = priority_alpha
+
154        idx += self.capacity
+155        self.priority_min[idx] = priority_alpha
@@ -355,7 +357,7 @@ slot

Continue until the root of the tree.

-
156        while idx >= 2:
+
159        while idx >= 2:
@@ -366,7 +368,7 @@ Continue until the root of the tree.

Get the index of the parent node

-
158            idx //= 2
+
161            idx //= 2
@@ -377,7 +379,7 @@ Continue until the root of the tree.

Value of the parent node is the minimum of it’s two children

-
160            self.priority_min[idx] = min(self.priority_min[2 * idx], self.priority_min[2 * idx + 1])
+
163            self.priority_min[idx] = min(self.priority_min[2 * idx], self.priority_min[2 * idx + 1])
@@ -388,7 +390,7 @@ Continue until the root of the tree.

Set priority in binary segment tree for sum

-
162    def _set_priority_sum(self, idx, priority):
+
165    def _set_priority_sum(self, idx, priority):
@@ -399,7 +401,7 @@ Continue until the root of the tree.

Leaf of the binary tree

-
168        idx += self.capacity
+
171        idx += self.capacity
@@ -410,7 +412,7 @@ Continue until the root of the tree.

Set the priority at the leaf

-
170        self.priority_sum[idx] = priority
+
173        self.priority_sum[idx] = priority
@@ -422,7 +424,7 @@ Continue until the root of the tree.

Continue until the root of the tree.

-
174        while idx >= 2:
+
177        while idx >= 2:
@@ -433,7 +435,7 @@ Continue until the root of the tree.

Get the index of the parent node

-
176            idx //= 2
+
179            idx //= 2
@@ -444,7 +446,7 @@ Continue until the root of the tree.

Value of the parent node is the sum of it’s two children

-
178            self.priority_sum[idx] = self.priority_sum[2 * idx] + self.priority_sum[2 * idx + 1]
+
181            self.priority_sum[idx] = self.priority_sum[2 * idx] + self.priority_sum[2 * idx + 1]
@@ -455,7 +457,7 @@ Continue until the root of the tree.

$\sum_k p_k^\alpha$

-
180    def _sum(self):
+
183    def _sum(self):
@@ -466,7 +468,7 @@ Continue until the root of the tree.

The root node keeps the sum of all values

-
186        return self.priority_sum[1]
+
189        return self.priority_sum[1]
@@ -477,7 +479,7 @@ Continue until the root of the tree.

$\min_k p_k^\alpha$

-
188    def _min(self):
+
191    def _min(self):
@@ -488,7 +490,7 @@ Continue until the root of the tree.

The root node keeps the minimum of all values

-
194        return self.priority_min[1]
+
197        return self.priority_min[1]
@@ -499,7 +501,7 @@ Continue until the root of the tree.

Find largest $i$ such that $\sum_{k=1}^{i} p_k^\alpha \le P$

-
196    def find_prefix_sum_idx(self, prefix_sum):
+
199    def find_prefix_sum_idx(self, prefix_sum):
@@ -510,8 +512,8 @@ Continue until the root of the tree.

Start from the root

-
202        idx = 1
-203        while idx < self.capacity:
+
205        idx = 1
+206        while idx < self.capacity:
@@ -522,7 +524,7 @@ Continue until the root of the tree.

If the sum of the left branch is higher than required sum

-
205            if self.priority_sum[idx * 2] > prefix_sum:
+
208            if self.priority_sum[idx * 2] > prefix_sum:
@@ -533,8 +535,8 @@ Continue until the root of the tree.

Go to left branch of the tree

-
207                idx = 2 * idx
-208            else:
+
210                idx = 2 * idx
+211            else:
@@ -546,8 +548,8 @@ Continue until the root of the tree.

branch from required sum

-
211                prefix_sum -= self.priority_sum[idx * 2]
-212                idx = 2 * idx + 1
+
214                prefix_sum -= self.priority_sum[idx * 2]
+215                idx = 2 * idx + 1
@@ -559,7 +561,7 @@ Continue until the root of the tree.

to get the index of actual value

-
216        return idx - self.capacity
+
219        return idx - self.capacity
@@ -570,7 +572,7 @@ to get the index of actual value

Sample from buffer

-
218    def sample(self, batch_size, beta):
+
221    def sample(self, batch_size, beta):
@@ -581,10 +583,10 @@ to get the index of actual value

Initialize samples

-
224        samples = {
-225            'weights': np.zeros(shape=batch_size, dtype=np.float32),
-226            'indexes': np.zeros(shape=batch_size, dtype=np.int32)
-227        }
+
227        samples = {
+228            'weights': np.zeros(shape=batch_size, dtype=np.float32),
+229            'indexes': np.zeros(shape=batch_size, dtype=np.int32)
+230        }
@@ -595,10 +597,10 @@ to get the index of actual value

Get sample indexes

-
230        for i in range(batch_size):
-231            p = random.random() * self._sum()
-232            idx = self.find_prefix_sum_idx(p)
-233            samples['indexes'][i] = idx
+
233        for i in range(batch_size):
+234            p = random.random() * self._sum()
+235            idx = self.find_prefix_sum_idx(p)
+236            samples['indexes'][i] = idx
@@ -609,7 +611,7 @@ to get the index of actual value

$\min_i P(i) = \frac{\min_i p_i^\alpha}{\sum_k p_k^\alpha}$

-
236        prob_min = self._min() / self._sum()
+
239        prob_min = self._min() / self._sum()
@@ -620,10 +622,10 @@ to get the index of actual value

$\max_i w_i = \bigg(\frac{1}{N} \frac{1}{\min_i P(i)}\bigg)^\beta$

-
238        max_weight = (prob_min * self.size) ** (-beta)
-239
-240        for i in range(batch_size):
-241            idx = samples['indexes'][i]
+
241        max_weight = (prob_min * self.size) ** (-beta)
+242
+243        for i in range(batch_size):
+244            idx = samples['indexes'][i]
@@ -634,7 +636,7 @@ to get the index of actual value

$P(i) = \frac{p_i^\alpha}{\sum_k p_k^\alpha}$

-
243            prob = self.priority_sum[idx + self.capacity] / self._sum()
+
246            prob = self.priority_sum[idx + self.capacity] / self._sum()
@@ -645,7 +647,7 @@ to get the index of actual value

$w_i = \bigg(\frac{1}{N} \frac{1}{P(i)}\bigg)^\beta$

-
245            weight = (prob * self.size) ** (-beta)
+
248            weight = (prob * self.size) ** (-beta)
@@ -657,7 +659,7 @@ to get the index of actual value

which also cancels off the $\frac{1}{N}$ term

-
248            samples['weights'][i] = weight / max_weight
+
251            samples['weights'][i] = weight / max_weight
@@ -668,10 +670,10 @@ to get the index of actual value

Get samples data

-
251        for k, v in self.data.items():
-252            samples[k] = v[samples['indexes']]
-253
-254        return samples
+
254        for k, v in self.data.items():
+255            samples[k] = v[samples['indexes']]
+256
+257        return samples
@@ -682,7 +684,7 @@ to get the index of actual value

Update priorities

-
256    def update_priorities(self, indexes, priorities):
+
259    def update_priorities(self, indexes, priorities):
@@ -693,7 +695,7 @@ to get the index of actual value

-
261        for idx, priority in zip(indexes, priorities):
+
264        for idx, priority in zip(indexes, priorities):
@@ -704,7 +706,7 @@ to get the index of actual value

Set current max priority

-
263            self.max_priority = max(self.max_priority, priority)
+
266            self.max_priority = max(self.max_priority, priority)
@@ -715,7 +717,7 @@ to get the index of actual value

Calculate $p_i^\alpha$

-
266            priority_alpha = priority ** self.alpha
+
269            priority_alpha = priority ** self.alpha
@@ -726,8 +728,8 @@ to get the index of actual value

Update the trees

-
268            self._set_priority_min(idx, priority_alpha)
-269            self._set_priority_sum(idx, priority_alpha)
+
271            self._set_priority_min(idx, priority_alpha)
+272            self._set_priority_sum(idx, priority_alpha)
@@ -738,7 +740,7 @@ to get the index of actual value

Whether the buffer is full

-
271    def is_full(self):
+
274    def is_full(self):
@@ -749,7 +751,7 @@ to get the index of actual value

-
275        return self.capacity == self.size
+
278        return self.capacity == self.size
-
391    m = Trainer(
-392        updates=configs['updates'],
-393        epochs=configs['epochs'],
-394        n_workers=configs['n_workers'],
-395        worker_steps=configs['worker_steps'],
-396        batches=configs['batches'],
-397        value_loss_coef=configs['value_loss_coef'],
-398        entropy_bonus_coef=configs['entropy_bonus_coef'],
-399        clip_range=configs['clip_range'],
-400        learning_rate=configs['learning_rate'],
-401    )
+
391    m = Trainer(**configs)
@@ -1232,8 +1222,8 @@ You can change this while the experiment is running.

Run and monitor the experiment

-
404    with experiment.start():
-405        m.run_training_loop()
+
394    with experiment.start():
+395        m.run_training_loop()
@@ -1244,7 +1234,7 @@ You can change this while the experiment is running.

Stop the workers

-
407    m.destroy()
+
397    m.destroy()
@@ -1255,8 +1245,8 @@ You can change this while the experiment is running.

Run it

-
411if __name__ == "__main__":
-412    main()
+
401if __name__ == "__main__":
+402    main()