Files
Varuna Jayasiri ec9a58c658 dqn experiment
2021-10-02 14:14:26 +05:30

108 lines
3.6 KiB
Python

"""
---
title: Deep Q Network (DQN) Model
summary: Implementation of neural network model for Deep Q Network (DQN).
---
# Deep Q Network (DQN) Model
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/rl/dqn/experiment.ipynb)
[![View Run](https://img.shields.io/badge/labml-experiment-brightgreen)](https://app.labml.ai/run/a0da8048235511ecb9affd797fa27714)
"""
import torch
from torch import nn
from labml_helpers.module import Module
class Model(Module):
"""
## Dueling Network ⚔️ Model for $Q$ Values
We are using a [dueling network](https://papers.labml.ai/paper/1511.06581)
to calculate Q-values.
Intuition behind dueling network architecture is that in most states
the action doesn't matter,
and in some states the action is significant. Dueling network allows
this to be represented very well.
\begin{align}
Q^\pi(s,a) &= V^\pi(s) + A^\pi(s, a)
\\
\mathop{\mathbb{E}}_{a \sim \pi(s)}
\Big[
A^\pi(s, a)
\Big]
&= 0
\end{align}
So we create two networks for $V$ and $A$ and get $Q$ from them.
$$
Q(s, a) = V(s) +
\Big(
A(s, a) - \frac{1}{|\mathcal{A}|} \sum_{a' \in \mathcal{A}} A(s, a')
\Big)
$$
We share the initial layers of the $V$ and $A$ networks.
"""
def __init__(self):
super().__init__()
self.conv = nn.Sequential(
# The first convolution layer takes a
# $84\times84$ frame and produces a $20\times20$ frame
nn.Conv2d(in_channels=4, out_channels=32, kernel_size=8, stride=4),
nn.ReLU(),
# The second convolution layer takes a
# $20\times20$ frame and produces a $9\times9$ frame
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2),
nn.ReLU(),
# The third convolution layer takes a
# $9\times9$ frame and produces a $7\times7$ frame
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1),
nn.ReLU(),
)
# A fully connected layer takes the flattened
# frame from third convolution layer, and outputs
# $512$ features
self.lin = nn.Linear(in_features=7 * 7 * 64, out_features=512)
self.activation = nn.ReLU()
# This head gives the state value $V$
self.state_value = nn.Sequential(
nn.Linear(in_features=512, out_features=256),
nn.ReLU(),
nn.Linear(in_features=256, out_features=1),
)
# This head gives the action value $A$
self.action_value = nn.Sequential(
nn.Linear(in_features=512, out_features=256),
nn.ReLU(),
nn.Linear(in_features=256, out_features=4),
)
def forward(self, obs: torch.Tensor):
# Convolution
h = self.conv(obs)
# Reshape for linear layers
h = h.reshape((-1, 7 * 7 * 64))
# Linear layer
h = self.activation(self.lin(h))
# $A$
action_value = self.action_value(h)
# $V$
state_value = self.state_value(h)
# $A(s, a) - \frac{1}{|\mathcal{A}|} \sum_{a' \in \mathcal{A}} A(s, a')$
action_score_centered = action_value - action_value.mean(dim=-1, keepdim=True)
# $Q(s, a) =V(s) + \Big(A(s, a) - \frac{1}{|\mathcal{A}|} \sum_{a' \in \mathcal{A}} A(s, a')\Big)$
q = state_value + action_score_centered
return q