This is a PyTorch implementation of paper Playing Atari with Deep Reinforcement Learning along with Dueling Network, Prioritized Replay and Double Q Network.
Here are the experiment and model implementation.
27from typing import Tuple
28
29import torch
30from torch import nn
31
32from labml import tracker
33from labml_helpers.module import Module
34from labml_nn.rl.dqn.replay_buffer import ReplayBufferWe want to find optimal action-value function.
In order to improve stability we use experience replay that randomly sample from previous experience $U(D)$. We also use a Q network with a separate set of paramters $\color{orangle}{\theta_i^{-}}$ to calculate the target. $\color{orangle}{\theta_i^{-}}$ is updated periodically. This is according to paper Human Level Control Through Deep Reinforcement Learning.
So the loss function is,
The max operator in the above calculation uses same network for both selecting the best action and for evaluating the value. That is, We use double Q-learning, where the $\operatorname{argmax}$ is taken from $\color{cyan}{\theta_i}$ and the value is taken from $\color{orange}{\theta_i^{-}}$.
And the loss function becomes,
37class QFuncLoss(Module):104    def __init__(self, gamma: float):
105        super().__init__()
106        self.gamma = gamma
107        self.huber_loss = nn.SmoothL1Loss(reduction='none')q - $Q(s;\theta_i)$action - $a$double_q - $\color{cyan}Q(s’;\color{cyan}{\theta_i})$target_q - $\color{orange}Q(s’;\color{orange}{\theta_i^{-}})$done - whether the game ended after taking the actionreward - $r$weights - weights of the samples from prioritized experienced replay109    def __call__(self, q: torch.Tensor, action: torch.Tensor, double_q: torch.Tensor,
110                 target_q: torch.Tensor, done: torch.Tensor, reward: torch.Tensor,
111                 weights: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:$Q(s,a;\theta_i)$
123        q_sampled_action = q.gather(-1, action.to(torch.long).unsqueeze(-1)).squeeze(-1)
124        tracker.add('q_sampled_action', q_sampled_action)Gradients shouldn’t propagate gradients
132        with torch.no_grad():Get the best action at state $s’$
136            best_next_action = torch.argmax(double_q, -1)Get the q value from the target network for the best action at state $s’$
142            best_next_q_value = target_q.gather(-1, best_next_action.unsqueeze(-1)).squeeze(-1)Calculate the desired Q value.
We multiply by (1 - done) to zero out
the next state Q values if the game ended.
153            q_update = reward + self.gamma * best_next_q_value * (1 - done)
154            tracker.add('q_update', q_update)Temporal difference error $\delta$ is used to weigh samples in replay buffer
157            td_error = q_sampled_action - q_update
158            tracker.add('td_error', td_error)We take Huber loss instead of mean squared error loss because it is less sensitive to outliers
162        losses = self.huber_loss(q_sampled_action, q_update)Get weighted means
164        loss = torch.mean(weights * losses)
165        tracker.add('loss', loss)
166
167        return td_error, loss