From 2ddfe1b252d03dc798e9f0b31c146344dfbd3855 Mon Sep 17 00:00:00 2001 From: Varuna Jayasiri Date: Sat, 24 Oct 2020 18:53:02 +0530 Subject: [PATCH] =?UTF-8?q?=F0=9F=93=9A=20annotations?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- labml_nn/rl/dqn/__init__.py | 138 +++++++++++++++++++++++--- labml_nn/rl/dqn/experiment.py | 165 +------------------------------ labml_nn/rl/dqn/model.py | 100 +++++++++++++++++++ labml_nn/rl/dqn/replay_buffer.py | 28 ++++-- 4 files changed, 244 insertions(+), 187 deletions(-) create mode 100644 labml_nn/rl/dqn/model.py diff --git a/labml_nn/rl/dqn/__init__.py b/labml_nn/rl/dqn/__init__.py index b23f9d54..d4d1c522 100644 --- a/labml_nn/rl/dqn/__init__.py +++ b/labml_nn/rl/dqn/__init__.py @@ -1,13 +1,24 @@ """ -This is a Deep Q Learning implementation with: +# Deep Q Networks + +This is a Deep Q Learning implementation that uses: + +* [Dueling Network](model.html) +* [Prioritized Replay](replay_buffer.html) * Double Q Network -* Dueling Network -* Prioritized Replay + +Here's the [experiment](experiment.html) and [model](model.html). + +\( + \def\green#1{{\color{yellowgreen}{#1}}} +\) + """ from typing import Tuple import torch +from torch import nn from labml import tracker from labml_helpers.module import Module @@ -15,36 +26,133 @@ from labml_nn.rl.dqn.replay_buffer import ReplayBuffer class QFuncLoss(Module): + """ + ## Train the model + + We want to find optimal action-value function. + + \begin{align} + Q^*(s,a) &= \max_\pi \mathbb{E} \Big[ + r_t + \gamma r_{t + 1} + \gamma^2 r_{t + 2} + ... | s_t = s, a_t = a, \pi + \Big] + \\ + Q^*(s,a) &= \mathop{\mathbb{E}}_{s' \sim \large{\varepsilon}} \Big[ + r + \gamma \max_{a'} Q^* (s', a') | s, a + \Big] + \end{align} + + ### Target network 🎯 + In order to improve stability we use experience replay that randomly sample + from previous experience $U(D)$. We also use a Q network + with a separate set of paramters $\color{orangle}{\theta_i^{-}}$ to calculate the target. + $\color{orangle}{\theta_i^{-}}$ is updated periodically. + This is according to paper + [Human Level Control Through Deep Reinforcement Learning](https://deepmind.com/research/dqn/). + + So the loss function is, + $$ + \mathcal{L}_i(\theta_i) = \mathop{\mathbb{E}}_{(s,a,r,s') \sim U(D)} + \bigg[ + \Big( + r + \gamma \max_{a'} Q(s', a'; \color{orange}{\theta_i^{-}}) - Q(s,a;\theta_i) + \Big) ^ 2 + \bigg] + $$ + + ### Double $Q$-Learning + The max operator in the above calculation uses same network for both + selecting the best action and for evaluating the value. + That is, + $$ + \max_{a'} Q(s', a'; \theta) = \color{cyan}{Q} + \Big( + s', \mathop{\operatorname{argmax}}_{a'} + \color{cyan}{Q}(s', a'; \color{cyan}{\theta}); \color{cyan}{\theta} + \Big) + $$ + We use [double Q-learning](https://arxiv.org/abs/1509.06461), where + the $\operatorname{argmax}$ is taken from $\color{cyan}{\theta_i}$ and + the value is taken from $\color{orange}{\theta_i^{-}}$. + + And the loss function becomes, + \begin{align} + \mathcal{L}_i(\theta_i) = \mathop{\mathbb{E}}_{(s,a,r,s') \sim U(D)} + \Bigg[ + \bigg( + &r + \gamma \color{orange}{Q} + \Big( + s', + \mathop{\operatorname{argmax}}_{a'} + \color{cyan}{Q}(s', a'; \color{cyan}{\theta_i}); \color{orange}{\theta_i^{-}} + \Big) + \\ + - &Q(s,a;\theta_i) + \bigg) ^ 2 + \Bigg] + \end{align} + """ + def __init__(self, gamma: float): super().__init__() self.gamma = gamma + self.huber_loss = nn.SmoothL1Loss(reduction='none') - def __call__(self, q: torch.Tensor, - action: torch.Tensor, - double_q: torch.Tensor, - target_q: torch.Tensor, - done: torch.Tensor, - reward: torch.Tensor, + def __call__(self, q: torch.Tensor, action: torch.Tensor, double_q: torch.Tensor, + target_q: torch.Tensor, done: torch.Tensor, reward: torch.Tensor, weights: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + * `q` - $Q(s;\theta_i)$ + * `action` - $a$ + * `double_q` - $\color{cyan}Q(s';\color{cyan}{\theta_i})$ + * `target_q` - $\color{orange}Q(s';\color{orange}{\theta_i^{-}})$ + * `done` - whether the game ended after taking the action + * `reward` - $r$ + * `weights` - weights of the samples from prioritized experienced replay + """ + + # $Q(s,a;\theta_i)$ q_sampled_action = q.gather(-1, action.to(torch.long).unsqueeze(-1)).squeeze(-1) tracker.add('q_sampled_action', q_sampled_action) + # Gradients shouldn't propagate gradients + # $$r + \gamma \color{orange}{Q} + # \Big(s', + # \mathop{\operatorname{argmax}}_{a'} + # \color{cyan}{Q}(s', a'; \color{cyan}{\theta_i}); \color{orange}{\theta_i^{-}} + # \Big)$$ with torch.no_grad(): + # Get the best action at state $s'$ + # $$\mathop{\operatorname{argmax}}_{a'} + # \color{cyan}{Q}(s', a'; \color{cyan}{\theta_i})$$ best_next_action = torch.argmax(double_q, -1) + # Get the q value from the target network for the best action at state $s'$ + # $$\color{orange}{Q} + # \Big(s',\mathop{\operatorname{argmax}}_{a'} + # \color{cyan}{Q}(s', a'; \color{cyan}{\theta_i}); \color{orange}{\theta_i^{-}} + # \Big)$$ best_next_q_value = target_q.gather(-1, best_next_action.unsqueeze(-1)).squeeze(-1) - best_next_q_value *= (1 - done) - - q_update = reward + self.gamma * best_next_q_value + # Calculate the desired Q value. + # We multiply by `(1 - done)` to zero out + # the next state Q values if the game ended. + # + # $$r + \gamma \color{orange}{Q} + # \Big(s', + # \mathop{\operatorname{argmax}}_{a'} + # \color{cyan}{Q}(s', a'; \color{cyan}{\theta_i}); \color{orange}{\theta_i^{-}} + # \Big)$$ + q_update = reward + self.gamma * best_next_q_value * (1 - done) tracker.add('q_update', q_update) + # Temporal difference error $\delta$ is used to weigh samples in replay buffer td_error = q_sampled_action - q_update tracker.add('td_error', td_error) - # Huber loss - losses = torch.nn.functional.smooth_l1_loss(q_sampled_action, q_update, reduction='none') + # We take [Huber loss](https://en.wikipedia.org/wiki/Huber_loss) instead of + # mean squared error loss because it is less sensitive to outliers + losses = self.huber_loss(q_sampled_action, q_update) + # Get weighted means loss = torch.mean(weights * losses) tracker.add('loss', loss) return td_error, loss - diff --git a/labml_nn/rl/dqn/experiment.py b/labml_nn/rl/dqn/experiment.py index 817f7cf0..bc7078fb 100644 --- a/labml_nn/rl/dqn/experiment.py +++ b/labml_nn/rl/dqn/experiment.py @@ -8,12 +8,11 @@ import numpy as np import torch -from torch import nn from labml import tracker, experiment, logger, monit -from labml_helpers.module import Module from labml_helpers.schedule import Piecewise from labml_nn.rl.dqn import QFuncLoss +from labml_nn.rl.dqn.model import Model from labml_nn.rl.dqn.replay_buffer import ReplayBuffer from labml_nn.rl.game import Worker @@ -23,105 +22,12 @@ else: device = torch.device("cpu") -class Model(Module): - """ - ## Neural Network Model for $Q$ Values - - #### Dueling Network ⚔️ - We are using a [dueling network](https://arxiv.org/abs/1511.06581) - to calculate Q-values. - Intuition behind dueling network architure is that in most states - the action doesn't matter, - and in some states the action is significant. Dueling network allows - this to be represented very well. - - \begin{align} - Q^\pi(s,a) &= V^\pi(s) + A^\pi(s, a) - \\ - \mathop{\mathbb{E}}_{a \sim \pi(s)} - \Big[ - A^\pi(s, a) - \Big] - &= 0 - \end{align} - - So we create two networks for $V$ and $A$ and get $Q$ from them. - $$ - Q(s, a) = V(s) + - \Big( - A(s, a) - \frac{1}{|\mathcal{A}|} \sum_{a' \in \mathcal{A}} A(s, a') - \Big) - $$ - We share the initial layers of the $V$ and $A$ networks. - """ - - def __init__(self): - """ - ### Initialize - - We need `scope` because we need multiple copies of variables - for target network and training network. - """ - - super().__init__() - self.conv = nn.Sequential( - # The first convolution layer takes a - # 84x84 frame and produces a 20x20 frame - nn.Conv2d(in_channels=4, out_channels=32, kernel_size=8, stride=4), - nn.ReLU(), - - # The second convolution layer takes a - # 20x20 frame and produces a 9x9 frame - nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2), - nn.ReLU(), - - # The third convolution layer takes a - # 9x9 frame and produces a 7x7 frame - nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1), - nn.ReLU(), - ) - - # A fully connected layer takes the flattened - # frame from third convolution layer, and outputs - # 512 features - self.lin = nn.Linear(in_features=7 * 7 * 64, out_features=512) - - self.state_score = nn.Sequential( - nn.Linear(in_features=512, out_features=256), - nn.ReLU(), - nn.Linear(in_features=256, out_features=1), - ) - self.action_score = nn.Sequential( - nn.Linear(in_features=512, out_features=256), - nn.ReLU(), - nn.Linear(in_features=256, out_features=4), - ) - - # - self.activation = nn.ReLU() - - def __call__(self, obs: torch.Tensor): - h = self.conv(obs) - h = h.reshape((-1, 7 * 7 * 64)) - - h = self.activation(self.lin(h)) - - action_score = self.action_score(h) - state_score = self.state_score(h) - - # $Q(s, a) =V(s) + \Big(A(s, a) - \frac{1}{|\mathcal{A}|} \sum_{a' \in \mathcal{A}} A(s, a')\Big)$ - action_score_centered = action_score - action_score.mean(dim=-1, keepdim=True) - q = state_score + action_score_centered - - return q - - def obs_to_torch(obs: np.ndarray) -> torch.Tensor: """Scale observations from `[0, 255]` to `[0, 1]`""" return torch.tensor(obs, dtype=torch.float32, device=device) / 255. -class Main(object): +class Trainer: """ ## Main class This class runs the training loop. @@ -239,71 +145,6 @@ class Main(object): self.obs[w] = next_obs def train(self, beta: float): - """ - ### Train the model - - We want to find optimal action-value function. - - \begin{align} - Q^*(s,a) &= \max_\pi \mathbb{E} \Big[ - r_t + \gamma r_{t + 1} + \gamma^2 r_{t + 2} + ... | s_t = s, a_t = a, \pi - \Big] - \\ - Q^*(s,a) &= \mathop{\mathbb{E}}_{s' \sim \large{\varepsilon}} \Big[ - r + \gamma \max_{a'} Q^* (s', a') | s, a - \Big] - \end{align} - - #### Target network 🎯 - In order to improve stability we use experience replay that randomly sample - from previous experience $U(D)$. We also use a Q network - with a separate set of paramters $\hl1{\theta_i^{-}}$ to calculate the target. - $\hl1{\theta_i^{-}}$ is updated periodically. - This is according to the [paper by DeepMind](https://deepmind.com/research/dqn/). - - So the loss function is, - $$ - \mathcal{L}_i(\theta_i) = \mathop{\mathbb{E}}_{(s,a,r,s') \sim U(D)} - \bigg[ - \Big( - r + \gamma \max_{a'} Q(s', a'; \hl1{\theta_i^{-}}) - Q(s,a;\theta_i) - \Big) ^ 2 - \bigg] - $$ - - #### Double $Q$-Learning - The max operator in the above calculation uses same network for both - selecting the best action and for evaluating the value. - That is, - $$ - \max_{a'} Q(s', a'; \theta) = \blue{Q} - \Big( - s', \mathop{\operatorname{argmax}}_{a'} - \blue{Q}(s', a'; \blue{\theta}); \blue{\theta} - \Big) - $$ - We use [double Q-learning](https://arxiv.org/abs/1509.06461), where - the $\operatorname{argmax}$ is taken from $\theta_i$ and - the value is taken from $\theta_i^{-}$. - - And the loss function becomes, - \begin{align} - \mathcal{L}_i(\theta_i) = \mathop{\mathbb{E}}_{(s,a,r,s') \sim U(D)} - \Bigg[ - \bigg( - &r + \gamma \blue{Q} - \Big( - s', - \mathop{\operatorname{argmax}}_{a'} - \green{Q}(s', a'; \green{\theta_i}); \blue{\theta_i^{-}} - \Big) - \\ - - &Q(s,a;\theta_i) - \bigg) ^ 2 - \Bigg] - \end{align} - """ - for _ in range(self.train_epochs): # sample from priority replay buffer samples = self.replay_buffer.sample(self.mini_batch_size, beta) @@ -379,7 +220,7 @@ class Main(object): # ## Run it if __name__ == "__main__": experiment.create(name='dqn') - m = Main() + m = Trainer() with experiment.start(): m.run_training_loop() m.destroy() diff --git a/labml_nn/rl/dqn/model.py b/labml_nn/rl/dqn/model.py new file mode 100644 index 00000000..0e40e349 --- /dev/null +++ b/labml_nn/rl/dqn/model.py @@ -0,0 +1,100 @@ +""" +# Neural Network Model +""" + +import torch +from torch import nn + +from labml_helpers.module import Module + + +class Model(Module): + """ + ## Dueling Network ⚔️ Model for $Q$ Values + + We are using a [dueling network](https://arxiv.org/abs/1511.06581) + to calculate Q-values. + Intuition behind dueling network architecture is that in most states + the action doesn't matter, + and in some states the action is significant. Dueling network allows + this to be represented very well. + + \begin{align} + Q^\pi(s,a) &= V^\pi(s) + A^\pi(s, a) + \\ + \mathop{\mathbb{E}}_{a \sim \pi(s)} + \Big[ + A^\pi(s, a) + \Big] + &= 0 + \end{align} + + So we create two networks for $V$ and $A$ and get $Q$ from them. + $$ + Q(s, a) = V(s) + + \Big( + A(s, a) - \frac{1}{|\mathcal{A}|} \sum_{a' \in \mathcal{A}} A(s, a') + \Big) + $$ + We share the initial layers of the $V$ and $A$ networks. + """ + + def __init__(self): + """ + ### Initialize + + We need `scope` because we need multiple copies of variables + for target network and training network. + """ + + super().__init__() + self.conv = nn.Sequential( + # The first convolution layer takes a + # 84x84 frame and produces a 20x20 frame + nn.Conv2d(in_channels=4, out_channels=32, kernel_size=8, stride=4), + nn.ReLU(), + + # The second convolution layer takes a + # 20x20 frame and produces a 9x9 frame + nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2), + nn.ReLU(), + + # The third convolution layer takes a + # 9x9 frame and produces a 7x7 frame + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1), + nn.ReLU(), + ) + + # A fully connected layer takes the flattened + # frame from third convolution layer, and outputs + # 512 features + self.lin = nn.Linear(in_features=7 * 7 * 64, out_features=512) + + self.state_score = nn.Sequential( + nn.Linear(in_features=512, out_features=256), + nn.ReLU(), + nn.Linear(in_features=256, out_features=1), + ) + self.action_score = nn.Sequential( + nn.Linear(in_features=512, out_features=256), + nn.ReLU(), + nn.Linear(in_features=256, out_features=4), + ) + + # + self.activation = nn.ReLU() + + def __call__(self, obs: torch.Tensor): + h = self.conv(obs) + h = h.reshape((-1, 7 * 7 * 64)) + + h = self.activation(self.lin(h)) + + action_score = self.action_score(h) + state_score = self.state_score(h) + + # $Q(s, a) =V(s) + \Big(A(s, a) - \frac{1}{|\mathcal{A}|} \sum_{a' \in \mathcal{A}} A(s, a')\Big)$ + action_score_centered = action_score - action_score.mean(dim=-1, keepdim=True) + q = state_score + action_score_centered + + return q diff --git a/labml_nn/rl/dqn/replay_buffer.py b/labml_nn/rl/dqn/replay_buffer.py index 48d79f44..86dded7d 100644 --- a/labml_nn/rl/dqn/replay_buffer.py +++ b/labml_nn/rl/dqn/replay_buffer.py @@ -1,3 +1,10 @@ +""" +# Prioritized Experience Replace Buffer + +This implements paper [Prioritized experience replay](https://arxiv.org/abs/1511.05952), +using a binary segment tree. +""" + import numpy as np import random @@ -5,9 +12,10 @@ import random class ReplayBuffer: """ ## Buffer for Prioritized Experience Replay + [Prioritized experience replay](https://arxiv.org/abs/1511.05952) samples important transitions more frequently. - The transitions are prioritized by the Temporal Difference error. + The transitions are prioritized by the Temporal Difference error (td error). We sample transition $i$ with probability, $$P(i) = \frac{p_i^\alpha}{\sum_k p_k^\alpha}$$ @@ -21,16 +29,16 @@ class ReplayBuffer: importance-sampling (IS) weights $$w_i = \bigg(\frac{1}{N} \frac{1}{P(i)}\bigg)^\beta$$ that fully compensates for when $\beta = 1$. - We normalize weights by $1/\max_i w_i$ for stability. + We normalize weights by $\frac{1}{\max_i w_i}$ for stability. Unbiased nature is most important towards the convergence at end of training. Therefore we increase $\beta$ towards end of training. - ### Binary Segment Trees - We use binary segment trees to efficiently calculate + ### Binary Segment Tree + We use a binary segment tree to efficiently calculate $\sum_k^i p_k^\alpha$, the cumulative probability, which is needed to sample. We also use a binary segment tree to find $\min p_i^\alpha$, - which is needed for $1/\max_i w_i$. + which is needed for $\frac{1}{\max_i w_i}$. We can also use a min-heap for this. This is how a binary segment tree works for sum; @@ -54,14 +62,16 @@ class ReplayBuffer: $$N_i = \left\lceil{\frac{N}{D - i + i}} \right\rceil$$ This is equal to the sum of nodes in all rows above $i$. So we can use a single array $a$ to store the tree, where, - $$b_{i,j} = a_{N_1 + j}$$ + $$b_{i,j} \rightarrow a_{N_i + j}$$ Then child nodes of $a_i$ are $a_{2i}$ and $a_{2i + 1}$. That is, $$a_i = a_{2i} + a_{2i + 1}$$ This way of maintaining binary trees is very easy to program. - *Note that we are indexing from 1*. + *Note that we are indexing starting from 1*. + + We using the same structure to compute the minimum. """ def __init__(self, capacity, alpha): @@ -206,7 +216,7 @@ class ReplayBuffer: # $w_i = \bigg(\frac{1}{N} \frac{1}{P(i)}\bigg)^\beta$ weight = (prob * self.size) ** (-beta) # normalize by $\frac{1}{\max_i w_i}$, - # which also cancels off the $\frac{1}/{N}$ term + # which also cancels off the $\frac{1}{N}$ term samples['weights'][i] = weight / max_weight # get samples data @@ -230,7 +240,5 @@ class ReplayBuffer: def is_full(self): """ ### Is the buffer full - - We only start sampling afte the buffer is full. """ return self.capacity == self.size