13import torch
14from torch import nn
15
16from labml_helpers.module import ModuleWe are using a dueling network to calculate Q-values. Intuition behind dueling network architecture is that in most states the action doesn't matter, and in some states the action is significant. Dueling network allows this to be represented very well.
So we create two networks for and and get from them. We share the initial layers of the and networks.
19class Model(Module):50    def __init__(self):
51        super().__init__()
52        self.conv = nn.Sequential(The first convolution layer takes a frame and produces a frame
55            nn.Conv2d(in_channels=4, out_channels=32, kernel_size=8, stride=4),
56            nn.ReLU(),The second convolution layer takes a frame and produces a frame
60            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2),
61            nn.ReLU(),The third convolution layer takes a frame and produces a frame
65            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1),
66            nn.ReLU(),
67        )A fully connected layer takes the flattened frame from third convolution layer, and outputs features
72        self.lin = nn.Linear(in_features=7 * 7 * 64, out_features=512)
73        self.activation = nn.ReLU()This head gives the state value
76        self.state_value = nn.Sequential(
77            nn.Linear(in_features=512, out_features=256),
78            nn.ReLU(),
79            nn.Linear(in_features=256, out_features=1),
80        )This head gives the action value
82        self.action_value = nn.Sequential(
83            nn.Linear(in_features=512, out_features=256),
84            nn.ReLU(),
85            nn.Linear(in_features=256, out_features=4),
86        )88    def forward(self, obs: torch.Tensor):Convolution
90        h = self.conv(obs)Reshape for linear layers
92        h = h.reshape((-1, 7 * 7 * 64))Linear layer
95        h = self.activation(self.lin(h))98        action_value = self.action_value(h)100        state_value = self.state_value(h)103        action_score_centered = action_value - action_value.mean(dim=-1, keepdim=True)105        q = state_value + action_score_centered
106
107        return q