This is a PyTorch implementation of paper Playing Atari with Deep Reinforcement Learning along with Dueling Network, Prioritized Replay and Double Q Network.
Here is the experiment and model implementation.
25from typing import Tuple
26
27import torch
28from torch import nn
29
30from labml import tracker
31from labml_helpers.module import Module
32from labml_nn.rl.dqn.replay_buffer import ReplayBufferWe want to find optimal action-value function.
In order to improve stability we use experience replay that randomly sample from previous experience . We also use a Q network with a separate set of paramters to calculate the target. is updated periodically. This is according to paper Human Level Control Through Deep Reinforcement Learning.
So the loss function is,
The max operator in the above calculation uses same network for both selecting the best action and for evaluating the value. That is, We use double Q-learning, where the is taken from and the value is taken from .
And the loss function becomes,
35class QFuncLoss(Module):103    def __init__(self, gamma: float):
104        super().__init__()
105        self.gamma = gamma
106        self.huber_loss = nn.SmoothL1Loss(reduction='none')q
 -  action
 -  double_q
 -  target_q
 -  done
 - whether the game ended after taking the action reward
 -  weights
 - weights of the samples from prioritized experienced replay108    def forward(self, q: torch.Tensor, action: torch.Tensor, double_q: torch.Tensor,
109                target_q: torch.Tensor, done: torch.Tensor, reward: torch.Tensor,
110                weights: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:122        q_sampled_action = q.gather(-1, action.to(torch.long).unsqueeze(-1)).squeeze(-1)
123        tracker.add('q_sampled_action', q_sampled_action)Gradients shouldn't propagate gradients
131        with torch.no_grad():Get the best action at state
135            best_next_action = torch.argmax(double_q, -1)Get the q value from the target network for the best action at state
141            best_next_q_value = target_q.gather(-1, best_next_action.unsqueeze(-1)).squeeze(-1)Calculate the desired Q value. We multiply by (1 - done)
 to zero out the next state Q values if the game ended.
152            q_update = reward + self.gamma * best_next_q_value * (1 - done)
153            tracker.add('q_update', q_update)Temporal difference error is used to weigh samples in replay buffer
156            td_error = q_sampled_action - q_update
157            tracker.add('td_error', td_error)We take Huber loss instead of mean squared error loss because it is less sensitive to outliers
161        losses = self.huber_loss(q_sampled_action, q_update)Get weighted means
163        loss = torch.mean(weights * losses)
164        tracker.add('loss', loss)
165
166        return td_error, loss