mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-10-30 18:27:03 +08:00
105 lines
3.2 KiB
Python
105 lines
3.2 KiB
Python
"""
|
|
---
|
|
title: Deep Q Network (DQN) Model
|
|
summary: Implementation of neural network model for Deep Q Network (DQN).
|
|
---
|
|
|
|
# Deep Q Network (DQN) Model
|
|
"""
|
|
|
|
import torch
|
|
from torch import nn
|
|
|
|
from labml_helpers.module import Module
|
|
|
|
|
|
class Model(Module):
|
|
"""
|
|
## Dueling Network ⚔️ Model for $Q$ Values
|
|
|
|
We are using a [dueling network](https://papers.labml.ai/paper/1511.06581)
|
|
to calculate Q-values.
|
|
Intuition behind dueling network architecture is that in most states
|
|
the action doesn't matter,
|
|
and in some states the action is significant. Dueling network allows
|
|
this to be represented very well.
|
|
|
|
\begin{align}
|
|
Q^\pi(s,a) &= V^\pi(s) + A^\pi(s, a)
|
|
\\
|
|
\mathop{\mathbb{E}}_{a \sim \pi(s)}
|
|
\Big[
|
|
A^\pi(s, a)
|
|
\Big]
|
|
&= 0
|
|
\end{align}
|
|
|
|
So we create two networks for $V$ and $A$ and get $Q$ from them.
|
|
$$
|
|
Q(s, a) = V(s) +
|
|
\Big(
|
|
A(s, a) - \frac{1}{|\mathcal{A}|} \sum_{a' \in \mathcal{A}} A(s, a')
|
|
\Big)
|
|
$$
|
|
We share the initial layers of the $V$ and $A$ networks.
|
|
"""
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.conv = nn.Sequential(
|
|
# The first convolution layer takes a
|
|
# $84\times84$ frame and produces a $20\times20$ frame
|
|
nn.Conv2d(in_channels=4, out_channels=32, kernel_size=8, stride=4),
|
|
nn.ReLU(),
|
|
|
|
# The second convolution layer takes a
|
|
# $20\times20$ frame and produces a $9\times9$ frame
|
|
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2),
|
|
nn.ReLU(),
|
|
|
|
# The third convolution layer takes a
|
|
# $9\times9$ frame and produces a $7\times7$ frame
|
|
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1),
|
|
nn.ReLU(),
|
|
)
|
|
|
|
# A fully connected layer takes the flattened
|
|
# frame from third convolution layer, and outputs
|
|
# $512$ features
|
|
self.lin = nn.Linear(in_features=7 * 7 * 64, out_features=512)
|
|
self.activation = nn.ReLU()
|
|
|
|
# This head gives the state value $V$
|
|
self.state_value = nn.Sequential(
|
|
nn.Linear(in_features=512, out_features=256),
|
|
nn.ReLU(),
|
|
nn.Linear(in_features=256, out_features=1),
|
|
)
|
|
# This head gives the action value $A$
|
|
self.action_value = nn.Sequential(
|
|
nn.Linear(in_features=512, out_features=256),
|
|
nn.ReLU(),
|
|
nn.Linear(in_features=256, out_features=4),
|
|
)
|
|
|
|
def __call__(self, obs: torch.Tensor):
|
|
# Convolution
|
|
h = self.conv(obs)
|
|
# Reshape for linear layers
|
|
h = h.reshape((-1, 7 * 7 * 64))
|
|
|
|
# Linear layer
|
|
h = self.activation(self.lin(h))
|
|
|
|
# $A$
|
|
action_value = self.action_value(h)
|
|
# $V$
|
|
state_value = self.state_value(h)
|
|
|
|
# $A(s, a) - \frac{1}{|\mathcal{A}|} \sum_{a' \in \mathcal{A}} A(s, a')$
|
|
action_score_centered = action_value - action_value.mean(dim=-1, keepdim=True)
|
|
# $Q(s, a) =V(s) + \Big(A(s, a) - \frac{1}{|\mathcal{A}|} \sum_{a' \in \mathcal{A}} A(s, a')\Big)$
|
|
q = state_value + action_score_centered
|
|
|
|
return q
|