📚 annotations

This commit is contained in:
Varuna Jayasiri
2020-10-25 09:44:43 +05:30
parent b492e7c7ca
commit ca258af351
4 changed files with 82 additions and 57 deletions

View File

@ -1,5 +1,5 @@
"""
# Neural Network Model
# Neural Network Model for Deep Q Network (DQN)
"""
import torch
@ -40,61 +40,60 @@ class Model(Module):
"""
def __init__(self):
"""
### Initialize
We need `scope` because we need multiple copies of variables
for target network and training network.
"""
super().__init__()
self.conv = nn.Sequential(
# The first convolution layer takes a
# 84x84 frame and produces a 20x20 frame
# $84\times84$ frame and produces a $20\times20$ frame
nn.Conv2d(in_channels=4, out_channels=32, kernel_size=8, stride=4),
nn.ReLU(),
# The second convolution layer takes a
# 20x20 frame and produces a 9x9 frame
# $20\times20$ frame and produces a $9\times9$ frame
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2),
nn.ReLU(),
# The third convolution layer takes a
# 9x9 frame and produces a 7x7 frame
# $9\times9$ frame and produces a $7\times7$ frame
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1),
nn.ReLU(),
)
# A fully connected layer takes the flattened
# frame from third convolution layer, and outputs
# 512 features
# $512$ features
self.lin = nn.Linear(in_features=7 * 7 * 64, out_features=512)
self.activation = nn.ReLU()
self.state_score = nn.Sequential(
# This head gives the state value $V$
self.state_value = nn.Sequential(
nn.Linear(in_features=512, out_features=256),
nn.ReLU(),
nn.Linear(in_features=256, out_features=1),
)
self.action_score = nn.Sequential(
# This head gives the action value $A$
self.action_value = nn.Sequential(
nn.Linear(in_features=512, out_features=256),
nn.ReLU(),
nn.Linear(in_features=256, out_features=4),
)
#
self.activation = nn.ReLU()
def __call__(self, obs: torch.Tensor):
# Convolution
h = self.conv(obs)
# Reshape for linear layers
h = h.reshape((-1, 7 * 7 * 64))
# Linear layer
h = self.activation(self.lin(h))
action_score = self.action_score(h)
state_score = self.state_score(h)
# $A$
action_value = self.action_value(h)
# $V$
state_value = self.state_value(h)
# $A(s, a) - \frac{1}{|\mathcal{A}|} \sum_{a' \in \mathcal{A}} A(s, a')$
action_score_centered = action_value - action_value.mean(dim=-1, keepdim=True)
# $Q(s, a) =V(s) + \Big(A(s, a) - \frac{1}{|\mathcal{A}|} \sum_{a' \in \mathcal{A}} A(s, a')\Big)$
action_score_centered = action_score - action_score.mean(dim=-1, keepdim=True)
q = state_score + action_score_centered
q = state_value + action_score_centered
return q