深度问网络 (DQN) 模型

Open In ColabView Run

13import torch
14from torch import nn
15
16from labml_helpers.module import Module

决斗网络 ⚔️ 价值观模型

我们正在使用决斗网络来计算 Q 值。决斗网络架构背后的直觉是,在大多数州,行动无关紧要,而在某些州,行动意义重大。决斗网络可以很好地体现这一点。

因此,我们为和创建了两个网络,然后从中获取。我们共享网络的初始层。

19class Model(Module):
50    def __init__(self):
51        super().__init__()
52        self.conv = nn.Sequential(

第一个卷积层需要一个帧并生成一个

55            nn.Conv2d(in_channels=4, out_channels=32, kernel_size=8, stride=4),
56            nn.ReLU(),

第二个卷积层获取一个帧并生成一个

60            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2),
61            nn.ReLU(),

第三个卷积层获取一个帧并生成一个

65            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1),
66            nn.ReLU(),
67        )

完全连接的图层从第三个卷积图层获取展平的帧,并输出要素

72        self.lin = nn.Linear(in_features=7 * 7 * 64, out_features=512)
73        self.activation = nn.ReLU()

这个头给出了状态值

76        self.state_value = nn.Sequential(
77            nn.Linear(in_features=512, out_features=256),
78            nn.ReLU(),
79            nn.Linear(in_features=256, out_features=1),
80        )

这个头给出了动作值

82        self.action_value = nn.Sequential(
83            nn.Linear(in_features=512, out_features=256),
84            nn.ReLU(),
85            nn.Linear(in_features=256, out_features=4),
86        )
88    def forward(self, obs: torch.Tensor):

卷积

90        h = self.conv(obs)

线性图层的整形

92        h = h.reshape((-1, 7 * 7 * 64))

线性层

95        h = self.activation(self.lin(h))

98        action_value = self.action_value(h)

100        state_value = self.state_value(h)

103        action_score_centered = action_value - action_value.mean(dim=-1, keepdim=True)

105        q = state_value + action_score_centered
106
107        return q