13import torch
14from torch import nn
15
16from labml_helpers.module import Module我们正在使用决斗网络来计算 Q 值。决斗网络架构背后的直觉是,在大多数州,行动无关紧要,而在某些州,行动意义重大。决斗网络可以很好地体现这一点。
因此,我们为和创建了两个网络,然后从中获取。我们共享和网络的初始层。
19class Model(Module):50 def __init__(self):
51 super().__init__()
52 self.conv = nn.Sequential(第一个卷积层需要一个帧并生成一个帧
55 nn.Conv2d(in_channels=4, out_channels=32, kernel_size=8, stride=4),
56 nn.ReLU(),第二个卷积层获取一个帧并生成一个帧
60 nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2),
61 nn.ReLU(),第三个卷积层获取一个帧并生成一个帧
65 nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1),
66 nn.ReLU(),
67 )完全连接的图层从第三个卷积图层获取展平的帧,并输出要素
72 self.lin = nn.Linear(in_features=7 * 7 * 64, out_features=512)
73 self.activation = nn.ReLU()这个头给出了状态值
76 self.state_value = nn.Sequential(
77 nn.Linear(in_features=512, out_features=256),
78 nn.ReLU(),
79 nn.Linear(in_features=256, out_features=1),
80 )这个头给出了动作值
82 self.action_value = nn.Sequential(
83 nn.Linear(in_features=512, out_features=256),
84 nn.ReLU(),
85 nn.Linear(in_features=256, out_features=4),
86 )88 def forward(self, obs: torch.Tensor):卷积
90 h = self.conv(obs)线性图层的整形
92 h = h.reshape((-1, 7 * 7 * 64))线性层
95 h = self.activation(self.lin(h))98 action_value = self.action_value(h)100 state_value = self.state_value(h)103 action_score_centered = action_value - action_value.mean(dim=-1, keepdim=True)105 q = state_value + action_score_centered
106
107 return q