📚 annotations

This commit is contained in:
Varuna Jayasiri
2020-10-24 18:53:02 +05:30
parent 40be459a92
commit 2ddfe1b252
4 changed files with 244 additions and 187 deletions

View File

@ -1,13 +1,24 @@
"""
This is a Deep Q Learning implementation with:
# Deep Q Networks
This is a Deep Q Learning implementation that uses:
* [Dueling Network](model.html)
* [Prioritized Replay](replay_buffer.html)
* Double Q Network
* Dueling Network
* Prioritized Replay
Here's the [experiment](experiment.html) and [model](model.html).
\(
\def\green#1{{\color{yellowgreen}{#1}}}
\)
"""
from typing import Tuple
import torch
from torch import nn
from labml import tracker
from labml_helpers.module import Module
@ -15,36 +26,133 @@ from labml_nn.rl.dqn.replay_buffer import ReplayBuffer
class QFuncLoss(Module):
"""
## Train the model
We want to find optimal action-value function.
\begin{align}
Q^*(s,a) &= \max_\pi \mathbb{E} \Big[
r_t + \gamma r_{t + 1} + \gamma^2 r_{t + 2} + ... | s_t = s, a_t = a, \pi
\Big]
\\
Q^*(s,a) &= \mathop{\mathbb{E}}_{s' \sim \large{\varepsilon}} \Big[
r + \gamma \max_{a'} Q^* (s', a') | s, a
\Big]
\end{align}
### Target network 🎯
In order to improve stability we use experience replay that randomly sample
from previous experience $U(D)$. We also use a Q network
with a separate set of paramters $\color{orangle}{\theta_i^{-}}$ to calculate the target.
$\color{orangle}{\theta_i^{-}}$ is updated periodically.
This is according to paper
[Human Level Control Through Deep Reinforcement Learning](https://deepmind.com/research/dqn/).
So the loss function is,
$$
\mathcal{L}_i(\theta_i) = \mathop{\mathbb{E}}_{(s,a,r,s') \sim U(D)}
\bigg[
\Big(
r + \gamma \max_{a'} Q(s', a'; \color{orange}{\theta_i^{-}}) - Q(s,a;\theta_i)
\Big) ^ 2
\bigg]
$$
### Double $Q$-Learning
The max operator in the above calculation uses same network for both
selecting the best action and for evaluating the value.
That is,
$$
\max_{a'} Q(s', a'; \theta) = \color{cyan}{Q}
\Big(
s', \mathop{\operatorname{argmax}}_{a'}
\color{cyan}{Q}(s', a'; \color{cyan}{\theta}); \color{cyan}{\theta}
\Big)
$$
We use [double Q-learning](https://arxiv.org/abs/1509.06461), where
the $\operatorname{argmax}$ is taken from $\color{cyan}{\theta_i}$ and
the value is taken from $\color{orange}{\theta_i^{-}}$.
And the loss function becomes,
\begin{align}
\mathcal{L}_i(\theta_i) = \mathop{\mathbb{E}}_{(s,a,r,s') \sim U(D)}
\Bigg[
\bigg(
&r + \gamma \color{orange}{Q}
\Big(
s',
\mathop{\operatorname{argmax}}_{a'}
\color{cyan}{Q}(s', a'; \color{cyan}{\theta_i}); \color{orange}{\theta_i^{-}}
\Big)
\\
- &Q(s,a;\theta_i)
\bigg) ^ 2
\Bigg]
\end{align}
"""
def __init__(self, gamma: float):
super().__init__()
self.gamma = gamma
self.huber_loss = nn.SmoothL1Loss(reduction='none')
def __call__(self, q: torch.Tensor,
action: torch.Tensor,
double_q: torch.Tensor,
target_q: torch.Tensor,
done: torch.Tensor,
reward: torch.Tensor,
def __call__(self, q: torch.Tensor, action: torch.Tensor, double_q: torch.Tensor,
target_q: torch.Tensor, done: torch.Tensor, reward: torch.Tensor,
weights: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
* `q` - $Q(s;\theta_i)$
* `action` - $a$
* `double_q` - $\color{cyan}Q(s';\color{cyan}{\theta_i})$
* `target_q` - $\color{orange}Q(s';\color{orange}{\theta_i^{-}})$
* `done` - whether the game ended after taking the action
* `reward` - $r$
* `weights` - weights of the samples from prioritized experienced replay
"""
# $Q(s,a;\theta_i)$
q_sampled_action = q.gather(-1, action.to(torch.long).unsqueeze(-1)).squeeze(-1)
tracker.add('q_sampled_action', q_sampled_action)
# Gradients shouldn't propagate gradients
# $$r + \gamma \color{orange}{Q}
# \Big(s',
# \mathop{\operatorname{argmax}}_{a'}
# \color{cyan}{Q}(s', a'; \color{cyan}{\theta_i}); \color{orange}{\theta_i^{-}}
# \Big)$$
with torch.no_grad():
# Get the best action at state $s'$
# $$\mathop{\operatorname{argmax}}_{a'}
# \color{cyan}{Q}(s', a'; \color{cyan}{\theta_i})$$
best_next_action = torch.argmax(double_q, -1)
# Get the q value from the target network for the best action at state $s'$
# $$\color{orange}{Q}
# \Big(s',\mathop{\operatorname{argmax}}_{a'}
# \color{cyan}{Q}(s', a'; \color{cyan}{\theta_i}); \color{orange}{\theta_i^{-}}
# \Big)$$
best_next_q_value = target_q.gather(-1, best_next_action.unsqueeze(-1)).squeeze(-1)
best_next_q_value *= (1 - done)
q_update = reward + self.gamma * best_next_q_value
# Calculate the desired Q value.
# We multiply by `(1 - done)` to zero out
# the next state Q values if the game ended.
#
# $$r + \gamma \color{orange}{Q}
# \Big(s',
# \mathop{\operatorname{argmax}}_{a'}
# \color{cyan}{Q}(s', a'; \color{cyan}{\theta_i}); \color{orange}{\theta_i^{-}}
# \Big)$$
q_update = reward + self.gamma * best_next_q_value * (1 - done)
tracker.add('q_update', q_update)
# Temporal difference error $\delta$ is used to weigh samples in replay buffer
td_error = q_sampled_action - q_update
tracker.add('td_error', td_error)
# Huber loss
losses = torch.nn.functional.smooth_l1_loss(q_sampled_action, q_update, reduction='none')
# We take [Huber loss](https://en.wikipedia.org/wiki/Huber_loss) instead of
# mean squared error loss because it is less sensitive to outliers
losses = self.huber_loss(q_sampled_action, q_update)
# Get weighted means
loss = torch.mean(weights * losses)
tracker.add('loss', loss)
return td_error, loss

View File

@ -8,12 +8,11 @@
import numpy as np
import torch
from torch import nn
from labml import tracker, experiment, logger, monit
from labml_helpers.module import Module
from labml_helpers.schedule import Piecewise
from labml_nn.rl.dqn import QFuncLoss
from labml_nn.rl.dqn.model import Model
from labml_nn.rl.dqn.replay_buffer import ReplayBuffer
from labml_nn.rl.game import Worker
@ -23,105 +22,12 @@ else:
device = torch.device("cpu")
class Model(Module):
"""
## <a name="model"></a>Neural Network Model for $Q$ Values
#### Dueling Network ⚔️
We are using a [dueling network](https://arxiv.org/abs/1511.06581)
to calculate Q-values.
Intuition behind dueling network architure is that in most states
the action doesn't matter,
and in some states the action is significant. Dueling network allows
this to be represented very well.
\begin{align}
Q^\pi(s,a) &= V^\pi(s) + A^\pi(s, a)
\\
\mathop{\mathbb{E}}_{a \sim \pi(s)}
\Big[
A^\pi(s, a)
\Big]
&= 0
\end{align}
So we create two networks for $V$ and $A$ and get $Q$ from them.
$$
Q(s, a) = V(s) +
\Big(
A(s, a) - \frac{1}{|\mathcal{A}|} \sum_{a' \in \mathcal{A}} A(s, a')
\Big)
$$
We share the initial layers of the $V$ and $A$ networks.
"""
def __init__(self):
"""
### Initialize
We need `scope` because we need multiple copies of variables
for target network and training network.
"""
super().__init__()
self.conv = nn.Sequential(
# The first convolution layer takes a
# 84x84 frame and produces a 20x20 frame
nn.Conv2d(in_channels=4, out_channels=32, kernel_size=8, stride=4),
nn.ReLU(),
# The second convolution layer takes a
# 20x20 frame and produces a 9x9 frame
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2),
nn.ReLU(),
# The third convolution layer takes a
# 9x9 frame and produces a 7x7 frame
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1),
nn.ReLU(),
)
# A fully connected layer takes the flattened
# frame from third convolution layer, and outputs
# 512 features
self.lin = nn.Linear(in_features=7 * 7 * 64, out_features=512)
self.state_score = nn.Sequential(
nn.Linear(in_features=512, out_features=256),
nn.ReLU(),
nn.Linear(in_features=256, out_features=1),
)
self.action_score = nn.Sequential(
nn.Linear(in_features=512, out_features=256),
nn.ReLU(),
nn.Linear(in_features=256, out_features=4),
)
#
self.activation = nn.ReLU()
def __call__(self, obs: torch.Tensor):
h = self.conv(obs)
h = h.reshape((-1, 7 * 7 * 64))
h = self.activation(self.lin(h))
action_score = self.action_score(h)
state_score = self.state_score(h)
# $Q(s, a) =V(s) + \Big(A(s, a) - \frac{1}{|\mathcal{A}|} \sum_{a' \in \mathcal{A}} A(s, a')\Big)$
action_score_centered = action_score - action_score.mean(dim=-1, keepdim=True)
q = state_score + action_score_centered
return q
def obs_to_torch(obs: np.ndarray) -> torch.Tensor:
"""Scale observations from `[0, 255]` to `[0, 1]`"""
return torch.tensor(obs, dtype=torch.float32, device=device) / 255.
class Main(object):
class Trainer:
"""
## <a name="main"></a>Main class
This class runs the training loop.
@ -239,71 +145,6 @@ class Main(object):
self.obs[w] = next_obs
def train(self, beta: float):
"""
### Train the model
We want to find optimal action-value function.
\begin{align}
Q^*(s,a) &= \max_\pi \mathbb{E} \Big[
r_t + \gamma r_{t + 1} + \gamma^2 r_{t + 2} + ... | s_t = s, a_t = a, \pi
\Big]
\\
Q^*(s,a) &= \mathop{\mathbb{E}}_{s' \sim \large{\varepsilon}} \Big[
r + \gamma \max_{a'} Q^* (s', a') | s, a
\Big]
\end{align}
#### Target network 🎯
In order to improve stability we use experience replay that randomly sample
from previous experience $U(D)$. We also use a Q network
with a separate set of paramters $\hl1{\theta_i^{-}}$ to calculate the target.
$\hl1{\theta_i^{-}}$ is updated periodically.
This is according to the [paper by DeepMind](https://deepmind.com/research/dqn/).
So the loss function is,
$$
\mathcal{L}_i(\theta_i) = \mathop{\mathbb{E}}_{(s,a,r,s') \sim U(D)}
\bigg[
\Big(
r + \gamma \max_{a'} Q(s', a'; \hl1{\theta_i^{-}}) - Q(s,a;\theta_i)
\Big) ^ 2
\bigg]
$$
#### Double $Q$-Learning
The max operator in the above calculation uses same network for both
selecting the best action and for evaluating the value.
That is,
$$
\max_{a'} Q(s', a'; \theta) = \blue{Q}
\Big(
s', \mathop{\operatorname{argmax}}_{a'}
\blue{Q}(s', a'; \blue{\theta}); \blue{\theta}
\Big)
$$
We use [double Q-learning](https://arxiv.org/abs/1509.06461), where
the $\operatorname{argmax}$ is taken from $\theta_i$ and
the value is taken from $\theta_i^{-}$.
And the loss function becomes,
\begin{align}
\mathcal{L}_i(\theta_i) = \mathop{\mathbb{E}}_{(s,a,r,s') \sim U(D)}
\Bigg[
\bigg(
&r + \gamma \blue{Q}
\Big(
s',
\mathop{\operatorname{argmax}}_{a'}
\green{Q}(s', a'; \green{\theta_i}); \blue{\theta_i^{-}}
\Big)
\\
- &Q(s,a;\theta_i)
\bigg) ^ 2
\Bigg]
\end{align}
"""
for _ in range(self.train_epochs):
# sample from priority replay buffer
samples = self.replay_buffer.sample(self.mini_batch_size, beta)
@ -379,7 +220,7 @@ class Main(object):
# ## Run it
if __name__ == "__main__":
experiment.create(name='dqn')
m = Main()
m = Trainer()
with experiment.start():
m.run_training_loop()
m.destroy()

100
labml_nn/rl/dqn/model.py Normal file
View File

@ -0,0 +1,100 @@
"""
# Neural Network Model
"""
import torch
from torch import nn
from labml_helpers.module import Module
class Model(Module):
"""
## Dueling Network ⚔️ Model for $Q$ Values
We are using a [dueling network](https://arxiv.org/abs/1511.06581)
to calculate Q-values.
Intuition behind dueling network architecture is that in most states
the action doesn't matter,
and in some states the action is significant. Dueling network allows
this to be represented very well.
\begin{align}
Q^\pi(s,a) &= V^\pi(s) + A^\pi(s, a)
\\
\mathop{\mathbb{E}}_{a \sim \pi(s)}
\Big[
A^\pi(s, a)
\Big]
&= 0
\end{align}
So we create two networks for $V$ and $A$ and get $Q$ from them.
$$
Q(s, a) = V(s) +
\Big(
A(s, a) - \frac{1}{|\mathcal{A}|} \sum_{a' \in \mathcal{A}} A(s, a')
\Big)
$$
We share the initial layers of the $V$ and $A$ networks.
"""
def __init__(self):
"""
### Initialize
We need `scope` because we need multiple copies of variables
for target network and training network.
"""
super().__init__()
self.conv = nn.Sequential(
# The first convolution layer takes a
# 84x84 frame and produces a 20x20 frame
nn.Conv2d(in_channels=4, out_channels=32, kernel_size=8, stride=4),
nn.ReLU(),
# The second convolution layer takes a
# 20x20 frame and produces a 9x9 frame
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2),
nn.ReLU(),
# The third convolution layer takes a
# 9x9 frame and produces a 7x7 frame
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1),
nn.ReLU(),
)
# A fully connected layer takes the flattened
# frame from third convolution layer, and outputs
# 512 features
self.lin = nn.Linear(in_features=7 * 7 * 64, out_features=512)
self.state_score = nn.Sequential(
nn.Linear(in_features=512, out_features=256),
nn.ReLU(),
nn.Linear(in_features=256, out_features=1),
)
self.action_score = nn.Sequential(
nn.Linear(in_features=512, out_features=256),
nn.ReLU(),
nn.Linear(in_features=256, out_features=4),
)
#
self.activation = nn.ReLU()
def __call__(self, obs: torch.Tensor):
h = self.conv(obs)
h = h.reshape((-1, 7 * 7 * 64))
h = self.activation(self.lin(h))
action_score = self.action_score(h)
state_score = self.state_score(h)
# $Q(s, a) =V(s) + \Big(A(s, a) - \frac{1}{|\mathcal{A}|} \sum_{a' \in \mathcal{A}} A(s, a')\Big)$
action_score_centered = action_score - action_score.mean(dim=-1, keepdim=True)
q = state_score + action_score_centered
return q

View File

@ -1,3 +1,10 @@
"""
# Prioritized Experience Replace Buffer
This implements paper [Prioritized experience replay](https://arxiv.org/abs/1511.05952),
using a binary segment tree.
"""
import numpy as np
import random
@ -5,9 +12,10 @@ import random
class ReplayBuffer:
"""
## Buffer for Prioritized Experience Replay
[Prioritized experience replay](https://arxiv.org/abs/1511.05952)
samples important transitions more frequently.
The transitions are prioritized by the Temporal Difference error.
The transitions are prioritized by the Temporal Difference error (td error).
We sample transition $i$ with probability,
$$P(i) = \frac{p_i^\alpha}{\sum_k p_k^\alpha}$$
@ -21,16 +29,16 @@ class ReplayBuffer:
importance-sampling (IS) weights
$$w_i = \bigg(\frac{1}{N} \frac{1}{P(i)}\bigg)^\beta$$
that fully compensates for when $\beta = 1$.
We normalize weights by $1/\max_i w_i$ for stability.
We normalize weights by $\frac{1}{\max_i w_i}$ for stability.
Unbiased nature is most important towards the convergence at end of training.
Therefore we increase $\beta$ towards end of training.
### Binary Segment Trees
We use binary segment trees to efficiently calculate
### Binary Segment Tree
We use a binary segment tree to efficiently calculate
$\sum_k^i p_k^\alpha$, the cumulative probability,
which is needed to sample.
We also use a binary segment tree to find $\min p_i^\alpha$,
which is needed for $1/\max_i w_i$.
which is needed for $\frac{1}{\max_i w_i}$.
We can also use a min-heap for this.
This is how a binary segment tree works for sum;
@ -54,14 +62,16 @@ class ReplayBuffer:
$$N_i = \left\lceil{\frac{N}{D - i + i}} \right\rceil$$
This is equal to the sum of nodes in all rows above $i$.
So we can use a single array $a$ to store the tree, where,
$$b_{i,j} = a_{N_1 + j}$$
$$b_{i,j} \rightarrow a_{N_i + j}$$
Then child nodes of $a_i$ are $a_{2i}$ and $a_{2i + 1}$.
That is,
$$a_i = a_{2i} + a_{2i + 1}$$
This way of maintaining binary trees is very easy to program.
*Note that we are indexing from 1*.
*Note that we are indexing starting from 1*.
We using the same structure to compute the minimum.
"""
def __init__(self, capacity, alpha):
@ -206,7 +216,7 @@ class ReplayBuffer:
# $w_i = \bigg(\frac{1}{N} \frac{1}{P(i)}\bigg)^\beta$
weight = (prob * self.size) ** (-beta)
# normalize by $\frac{1}{\max_i w_i}$,
# which also cancels off the $\frac{1}/{N}$ term
# which also cancels off the $\frac{1}{N}$ term
samples['weights'][i] = weight / max_weight
# get samples data
@ -230,7 +240,5 @@ class ReplayBuffer:
def is_full(self):
"""
### Is the buffer full
We only start sampling afte the buffer is full.
"""
return self.capacity == self.size