📚 annotations

This commit is contained in:
Varuna Jayasiri
2020-10-24 18:53:02 +05:30
parent 40be459a92
commit 2ddfe1b252
4 changed files with 244 additions and 187 deletions

View File

@ -8,12 +8,11 @@
import numpy as np
import torch
from torch import nn
from labml import tracker, experiment, logger, monit
from labml_helpers.module import Module
from labml_helpers.schedule import Piecewise
from labml_nn.rl.dqn import QFuncLoss
from labml_nn.rl.dqn.model import Model
from labml_nn.rl.dqn.replay_buffer import ReplayBuffer
from labml_nn.rl.game import Worker
@ -23,105 +22,12 @@ else:
device = torch.device("cpu")
class Model(Module):
"""
## <a name="model"></a>Neural Network Model for $Q$ Values
#### Dueling Network ⚔️
We are using a [dueling network](https://arxiv.org/abs/1511.06581)
to calculate Q-values.
Intuition behind dueling network architure is that in most states
the action doesn't matter,
and in some states the action is significant. Dueling network allows
this to be represented very well.
\begin{align}
Q^\pi(s,a) &= V^\pi(s) + A^\pi(s, a)
\\
\mathop{\mathbb{E}}_{a \sim \pi(s)}
\Big[
A^\pi(s, a)
\Big]
&= 0
\end{align}
So we create two networks for $V$ and $A$ and get $Q$ from them.
$$
Q(s, a) = V(s) +
\Big(
A(s, a) - \frac{1}{|\mathcal{A}|} \sum_{a' \in \mathcal{A}} A(s, a')
\Big)
$$
We share the initial layers of the $V$ and $A$ networks.
"""
def __init__(self):
"""
### Initialize
We need `scope` because we need multiple copies of variables
for target network and training network.
"""
super().__init__()
self.conv = nn.Sequential(
# The first convolution layer takes a
# 84x84 frame and produces a 20x20 frame
nn.Conv2d(in_channels=4, out_channels=32, kernel_size=8, stride=4),
nn.ReLU(),
# The second convolution layer takes a
# 20x20 frame and produces a 9x9 frame
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2),
nn.ReLU(),
# The third convolution layer takes a
# 9x9 frame and produces a 7x7 frame
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1),
nn.ReLU(),
)
# A fully connected layer takes the flattened
# frame from third convolution layer, and outputs
# 512 features
self.lin = nn.Linear(in_features=7 * 7 * 64, out_features=512)
self.state_score = nn.Sequential(
nn.Linear(in_features=512, out_features=256),
nn.ReLU(),
nn.Linear(in_features=256, out_features=1),
)
self.action_score = nn.Sequential(
nn.Linear(in_features=512, out_features=256),
nn.ReLU(),
nn.Linear(in_features=256, out_features=4),
)
#
self.activation = nn.ReLU()
def __call__(self, obs: torch.Tensor):
h = self.conv(obs)
h = h.reshape((-1, 7 * 7 * 64))
h = self.activation(self.lin(h))
action_score = self.action_score(h)
state_score = self.state_score(h)
# $Q(s, a) =V(s) + \Big(A(s, a) - \frac{1}{|\mathcal{A}|} \sum_{a' \in \mathcal{A}} A(s, a')\Big)$
action_score_centered = action_score - action_score.mean(dim=-1, keepdim=True)
q = state_score + action_score_centered
return q
def obs_to_torch(obs: np.ndarray) -> torch.Tensor:
"""Scale observations from `[0, 255]` to `[0, 1]`"""
return torch.tensor(obs, dtype=torch.float32, device=device) / 255.
class Main(object):
class Trainer:
"""
## <a name="main"></a>Main class
This class runs the training loop.
@ -239,71 +145,6 @@ class Main(object):
self.obs[w] = next_obs
def train(self, beta: float):
"""
### Train the model
We want to find optimal action-value function.
\begin{align}
Q^*(s,a) &= \max_\pi \mathbb{E} \Big[
r_t + \gamma r_{t + 1} + \gamma^2 r_{t + 2} + ... | s_t = s, a_t = a, \pi
\Big]
\\
Q^*(s,a) &= \mathop{\mathbb{E}}_{s' \sim \large{\varepsilon}} \Big[
r + \gamma \max_{a'} Q^* (s', a') | s, a
\Big]
\end{align}
#### Target network 🎯
In order to improve stability we use experience replay that randomly sample
from previous experience $U(D)$. We also use a Q network
with a separate set of paramters $\hl1{\theta_i^{-}}$ to calculate the target.
$\hl1{\theta_i^{-}}$ is updated periodically.
This is according to the [paper by DeepMind](https://deepmind.com/research/dqn/).
So the loss function is,
$$
\mathcal{L}_i(\theta_i) = \mathop{\mathbb{E}}_{(s,a,r,s') \sim U(D)}
\bigg[
\Big(
r + \gamma \max_{a'} Q(s', a'; \hl1{\theta_i^{-}}) - Q(s,a;\theta_i)
\Big) ^ 2
\bigg]
$$
#### Double $Q$-Learning
The max operator in the above calculation uses same network for both
selecting the best action and for evaluating the value.
That is,
$$
\max_{a'} Q(s', a'; \theta) = \blue{Q}
\Big(
s', \mathop{\operatorname{argmax}}_{a'}
\blue{Q}(s', a'; \blue{\theta}); \blue{\theta}
\Big)
$$
We use [double Q-learning](https://arxiv.org/abs/1509.06461), where
the $\operatorname{argmax}$ is taken from $\theta_i$ and
the value is taken from $\theta_i^{-}$.
And the loss function becomes,
\begin{align}
\mathcal{L}_i(\theta_i) = \mathop{\mathbb{E}}_{(s,a,r,s') \sim U(D)}
\Bigg[
\bigg(
&r + \gamma \blue{Q}
\Big(
s',
\mathop{\operatorname{argmax}}_{a'}
\green{Q}(s', a'; \green{\theta_i}); \blue{\theta_i^{-}}
\Big)
\\
- &Q(s,a;\theta_i)
\bigg) ^ 2
\Bigg]
\end{align}
"""
for _ in range(self.train_epochs):
# sample from priority replay buffer
samples = self.replay_buffer.sample(self.mini_batch_size, beta)
@ -379,7 +220,7 @@ class Main(object):
# ## Run it
if __name__ == "__main__":
experiment.create(name='dqn')
m = Main()
m = Trainer()
with experiment.start():
m.run_training_loop()
m.destroy()