mirror of
				https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
				synced 2025-11-04 14:29:43 +08:00 
			
		
		
		
	📚 annotations
This commit is contained in:
		@ -1,13 +1,24 @@
 | 
				
			|||||||
"""
 | 
					"""
 | 
				
			||||||
This is a Deep Q Learning implementation with:
 | 
					# Deep Q Networks
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					This is a Deep Q Learning implementation that uses:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					* [Dueling Network](model.html)
 | 
				
			||||||
 | 
					* [Prioritized Replay](replay_buffer.html)
 | 
				
			||||||
* Double Q Network
 | 
					* Double Q Network
 | 
				
			||||||
* Dueling Network
 | 
					
 | 
				
			||||||
* Prioritized Replay
 | 
					Here's the [experiment](experiment.html) and [model](model.html).
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					\(
 | 
				
			||||||
 | 
					   \def\green#1{{\color{yellowgreen}{#1}}}
 | 
				
			||||||
 | 
					\)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
"""
 | 
					"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from typing import Tuple
 | 
					from typing import Tuple
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
 | 
					from torch import nn
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from labml import tracker
 | 
					from labml import tracker
 | 
				
			||||||
from labml_helpers.module import Module
 | 
					from labml_helpers.module import Module
 | 
				
			||||||
@ -15,36 +26,133 @@ from labml_nn.rl.dqn.replay_buffer import ReplayBuffer
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class QFuncLoss(Module):
 | 
					class QFuncLoss(Module):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    ## Train the model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    We want to find optimal action-value function.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    \begin{align}
 | 
				
			||||||
 | 
					        Q^*(s,a) &= \max_\pi \mathbb{E} \Big[
 | 
				
			||||||
 | 
					            r_t + \gamma r_{t + 1} + \gamma^2 r_{t + 2} + ... | s_t = s, a_t = a, \pi
 | 
				
			||||||
 | 
					        \Big]
 | 
				
			||||||
 | 
					    \\
 | 
				
			||||||
 | 
					        Q^*(s,a) &= \mathop{\mathbb{E}}_{s' \sim \large{\varepsilon}} \Big[
 | 
				
			||||||
 | 
					            r + \gamma \max_{a'} Q^* (s', a') | s, a
 | 
				
			||||||
 | 
					        \Big]
 | 
				
			||||||
 | 
					    \end{align}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    ### Target network 🎯
 | 
				
			||||||
 | 
					    In order to improve stability we use experience replay that randomly sample
 | 
				
			||||||
 | 
					    from previous experience $U(D)$. We also use a Q network
 | 
				
			||||||
 | 
					    with a separate set of paramters $\color{orangle}{\theta_i^{-}}$ to calculate the target.
 | 
				
			||||||
 | 
					    $\color{orangle}{\theta_i^{-}}$ is updated periodically.
 | 
				
			||||||
 | 
					    This is according to paper
 | 
				
			||||||
 | 
					    [Human Level Control Through Deep Reinforcement Learning](https://deepmind.com/research/dqn/).
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    So the loss function is,
 | 
				
			||||||
 | 
					    $$
 | 
				
			||||||
 | 
					    \mathcal{L}_i(\theta_i) = \mathop{\mathbb{E}}_{(s,a,r,s') \sim U(D)}
 | 
				
			||||||
 | 
					    \bigg[
 | 
				
			||||||
 | 
					        \Big(
 | 
				
			||||||
 | 
					            r + \gamma \max_{a'} Q(s', a'; \color{orange}{\theta_i^{-}}) - Q(s,a;\theta_i)
 | 
				
			||||||
 | 
					        \Big) ^ 2
 | 
				
			||||||
 | 
					    \bigg]
 | 
				
			||||||
 | 
					    $$
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    ### Double $Q$-Learning
 | 
				
			||||||
 | 
					    The max operator in the above calculation uses same network for both
 | 
				
			||||||
 | 
					    selecting the best action and for evaluating the value.
 | 
				
			||||||
 | 
					    That is,
 | 
				
			||||||
 | 
					    $$
 | 
				
			||||||
 | 
					    \max_{a'} Q(s', a'; \theta) = \color{cyan}{Q}
 | 
				
			||||||
 | 
					    \Big(
 | 
				
			||||||
 | 
					        s', \mathop{\operatorname{argmax}}_{a'}
 | 
				
			||||||
 | 
					        \color{cyan}{Q}(s', a'; \color{cyan}{\theta}); \color{cyan}{\theta}
 | 
				
			||||||
 | 
					    \Big)
 | 
				
			||||||
 | 
					    $$
 | 
				
			||||||
 | 
					    We use [double Q-learning](https://arxiv.org/abs/1509.06461), where
 | 
				
			||||||
 | 
					    the $\operatorname{argmax}$ is taken from $\color{cyan}{\theta_i}$ and
 | 
				
			||||||
 | 
					    the value is taken from $\color{orange}{\theta_i^{-}}$.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    And the loss function becomes,
 | 
				
			||||||
 | 
					    \begin{align}
 | 
				
			||||||
 | 
					        \mathcal{L}_i(\theta_i) = \mathop{\mathbb{E}}_{(s,a,r,s') \sim U(D)}
 | 
				
			||||||
 | 
					        \Bigg[
 | 
				
			||||||
 | 
					            \bigg(
 | 
				
			||||||
 | 
					                &r + \gamma \color{orange}{Q}
 | 
				
			||||||
 | 
					                \Big(
 | 
				
			||||||
 | 
					                    s',
 | 
				
			||||||
 | 
					                    \mathop{\operatorname{argmax}}_{a'}
 | 
				
			||||||
 | 
					                        \color{cyan}{Q}(s', a'; \color{cyan}{\theta_i}); \color{orange}{\theta_i^{-}}
 | 
				
			||||||
 | 
					                \Big)
 | 
				
			||||||
 | 
					                \\
 | 
				
			||||||
 | 
					                - &Q(s,a;\theta_i)
 | 
				
			||||||
 | 
					            \bigg) ^ 2
 | 
				
			||||||
 | 
					        \Bigg]
 | 
				
			||||||
 | 
					    \end{align}
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __init__(self, gamma: float):
 | 
					    def __init__(self, gamma: float):
 | 
				
			||||||
        super().__init__()
 | 
					        super().__init__()
 | 
				
			||||||
        self.gamma = gamma
 | 
					        self.gamma = gamma
 | 
				
			||||||
 | 
					        self.huber_loss = nn.SmoothL1Loss(reduction='none')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __call__(self, q: torch.Tensor,
 | 
					    def __call__(self, q: torch.Tensor, action: torch.Tensor, double_q: torch.Tensor,
 | 
				
			||||||
                 action: torch.Tensor,
 | 
					                 target_q: torch.Tensor, done: torch.Tensor, reward: torch.Tensor,
 | 
				
			||||||
                 double_q: torch.Tensor,
 | 
					 | 
				
			||||||
                 target_q: torch.Tensor,
 | 
					 | 
				
			||||||
                 done: torch.Tensor,
 | 
					 | 
				
			||||||
                 reward: torch.Tensor,
 | 
					 | 
				
			||||||
                 weights: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
 | 
					                 weights: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        * `q` - $Q(s;\theta_i)$
 | 
				
			||||||
 | 
					        * `action` - $a$
 | 
				
			||||||
 | 
					        * `double_q` - $\color{cyan}Q(s';\color{cyan}{\theta_i})$
 | 
				
			||||||
 | 
					        * `target_q` - $\color{orange}Q(s';\color{orange}{\theta_i^{-}})$
 | 
				
			||||||
 | 
					        * `done` - whether the game ended after taking the action
 | 
				
			||||||
 | 
					        * `reward` - $r$
 | 
				
			||||||
 | 
					        * `weights` - weights of the samples from prioritized experienced replay
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # $Q(s,a;\theta_i)$
 | 
				
			||||||
        q_sampled_action = q.gather(-1, action.to(torch.long).unsqueeze(-1)).squeeze(-1)
 | 
					        q_sampled_action = q.gather(-1, action.to(torch.long).unsqueeze(-1)).squeeze(-1)
 | 
				
			||||||
        tracker.add('q_sampled_action', q_sampled_action)
 | 
					        tracker.add('q_sampled_action', q_sampled_action)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Gradients shouldn't propagate gradients
 | 
				
			||||||
 | 
					        # $$r + \gamma \color{orange}{Q}
 | 
				
			||||||
 | 
					        #                 \Big(s',
 | 
				
			||||||
 | 
					        #                     \mathop{\operatorname{argmax}}_{a'}
 | 
				
			||||||
 | 
					        #                         \color{cyan}{Q}(s', a'; \color{cyan}{\theta_i}); \color{orange}{\theta_i^{-}}
 | 
				
			||||||
 | 
					        #                 \Big)$$
 | 
				
			||||||
        with torch.no_grad():
 | 
					        with torch.no_grad():
 | 
				
			||||||
 | 
					            # Get the best action at state $s'$
 | 
				
			||||||
 | 
					            # $$\mathop{\operatorname{argmax}}_{a'}
 | 
				
			||||||
 | 
					            #  \color{cyan}{Q}(s', a'; \color{cyan}{\theta_i})$$
 | 
				
			||||||
            best_next_action = torch.argmax(double_q, -1)
 | 
					            best_next_action = torch.argmax(double_q, -1)
 | 
				
			||||||
 | 
					            # Get the q value from the target network for the best action at state $s'$
 | 
				
			||||||
 | 
					            # $$\color{orange}{Q}
 | 
				
			||||||
 | 
					            # \Big(s',\mathop{\operatorname{argmax}}_{a'}
 | 
				
			||||||
 | 
					            # \color{cyan}{Q}(s', a'; \color{cyan}{\theta_i}); \color{orange}{\theta_i^{-}}
 | 
				
			||||||
 | 
					            # \Big)$$
 | 
				
			||||||
            best_next_q_value = target_q.gather(-1, best_next_action.unsqueeze(-1)).squeeze(-1)
 | 
					            best_next_q_value = target_q.gather(-1, best_next_action.unsqueeze(-1)).squeeze(-1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            best_next_q_value *= (1 - done)
 | 
					            # Calculate the desired Q value.
 | 
				
			||||||
 | 
					            # We multiply by `(1 - done)` to zero out
 | 
				
			||||||
            q_update = reward + self.gamma * best_next_q_value
 | 
					            # the next state Q values if the game ended.
 | 
				
			||||||
 | 
					            #
 | 
				
			||||||
 | 
					            # $$r + \gamma \color{orange}{Q}
 | 
				
			||||||
 | 
					            #                 \Big(s',
 | 
				
			||||||
 | 
					            #                     \mathop{\operatorname{argmax}}_{a'}
 | 
				
			||||||
 | 
					            #                         \color{cyan}{Q}(s', a'; \color{cyan}{\theta_i}); \color{orange}{\theta_i^{-}}
 | 
				
			||||||
 | 
					            #                 \Big)$$
 | 
				
			||||||
 | 
					            q_update = reward + self.gamma * best_next_q_value * (1 - done)
 | 
				
			||||||
            tracker.add('q_update', q_update)
 | 
					            tracker.add('q_update', q_update)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # Temporal difference error $\delta$ is used to weigh samples in replay buffer
 | 
				
			||||||
            td_error = q_sampled_action - q_update
 | 
					            td_error = q_sampled_action - q_update
 | 
				
			||||||
            tracker.add('td_error', td_error)
 | 
					            tracker.add('td_error', td_error)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Huber loss
 | 
					        # We take [Huber loss](https://en.wikipedia.org/wiki/Huber_loss) instead of
 | 
				
			||||||
        losses = torch.nn.functional.smooth_l1_loss(q_sampled_action, q_update, reduction='none')
 | 
					        # mean squared error loss because it is less sensitive to outliers
 | 
				
			||||||
 | 
					        losses = self.huber_loss(q_sampled_action, q_update)
 | 
				
			||||||
 | 
					        # Get weighted means
 | 
				
			||||||
        loss = torch.mean(weights * losses)
 | 
					        loss = torch.mean(weights * losses)
 | 
				
			||||||
        tracker.add('loss', loss)
 | 
					        tracker.add('loss', loss)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return td_error, loss
 | 
					        return td_error, loss
 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
				
			|||||||
@ -8,12 +8,11 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import numpy as np
 | 
					import numpy as np
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
from torch import nn
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
from labml import tracker, experiment, logger, monit
 | 
					from labml import tracker, experiment, logger, monit
 | 
				
			||||||
from labml_helpers.module import Module
 | 
					 | 
				
			||||||
from labml_helpers.schedule import Piecewise
 | 
					from labml_helpers.schedule import Piecewise
 | 
				
			||||||
from labml_nn.rl.dqn import QFuncLoss
 | 
					from labml_nn.rl.dqn import QFuncLoss
 | 
				
			||||||
 | 
					from labml_nn.rl.dqn.model import Model
 | 
				
			||||||
from labml_nn.rl.dqn.replay_buffer import ReplayBuffer
 | 
					from labml_nn.rl.dqn.replay_buffer import ReplayBuffer
 | 
				
			||||||
from labml_nn.rl.game import Worker
 | 
					from labml_nn.rl.game import Worker
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -23,105 +22,12 @@ else:
 | 
				
			|||||||
    device = torch.device("cpu")
 | 
					    device = torch.device("cpu")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class Model(Module):
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    ## <a name="model"></a>Neural Network Model for $Q$ Values
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    #### Dueling Network ⚔️
 | 
					 | 
				
			||||||
    We are using a [dueling network](https://arxiv.org/abs/1511.06581)
 | 
					 | 
				
			||||||
     to calculate Q-values.
 | 
					 | 
				
			||||||
    Intuition behind dueling network architure is that in most states
 | 
					 | 
				
			||||||
     the action doesn't matter,
 | 
					 | 
				
			||||||
    and in some states the action is significant. Dueling network allows
 | 
					 | 
				
			||||||
     this to be represented very well.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    \begin{align}
 | 
					 | 
				
			||||||
        Q^\pi(s,a) &= V^\pi(s) + A^\pi(s, a)
 | 
					 | 
				
			||||||
        \\
 | 
					 | 
				
			||||||
        \mathop{\mathbb{E}}_{a \sim \pi(s)}
 | 
					 | 
				
			||||||
         \Big[
 | 
					 | 
				
			||||||
          A^\pi(s, a)
 | 
					 | 
				
			||||||
         \Big]
 | 
					 | 
				
			||||||
        &= 0
 | 
					 | 
				
			||||||
    \end{align}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    So we create two networks for $V$ and $A$ and get $Q$ from them.
 | 
					 | 
				
			||||||
    $$
 | 
					 | 
				
			||||||
        Q(s, a) = V(s) +
 | 
					 | 
				
			||||||
        \Big(
 | 
					 | 
				
			||||||
            A(s, a) - \frac{1}{|\mathcal{A}|} \sum_{a' \in \mathcal{A}} A(s, a')
 | 
					 | 
				
			||||||
        \Big)
 | 
					 | 
				
			||||||
    $$
 | 
					 | 
				
			||||||
    We share the initial layers of the $V$ and $A$ networks.
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __init__(self):
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        ### Initialize
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        We need `scope` because we need multiple copies of variables
 | 
					 | 
				
			||||||
         for target network and training network.
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        super().__init__()
 | 
					 | 
				
			||||||
        self.conv = nn.Sequential(
 | 
					 | 
				
			||||||
            # The first convolution layer takes a
 | 
					 | 
				
			||||||
            # 84x84 frame and produces a 20x20 frame
 | 
					 | 
				
			||||||
            nn.Conv2d(in_channels=4, out_channels=32, kernel_size=8, stride=4),
 | 
					 | 
				
			||||||
            nn.ReLU(),
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            # The second convolution layer takes a
 | 
					 | 
				
			||||||
            # 20x20 frame and produces a 9x9 frame
 | 
					 | 
				
			||||||
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2),
 | 
					 | 
				
			||||||
            nn.ReLU(),
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            # The third convolution layer takes a
 | 
					 | 
				
			||||||
            # 9x9 frame and produces a 7x7 frame
 | 
					 | 
				
			||||||
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1),
 | 
					 | 
				
			||||||
            nn.ReLU(),
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # A fully connected layer takes the flattened
 | 
					 | 
				
			||||||
        # frame from third convolution layer, and outputs
 | 
					 | 
				
			||||||
        # 512 features
 | 
					 | 
				
			||||||
        self.lin = nn.Linear(in_features=7 * 7 * 64, out_features=512)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        self.state_score = nn.Sequential(
 | 
					 | 
				
			||||||
            nn.Linear(in_features=512, out_features=256),
 | 
					 | 
				
			||||||
            nn.ReLU(),
 | 
					 | 
				
			||||||
            nn.Linear(in_features=256, out_features=1),
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        self.action_score = nn.Sequential(
 | 
					 | 
				
			||||||
            nn.Linear(in_features=512, out_features=256),
 | 
					 | 
				
			||||||
            nn.ReLU(),
 | 
					 | 
				
			||||||
            nn.Linear(in_features=256, out_features=4),
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        #
 | 
					 | 
				
			||||||
        self.activation = nn.ReLU()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __call__(self, obs: torch.Tensor):
 | 
					 | 
				
			||||||
        h = self.conv(obs)
 | 
					 | 
				
			||||||
        h = h.reshape((-1, 7 * 7 * 64))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        h = self.activation(self.lin(h))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        action_score = self.action_score(h)
 | 
					 | 
				
			||||||
        state_score = self.state_score(h)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # $Q(s, a) =V(s) + \Big(A(s, a) - \frac{1}{|\mathcal{A}|} \sum_{a' \in \mathcal{A}} A(s, a')\Big)$
 | 
					 | 
				
			||||||
        action_score_centered = action_score - action_score.mean(dim=-1, keepdim=True)
 | 
					 | 
				
			||||||
        q = state_score + action_score_centered
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        return q
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def obs_to_torch(obs: np.ndarray) -> torch.Tensor:
 | 
					def obs_to_torch(obs: np.ndarray) -> torch.Tensor:
 | 
				
			||||||
    """Scale observations from `[0, 255]` to `[0, 1]`"""
 | 
					    """Scale observations from `[0, 255]` to `[0, 1]`"""
 | 
				
			||||||
    return torch.tensor(obs, dtype=torch.float32, device=device) / 255.
 | 
					    return torch.tensor(obs, dtype=torch.float32, device=device) / 255.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class Main(object):
 | 
					class Trainer:
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    ## <a name="main"></a>Main class
 | 
					    ## <a name="main"></a>Main class
 | 
				
			||||||
    This class runs the training loop.
 | 
					    This class runs the training loop.
 | 
				
			||||||
@ -239,71 +145,6 @@ class Main(object):
 | 
				
			|||||||
                    self.obs[w] = next_obs
 | 
					                    self.obs[w] = next_obs
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def train(self, beta: float):
 | 
					    def train(self, beta: float):
 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        ### Train the model
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        We want to find optimal action-value function.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        \begin{align}
 | 
					 | 
				
			||||||
            Q^*(s,a) &= \max_\pi \mathbb{E} \Big[
 | 
					 | 
				
			||||||
                r_t + \gamma r_{t + 1} + \gamma^2 r_{t + 2} + ... | s_t = s, a_t = a, \pi
 | 
					 | 
				
			||||||
            \Big]
 | 
					 | 
				
			||||||
        \\
 | 
					 | 
				
			||||||
            Q^*(s,a) &= \mathop{\mathbb{E}}_{s' \sim \large{\varepsilon}} \Big[
 | 
					 | 
				
			||||||
                r + \gamma \max_{a'} Q^* (s', a') | s, a
 | 
					 | 
				
			||||||
            \Big]
 | 
					 | 
				
			||||||
        \end{align}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        #### Target network 🎯
 | 
					 | 
				
			||||||
        In order to improve stability we use experience replay that randomly sample
 | 
					 | 
				
			||||||
        from previous experience $U(D)$. We also use a Q network
 | 
					 | 
				
			||||||
        with a separate set of paramters $\hl1{\theta_i^{-}}$ to calculate the target.
 | 
					 | 
				
			||||||
        $\hl1{\theta_i^{-}}$ is updated periodically.
 | 
					 | 
				
			||||||
        This is according to the [paper by DeepMind](https://deepmind.com/research/dqn/).
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        So the loss function is,
 | 
					 | 
				
			||||||
        $$
 | 
					 | 
				
			||||||
        \mathcal{L}_i(\theta_i) = \mathop{\mathbb{E}}_{(s,a,r,s') \sim U(D)}
 | 
					 | 
				
			||||||
        \bigg[
 | 
					 | 
				
			||||||
            \Big(
 | 
					 | 
				
			||||||
                r + \gamma \max_{a'} Q(s', a'; \hl1{\theta_i^{-}}) - Q(s,a;\theta_i)
 | 
					 | 
				
			||||||
            \Big) ^ 2
 | 
					 | 
				
			||||||
        \bigg]
 | 
					 | 
				
			||||||
        $$
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        #### Double $Q$-Learning
 | 
					 | 
				
			||||||
        The max operator in the above calculation uses same network for both
 | 
					 | 
				
			||||||
        selecting the best action and for evaluating the value.
 | 
					 | 
				
			||||||
        That is,
 | 
					 | 
				
			||||||
        $$
 | 
					 | 
				
			||||||
        \max_{a'} Q(s', a'; \theta) = \blue{Q}
 | 
					 | 
				
			||||||
        \Big(
 | 
					 | 
				
			||||||
            s', \mathop{\operatorname{argmax}}_{a'}
 | 
					 | 
				
			||||||
            \blue{Q}(s', a'; \blue{\theta}); \blue{\theta}
 | 
					 | 
				
			||||||
        \Big)
 | 
					 | 
				
			||||||
        $$
 | 
					 | 
				
			||||||
        We use [double Q-learning](https://arxiv.org/abs/1509.06461), where
 | 
					 | 
				
			||||||
        the $\operatorname{argmax}$ is taken from $\theta_i$ and
 | 
					 | 
				
			||||||
        the value is taken from $\theta_i^{-}$.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        And the loss function becomes,
 | 
					 | 
				
			||||||
        \begin{align}
 | 
					 | 
				
			||||||
            \mathcal{L}_i(\theta_i) = \mathop{\mathbb{E}}_{(s,a,r,s') \sim U(D)}
 | 
					 | 
				
			||||||
            \Bigg[
 | 
					 | 
				
			||||||
                \bigg(
 | 
					 | 
				
			||||||
                    &r + \gamma \blue{Q}
 | 
					 | 
				
			||||||
                    \Big(
 | 
					 | 
				
			||||||
                        s',
 | 
					 | 
				
			||||||
                        \mathop{\operatorname{argmax}}_{a'}
 | 
					 | 
				
			||||||
                            \green{Q}(s', a'; \green{\theta_i}); \blue{\theta_i^{-}}
 | 
					 | 
				
			||||||
                    \Big)
 | 
					 | 
				
			||||||
                    \\
 | 
					 | 
				
			||||||
                    - &Q(s,a;\theta_i)
 | 
					 | 
				
			||||||
                \bigg) ^ 2
 | 
					 | 
				
			||||||
            \Bigg]
 | 
					 | 
				
			||||||
        \end{align}
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        for _ in range(self.train_epochs):
 | 
					        for _ in range(self.train_epochs):
 | 
				
			||||||
            # sample from priority replay buffer
 | 
					            # sample from priority replay buffer
 | 
				
			||||||
            samples = self.replay_buffer.sample(self.mini_batch_size, beta)
 | 
					            samples = self.replay_buffer.sample(self.mini_batch_size, beta)
 | 
				
			||||||
@ -379,7 +220,7 @@ class Main(object):
 | 
				
			|||||||
# ## Run it
 | 
					# ## Run it
 | 
				
			||||||
if __name__ == "__main__":
 | 
					if __name__ == "__main__":
 | 
				
			||||||
    experiment.create(name='dqn')
 | 
					    experiment.create(name='dqn')
 | 
				
			||||||
    m = Main()
 | 
					    m = Trainer()
 | 
				
			||||||
    with experiment.start():
 | 
					    with experiment.start():
 | 
				
			||||||
        m.run_training_loop()
 | 
					        m.run_training_loop()
 | 
				
			||||||
    m.destroy()
 | 
					    m.destroy()
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										100
									
								
								labml_nn/rl/dqn/model.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										100
									
								
								labml_nn/rl/dqn/model.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,100 @@
 | 
				
			|||||||
 | 
					"""
 | 
				
			||||||
 | 
					# Neural Network Model
 | 
				
			||||||
 | 
					"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					from torch import nn
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from labml_helpers.module import Module
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class Model(Module):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    ## Dueling Network ⚔️ Model for $Q$ Values
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    We are using a [dueling network](https://arxiv.org/abs/1511.06581)
 | 
				
			||||||
 | 
					     to calculate Q-values.
 | 
				
			||||||
 | 
					    Intuition behind dueling network architecture is that in most states
 | 
				
			||||||
 | 
					     the action doesn't matter,
 | 
				
			||||||
 | 
					    and in some states the action is significant. Dueling network allows
 | 
				
			||||||
 | 
					     this to be represented very well.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    \begin{align}
 | 
				
			||||||
 | 
					        Q^\pi(s,a) &= V^\pi(s) + A^\pi(s, a)
 | 
				
			||||||
 | 
					        \\
 | 
				
			||||||
 | 
					        \mathop{\mathbb{E}}_{a \sim \pi(s)}
 | 
				
			||||||
 | 
					         \Big[
 | 
				
			||||||
 | 
					          A^\pi(s, a)
 | 
				
			||||||
 | 
					         \Big]
 | 
				
			||||||
 | 
					        &= 0
 | 
				
			||||||
 | 
					    \end{align}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    So we create two networks for $V$ and $A$ and get $Q$ from them.
 | 
				
			||||||
 | 
					    $$
 | 
				
			||||||
 | 
					        Q(s, a) = V(s) +
 | 
				
			||||||
 | 
					        \Big(
 | 
				
			||||||
 | 
					            A(s, a) - \frac{1}{|\mathcal{A}|} \sum_{a' \in \mathcal{A}} A(s, a')
 | 
				
			||||||
 | 
					        \Big)
 | 
				
			||||||
 | 
					    $$
 | 
				
			||||||
 | 
					    We share the initial layers of the $V$ and $A$ networks.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        ### Initialize
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        We need `scope` because we need multiple copies of variables
 | 
				
			||||||
 | 
					         for target network and training network.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					        self.conv = nn.Sequential(
 | 
				
			||||||
 | 
					            # The first convolution layer takes a
 | 
				
			||||||
 | 
					            # 84x84 frame and produces a 20x20 frame
 | 
				
			||||||
 | 
					            nn.Conv2d(in_channels=4, out_channels=32, kernel_size=8, stride=4),
 | 
				
			||||||
 | 
					            nn.ReLU(),
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # The second convolution layer takes a
 | 
				
			||||||
 | 
					            # 20x20 frame and produces a 9x9 frame
 | 
				
			||||||
 | 
					            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2),
 | 
				
			||||||
 | 
					            nn.ReLU(),
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # The third convolution layer takes a
 | 
				
			||||||
 | 
					            # 9x9 frame and produces a 7x7 frame
 | 
				
			||||||
 | 
					            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1),
 | 
				
			||||||
 | 
					            nn.ReLU(),
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # A fully connected layer takes the flattened
 | 
				
			||||||
 | 
					        # frame from third convolution layer, and outputs
 | 
				
			||||||
 | 
					        # 512 features
 | 
				
			||||||
 | 
					        self.lin = nn.Linear(in_features=7 * 7 * 64, out_features=512)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.state_score = nn.Sequential(
 | 
				
			||||||
 | 
					            nn.Linear(in_features=512, out_features=256),
 | 
				
			||||||
 | 
					            nn.ReLU(),
 | 
				
			||||||
 | 
					            nn.Linear(in_features=256, out_features=1),
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.action_score = nn.Sequential(
 | 
				
			||||||
 | 
					            nn.Linear(in_features=512, out_features=256),
 | 
				
			||||||
 | 
					            nn.ReLU(),
 | 
				
			||||||
 | 
					            nn.Linear(in_features=256, out_features=4),
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        #
 | 
				
			||||||
 | 
					        self.activation = nn.ReLU()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __call__(self, obs: torch.Tensor):
 | 
				
			||||||
 | 
					        h = self.conv(obs)
 | 
				
			||||||
 | 
					        h = h.reshape((-1, 7 * 7 * 64))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        h = self.activation(self.lin(h))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        action_score = self.action_score(h)
 | 
				
			||||||
 | 
					        state_score = self.state_score(h)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # $Q(s, a) =V(s) + \Big(A(s, a) - \frac{1}{|\mathcal{A}|} \sum_{a' \in \mathcal{A}} A(s, a')\Big)$
 | 
				
			||||||
 | 
					        action_score_centered = action_score - action_score.mean(dim=-1, keepdim=True)
 | 
				
			||||||
 | 
					        q = state_score + action_score_centered
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return q
 | 
				
			||||||
@ -1,3 +1,10 @@
 | 
				
			|||||||
 | 
					"""
 | 
				
			||||||
 | 
					# Prioritized Experience Replace Buffer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					This implements paper [Prioritized experience replay](https://arxiv.org/abs/1511.05952),
 | 
				
			||||||
 | 
					using a binary segment tree.
 | 
				
			||||||
 | 
					"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import numpy as np
 | 
					import numpy as np
 | 
				
			||||||
import random
 | 
					import random
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -5,9 +12,10 @@ import random
 | 
				
			|||||||
class ReplayBuffer:
 | 
					class ReplayBuffer:
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    ## Buffer for Prioritized Experience Replay
 | 
					    ## Buffer for Prioritized Experience Replay
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    [Prioritized experience replay](https://arxiv.org/abs/1511.05952)
 | 
					    [Prioritized experience replay](https://arxiv.org/abs/1511.05952)
 | 
				
			||||||
     samples important transitions more frequently.
 | 
					     samples important transitions more frequently.
 | 
				
			||||||
    The transitions are prioritized by the Temporal Difference error.
 | 
					    The transitions are prioritized by the Temporal Difference error (td error).
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    We sample transition $i$ with probability,
 | 
					    We sample transition $i$ with probability,
 | 
				
			||||||
    $$P(i) = \frac{p_i^\alpha}{\sum_k p_k^\alpha}$$
 | 
					    $$P(i) = \frac{p_i^\alpha}{\sum_k p_k^\alpha}$$
 | 
				
			||||||
@ -21,16 +29,16 @@ class ReplayBuffer:
 | 
				
			|||||||
     importance-sampling (IS) weights
 | 
					     importance-sampling (IS) weights
 | 
				
			||||||
    $$w_i = \bigg(\frac{1}{N} \frac{1}{P(i)}\bigg)^\beta$$
 | 
					    $$w_i = \bigg(\frac{1}{N} \frac{1}{P(i)}\bigg)^\beta$$
 | 
				
			||||||
    that fully compensates for when $\beta = 1$.
 | 
					    that fully compensates for when $\beta = 1$.
 | 
				
			||||||
    We normalize weights by $1/\max_i w_i$ for stability.
 | 
					    We normalize weights by $\frac{1}{\max_i w_i}$ for stability.
 | 
				
			||||||
    Unbiased nature is most important towards the convergence at end of training.
 | 
					    Unbiased nature is most important towards the convergence at end of training.
 | 
				
			||||||
    Therefore we increase $\beta$ towards end of training.
 | 
					    Therefore we increase $\beta$ towards end of training.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    ### Binary Segment Trees
 | 
					    ### Binary Segment Tree
 | 
				
			||||||
    We use binary segment trees to efficiently calculate
 | 
					    We use a binary segment tree to efficiently calculate
 | 
				
			||||||
    $\sum_k^i p_k^\alpha$, the cumulative probability,
 | 
					    $\sum_k^i p_k^\alpha$, the cumulative probability,
 | 
				
			||||||
    which is needed to sample.
 | 
					    which is needed to sample.
 | 
				
			||||||
    We also use a binary segment tree to find $\min p_i^\alpha$,
 | 
					    We also use a binary segment tree to find $\min p_i^\alpha$,
 | 
				
			||||||
    which is needed for $1/\max_i w_i$.
 | 
					    which is needed for $\frac{1}{\max_i w_i}$.
 | 
				
			||||||
    We can also use a min-heap for this.
 | 
					    We can also use a min-heap for this.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    This is how a binary segment tree works for sum;
 | 
					    This is how a binary segment tree works for sum;
 | 
				
			||||||
@ -54,14 +62,16 @@ class ReplayBuffer:
 | 
				
			|||||||
    $$N_i = \left\lceil{\frac{N}{D - i + i}} \right\rceil$$
 | 
					    $$N_i = \left\lceil{\frac{N}{D - i + i}} \right\rceil$$
 | 
				
			||||||
    This is equal to the sum of nodes in all rows above $i$.
 | 
					    This is equal to the sum of nodes in all rows above $i$.
 | 
				
			||||||
    So we can use a single array $a$ to store the tree, where,
 | 
					    So we can use a single array $a$ to store the tree, where,
 | 
				
			||||||
    $$b_{i,j} = a_{N_1 + j}$$
 | 
					    $$b_{i,j} \rightarrow a_{N_i + j}$$
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    Then child nodes of $a_i$ are $a_{2i}$ and $a_{2i + 1}$.
 | 
					    Then child nodes of $a_i$ are $a_{2i}$ and $a_{2i + 1}$.
 | 
				
			||||||
    That is,
 | 
					    That is,
 | 
				
			||||||
    $$a_i = a_{2i} + a_{2i + 1}$$
 | 
					    $$a_i = a_{2i} + a_{2i + 1}$$
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    This way of maintaining binary trees is very easy to program.
 | 
					    This way of maintaining binary trees is very easy to program.
 | 
				
			||||||
    *Note that we are indexing from 1*.
 | 
					    *Note that we are indexing starting from 1*.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    We using the same structure to compute the minimum.
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __init__(self, capacity, alpha):
 | 
					    def __init__(self, capacity, alpha):
 | 
				
			||||||
@ -206,7 +216,7 @@ class ReplayBuffer:
 | 
				
			|||||||
            # $w_i = \bigg(\frac{1}{N} \frac{1}{P(i)}\bigg)^\beta$
 | 
					            # $w_i = \bigg(\frac{1}{N} \frac{1}{P(i)}\bigg)^\beta$
 | 
				
			||||||
            weight = (prob * self.size) ** (-beta)
 | 
					            weight = (prob * self.size) ** (-beta)
 | 
				
			||||||
            # normalize by $\frac{1}{\max_i w_i}$,
 | 
					            # normalize by $\frac{1}{\max_i w_i}$,
 | 
				
			||||||
            #  which also cancels off the $\frac{1}/{N}$ term
 | 
					            #  which also cancels off the $\frac{1}{N}$ term
 | 
				
			||||||
            samples['weights'][i] = weight / max_weight
 | 
					            samples['weights'][i] = weight / max_weight
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # get samples data
 | 
					        # get samples data
 | 
				
			||||||
@ -230,7 +240,5 @@ class ReplayBuffer:
 | 
				
			|||||||
    def is_full(self):
 | 
					    def is_full(self):
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        ### Is the buffer full
 | 
					        ### Is the buffer full
 | 
				
			||||||
 | 
					 | 
				
			||||||
        We only start sampling afte the buffer is full.
 | 
					 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        return self.capacity == self.size
 | 
					        return self.capacity == self.size
 | 
				
			||||||
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user