深度问答网络 (DQN)

这是 PyTorch 实现的论文《玩雅达利与深度强化学习》以及《决斗网络》、《优先重播》和 Double Q Network。

以下是实验模型实现。

Open In ColabView Run

25from typing import Tuple
26
27import torch
28from torch import nn
29
30from labml import tracker
31from labml_helpers.module import Module
32from labml_nn.rl.dqn.replay_buffer import ReplayBuffer

训练模型

我们想找到最优的动作值函数。

目标网络 🎯

为了提高稳定性,我们使用从之前的体验中随机抽样的体验重播。我们还使用带有单独参数集的 Q 网络来计算目标。会定期更新。这是根据论文《通过深度强化学习控制人文水平》所说的。

所以损失函数是,

双重学习

上述计算中的最大运算符使用相同的网络来选择最佳操作和评估值。也就是说,我们使用双 Q 学习,其中取自,值取自

然后损失函数变成,

35class QFuncLoss(Module):
103    def __init__(self, gamma: float):
104        super().__init__()
105        self.gamma = gamma
106        self.huber_loss = nn.SmoothL1Loss(reduction='none')
  • q -
  • action -
  • double_q -
  • target_q -
  • done -游戏在采取行动后是否结束
  • reward -
  • weights -来自有经验的优先重播的样本的权重
  • 108    def forward(self, q: torch.Tensor, action: torch.Tensor, double_q: torch.Tensor,
    109                target_q: torch.Tensor, done: torch.Tensor, reward: torch.Tensor,
    110                weights: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:

    122        q_sampled_action = q.gather(-1, action.to(torch.long).unsqueeze(-1)).squeeze(-1)
    123        tracker.add('q_sampled_action', q_sampled_action)

    渐变不应传播渐变

    131        with torch.no_grad():

    在州内采取最佳行动

    135            best_next_action = torch.argmax(double_q, -1)

    从目标网络获取 q 值,以便在州内采取最佳行动

    141            best_next_q_value = target_q.gather(-1, best_next_action.unsqueeze(-1)).squeeze(-1)

    计算所需的 Q 值。如果游戏结束,我们将乘(1 - done) 以将下一个状态 Q 值归零。

    152            q_update = reward + self.gamma * best_next_q_value * (1 - done)
    153            tracker.add('q_update', q_update)

    时差误差用于称量重放缓冲区中的样本

    156            td_error = q_sampled_action - q_update
    157            tracker.add('td_error', td_error)

    我们采用 Huber 损失而不是均方误差损失,因为它对异常值不太敏感

    161        losses = self.huber_loss(q_sampled_action, q_update)

    获取加权均值

    163        loss = torch.mean(weights * losses)
    164        tracker.add('loss', loss)
    165
    166        return td_error, loss