mirror of
				https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
				synced 2025-11-04 14:29:43 +08:00 
			
		
		
		
	✨ DQN
This commit is contained in:
		@ -0,0 +1,50 @@
 | 
				
			|||||||
 | 
					"""
 | 
				
			||||||
 | 
					This is a Deep Q Learning implementation with:
 | 
				
			||||||
 | 
					* Double Q Network
 | 
				
			||||||
 | 
					* Dueling Network
 | 
				
			||||||
 | 
					* Prioritized Replay
 | 
				
			||||||
 | 
					"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from typing import Tuple
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from labml import tracker
 | 
				
			||||||
 | 
					from labml_helpers.module import Module
 | 
				
			||||||
 | 
					from labml_nn.rl.dqn.replay_buffer import ReplayBuffer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class QFuncLoss(Module):
 | 
				
			||||||
 | 
					    def __init__(self, gamma: float):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					        self.gamma = gamma
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __call__(self, q: torch.Tensor,
 | 
				
			||||||
 | 
					                 action: torch.Tensor,
 | 
				
			||||||
 | 
					                 double_q: torch.Tensor,
 | 
				
			||||||
 | 
					                 target_q: torch.Tensor,
 | 
				
			||||||
 | 
					                 done: torch.Tensor,
 | 
				
			||||||
 | 
					                 reward: torch.Tensor,
 | 
				
			||||||
 | 
					                 weights: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
 | 
				
			||||||
 | 
					        q_sampled_action = q.gather(-1, action.to(torch.long).unsqueeze(-1)).squeeze(-1)
 | 
				
			||||||
 | 
					        tracker.add('q_sampled_action', q_sampled_action)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        with torch.no_grad():
 | 
				
			||||||
 | 
					            best_next_action = torch.argmax(double_q, -1)
 | 
				
			||||||
 | 
					            best_next_q_value = target_q.gather(-1, best_next_action.unsqueeze(-1)).squeeze(-1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            best_next_q_value *= (1 - done)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            q_update = reward + self.gamma * best_next_q_value
 | 
				
			||||||
 | 
					            tracker.add('q_update', q_update)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            td_error = q_sampled_action - q_update
 | 
				
			||||||
 | 
					            tracker.add('td_error', td_error)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Huber loss
 | 
				
			||||||
 | 
					        losses = torch.nn.functional.smooth_l1_loss(q_sampled_action, q_update, reduction='none')
 | 
				
			||||||
 | 
					        loss = torch.mean(weights * losses)
 | 
				
			||||||
 | 
					        tracker.add('loss', loss)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return td_error, loss
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										385
									
								
								labml_nn/rl/dqn/experiment.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										385
									
								
								labml_nn/rl/dqn/experiment.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,385 @@
 | 
				
			|||||||
 | 
					"""
 | 
				
			||||||
 | 
					\(
 | 
				
			||||||
 | 
					   \def\hl1#1{{\color{orange}{#1}}}
 | 
				
			||||||
 | 
					   \def\blue#1{{\color{cyan}{#1}}}
 | 
				
			||||||
 | 
					   \def\green#1{{\color{yellowgreen}{#1}}}
 | 
				
			||||||
 | 
					\)
 | 
				
			||||||
 | 
					"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import numpy as np
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					from torch import nn
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from labml import tracker, experiment, logger, monit
 | 
				
			||||||
 | 
					from labml_helpers.module import Module
 | 
				
			||||||
 | 
					from labml_helpers.schedule import Piecewise
 | 
				
			||||||
 | 
					from labml_nn.rl.dqn import QFuncLoss
 | 
				
			||||||
 | 
					from labml_nn.rl.dqn.replay_buffer import ReplayBuffer
 | 
				
			||||||
 | 
					from labml_nn.rl.game import Worker
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if torch.cuda.is_available():
 | 
				
			||||||
 | 
					    device = torch.device("cuda:0")
 | 
				
			||||||
 | 
					else:
 | 
				
			||||||
 | 
					    device = torch.device("cpu")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class Model(Module):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    ## <a name="model"></a>Neural Network Model for $Q$ Values
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    #### Dueling Network ⚔️
 | 
				
			||||||
 | 
					    We are using a [dueling network](https://arxiv.org/abs/1511.06581)
 | 
				
			||||||
 | 
					     to calculate Q-values.
 | 
				
			||||||
 | 
					    Intuition behind dueling network architure is that in most states
 | 
				
			||||||
 | 
					     the action doesn't matter,
 | 
				
			||||||
 | 
					    and in some states the action is significant. Dueling network allows
 | 
				
			||||||
 | 
					     this to be represented very well.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    \begin{align}
 | 
				
			||||||
 | 
					        Q^\pi(s,a) &= V^\pi(s) + A^\pi(s, a)
 | 
				
			||||||
 | 
					        \\
 | 
				
			||||||
 | 
					        \mathop{\mathbb{E}}_{a \sim \pi(s)}
 | 
				
			||||||
 | 
					         \Big[
 | 
				
			||||||
 | 
					          A^\pi(s, a)
 | 
				
			||||||
 | 
					         \Big]
 | 
				
			||||||
 | 
					        &= 0
 | 
				
			||||||
 | 
					    \end{align}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    So we create two networks for $V$ and $A$ and get $Q$ from them.
 | 
				
			||||||
 | 
					    $$
 | 
				
			||||||
 | 
					        Q(s, a) = V(s) +
 | 
				
			||||||
 | 
					        \Big(
 | 
				
			||||||
 | 
					            A(s, a) - \frac{1}{|\mathcal{A}|} \sum_{a' \in \mathcal{A}} A(s, a')
 | 
				
			||||||
 | 
					        \Big)
 | 
				
			||||||
 | 
					    $$
 | 
				
			||||||
 | 
					    We share the initial layers of the $V$ and $A$ networks.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        ### Initialize
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        We need `scope` because we need multiple copies of variables
 | 
				
			||||||
 | 
					         for target network and training network.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					        self.conv = nn.Sequential(
 | 
				
			||||||
 | 
					            # The first convolution layer takes a
 | 
				
			||||||
 | 
					            # 84x84 frame and produces a 20x20 frame
 | 
				
			||||||
 | 
					            nn.Conv2d(in_channels=4, out_channels=32, kernel_size=8, stride=4),
 | 
				
			||||||
 | 
					            nn.ReLU(),
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # The second convolution layer takes a
 | 
				
			||||||
 | 
					            # 20x20 frame and produces a 9x9 frame
 | 
				
			||||||
 | 
					            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2),
 | 
				
			||||||
 | 
					            nn.ReLU(),
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # The third convolution layer takes a
 | 
				
			||||||
 | 
					            # 9x9 frame and produces a 7x7 frame
 | 
				
			||||||
 | 
					            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1),
 | 
				
			||||||
 | 
					            nn.ReLU(),
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # A fully connected layer takes the flattened
 | 
				
			||||||
 | 
					        # frame from third convolution layer, and outputs
 | 
				
			||||||
 | 
					        # 512 features
 | 
				
			||||||
 | 
					        self.lin = nn.Linear(in_features=7 * 7 * 64, out_features=512)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.state_score = nn.Sequential(
 | 
				
			||||||
 | 
					            nn.Linear(in_features=512, out_features=256),
 | 
				
			||||||
 | 
					            nn.ReLU(),
 | 
				
			||||||
 | 
					            nn.Linear(in_features=256, out_features=1),
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.action_score = nn.Sequential(
 | 
				
			||||||
 | 
					            nn.Linear(in_features=512, out_features=256),
 | 
				
			||||||
 | 
					            nn.ReLU(),
 | 
				
			||||||
 | 
					            nn.Linear(in_features=256, out_features=4),
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        #
 | 
				
			||||||
 | 
					        self.activation = nn.ReLU()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __call__(self, obs: torch.Tensor):
 | 
				
			||||||
 | 
					        h = self.conv(obs)
 | 
				
			||||||
 | 
					        h = h.reshape((-1, 7 * 7 * 64))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        h = self.activation(self.lin(h))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        action_score = self.action_score(h)
 | 
				
			||||||
 | 
					        state_score = self.state_score(h)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # $Q(s, a) =V(s) + \Big(A(s, a) - \frac{1}{|\mathcal{A}|} \sum_{a' \in \mathcal{A}} A(s, a')\Big)$
 | 
				
			||||||
 | 
					        action_score_centered = action_score - action_score.mean(dim=-1, keepdim=True)
 | 
				
			||||||
 | 
					        q = state_score + action_score_centered
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return q
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def obs_to_torch(obs: np.ndarray) -> torch.Tensor:
 | 
				
			||||||
 | 
					    """Scale observations from `[0, 255]` to `[0, 1]`"""
 | 
				
			||||||
 | 
					    return torch.tensor(obs, dtype=torch.float32, device=device) / 255.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class Main(object):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    ## <a name="main"></a>Main class
 | 
				
			||||||
 | 
					    This class runs the training loop.
 | 
				
			||||||
 | 
					    It initializes TensorFlow, handles logging and monitoring,
 | 
				
			||||||
 | 
					     and runs workers as multiple processes.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        ### Initialize
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # #### Configurations
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # number of workers
 | 
				
			||||||
 | 
					        self.n_workers = 8
 | 
				
			||||||
 | 
					        # steps sampled on each update
 | 
				
			||||||
 | 
					        self.worker_steps = 4
 | 
				
			||||||
 | 
					        # number of training iterations
 | 
				
			||||||
 | 
					        self.train_epochs = 8
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # number of updates
 | 
				
			||||||
 | 
					        self.updates = 1_000_000
 | 
				
			||||||
 | 
					        # size of mini batch for training
 | 
				
			||||||
 | 
					        self.mini_batch_size = 32
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # exploration as a function of time step
 | 
				
			||||||
 | 
					        self.exploration_coefficient = Piecewise(
 | 
				
			||||||
 | 
					            [
 | 
				
			||||||
 | 
					                (0, 1.0),
 | 
				
			||||||
 | 
					                (25_000, 0.1),
 | 
				
			||||||
 | 
					                (self.updates / 2, 0.01)
 | 
				
			||||||
 | 
					            ], outside_value=0.01)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # update target network every 250 update
 | 
				
			||||||
 | 
					        self.update_target_model = 250
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # $\beta$ for replay buffer as a function of time steps
 | 
				
			||||||
 | 
					        self.prioritized_replay_beta = Piecewise(
 | 
				
			||||||
 | 
					            [
 | 
				
			||||||
 | 
					                (0, 0.4),
 | 
				
			||||||
 | 
					                (self.updates, 1)
 | 
				
			||||||
 | 
					            ], outside_value=1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # replay buffer
 | 
				
			||||||
 | 
					        self.replay_buffer = ReplayBuffer(2 ** 14, 0.6)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.model = Model().to(device)
 | 
				
			||||||
 | 
					        self.target_model = Model().to(device)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # last observation for each worker
 | 
				
			||||||
 | 
					        # create workers
 | 
				
			||||||
 | 
					        self.workers = [Worker(47 + i) for i in range(self.n_workers)]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # initialize tensors for observations
 | 
				
			||||||
 | 
					        self.obs = np.zeros((self.n_workers, 4, 84, 84), dtype=np.uint8)
 | 
				
			||||||
 | 
					        for worker in self.workers:
 | 
				
			||||||
 | 
					            worker.child.send(("reset", None))
 | 
				
			||||||
 | 
					        for i, worker in enumerate(self.workers):
 | 
				
			||||||
 | 
					            self.obs[i] = worker.child.recv()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.loss_func = QFuncLoss(0.99)
 | 
				
			||||||
 | 
					        # optimizer
 | 
				
			||||||
 | 
					        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=2.5e-4)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _sample_action(self, q_value: torch.Tensor, exploration_coefficient: float):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        #### $\epsilon$-greedy Sampling
 | 
				
			||||||
 | 
					        When sampling actions we use a $\epsilon$-greedy strategy, where we
 | 
				
			||||||
 | 
					        take a greedy action with probabiliy $1 - \epsilon$ and
 | 
				
			||||||
 | 
					        take a random action with probability $\epsilon$.
 | 
				
			||||||
 | 
					        We refer to $\epsilon$ as *exploration*.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        with torch.no_grad():
 | 
				
			||||||
 | 
					            greedy_action = torch.argmax(q_value, dim=-1)
 | 
				
			||||||
 | 
					            random_action = torch.randint(q_value.shape[-1], greedy_action.shape, device=q_value.device)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            is_choose_rand = torch.rand(greedy_action.shape, device=q_value.device) < exploration_coefficient
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            return torch.where(is_choose_rand, random_action, greedy_action).cpu().numpy()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def sample(self, exploration_coefficient: float):
 | 
				
			||||||
 | 
					        """### Sample data"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        with torch.no_grad():
 | 
				
			||||||
 | 
					            # sample `SAMPLE_STEPS`
 | 
				
			||||||
 | 
					            for t in range(self.worker_steps):
 | 
				
			||||||
 | 
					                # sample actions
 | 
				
			||||||
 | 
					                q_value = self.model(obs_to_torch(self.obs))
 | 
				
			||||||
 | 
					                actions = self._sample_action(q_value, exploration_coefficient)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                # run sampled actions on each worker
 | 
				
			||||||
 | 
					                for w, worker in enumerate(self.workers):
 | 
				
			||||||
 | 
					                    worker.child.send(("step", actions[w]))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                # collect information from each worker
 | 
				
			||||||
 | 
					                for w, worker in enumerate(self.workers):
 | 
				
			||||||
 | 
					                    # get results after executing the actions
 | 
				
			||||||
 | 
					                    next_obs, reward, done, info = worker.child.recv()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    # add transition to replay buffer
 | 
				
			||||||
 | 
					                    self.replay_buffer.add(self.obs[w], actions[w], reward, next_obs, done)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    # update episode information
 | 
				
			||||||
 | 
					                    # collect episode info, which is available if an episode finished;
 | 
				
			||||||
 | 
					                    #  this includes total reward and length of the episode -
 | 
				
			||||||
 | 
					                    #  look at `Game` to see how it works.
 | 
				
			||||||
 | 
					                    # We also add a game frame to it for monitoring.
 | 
				
			||||||
 | 
					                    if info:
 | 
				
			||||||
 | 
					                        tracker.add('reward', info['reward'])
 | 
				
			||||||
 | 
					                        tracker.add('length', info['length'])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    # update current observation
 | 
				
			||||||
 | 
					                    self.obs[w] = next_obs
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def train(self, beta: float):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        ### Train the model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        We want to find optimal action-value function.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        \begin{align}
 | 
				
			||||||
 | 
					            Q^*(s,a) &= \max_\pi \mathbb{E} \Big[
 | 
				
			||||||
 | 
					                r_t + \gamma r_{t + 1} + \gamma^2 r_{t + 2} + ... | s_t = s, a_t = a, \pi
 | 
				
			||||||
 | 
					            \Big]
 | 
				
			||||||
 | 
					        \\
 | 
				
			||||||
 | 
					            Q^*(s,a) &= \mathop{\mathbb{E}}_{s' \sim \large{\varepsilon}} \Big[
 | 
				
			||||||
 | 
					                r + \gamma \max_{a'} Q^* (s', a') | s, a
 | 
				
			||||||
 | 
					            \Big]
 | 
				
			||||||
 | 
					        \end{align}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        #### Target network 🎯
 | 
				
			||||||
 | 
					        In order to improve stability we use experience replay that randomly sample
 | 
				
			||||||
 | 
					        from previous experience $U(D)$. We also use a Q network
 | 
				
			||||||
 | 
					        with a separate set of paramters $\hl1{\theta_i^{-}}$ to calculate the target.
 | 
				
			||||||
 | 
					        $\hl1{\theta_i^{-}}$ is updated periodically.
 | 
				
			||||||
 | 
					        This is according to the [paper by DeepMind](https://deepmind.com/research/dqn/).
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        So the loss function is,
 | 
				
			||||||
 | 
					        $$
 | 
				
			||||||
 | 
					        \mathcal{L}_i(\theta_i) = \mathop{\mathbb{E}}_{(s,a,r,s') \sim U(D)}
 | 
				
			||||||
 | 
					        \bigg[
 | 
				
			||||||
 | 
					            \Big(
 | 
				
			||||||
 | 
					                r + \gamma \max_{a'} Q(s', a'; \hl1{\theta_i^{-}}) - Q(s,a;\theta_i)
 | 
				
			||||||
 | 
					            \Big) ^ 2
 | 
				
			||||||
 | 
					        \bigg]
 | 
				
			||||||
 | 
					        $$
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        #### Double $Q$-Learning
 | 
				
			||||||
 | 
					        The max operator in the above calculation uses same network for both
 | 
				
			||||||
 | 
					        selecting the best action and for evaluating the value.
 | 
				
			||||||
 | 
					        That is,
 | 
				
			||||||
 | 
					        $$
 | 
				
			||||||
 | 
					        \max_{a'} Q(s', a'; \theta) = \blue{Q}
 | 
				
			||||||
 | 
					        \Big(
 | 
				
			||||||
 | 
					            s', \mathop{\operatorname{argmax}}_{a'}
 | 
				
			||||||
 | 
					            \blue{Q}(s', a'; \blue{\theta}); \blue{\theta}
 | 
				
			||||||
 | 
					        \Big)
 | 
				
			||||||
 | 
					        $$
 | 
				
			||||||
 | 
					        We use [double Q-learning](https://arxiv.org/abs/1509.06461), where
 | 
				
			||||||
 | 
					        the $\operatorname{argmax}$ is taken from $\theta_i$ and
 | 
				
			||||||
 | 
					        the value is taken from $\theta_i^{-}$.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        And the loss function becomes,
 | 
				
			||||||
 | 
					        \begin{align}
 | 
				
			||||||
 | 
					            \mathcal{L}_i(\theta_i) = \mathop{\mathbb{E}}_{(s,a,r,s') \sim U(D)}
 | 
				
			||||||
 | 
					            \Bigg[
 | 
				
			||||||
 | 
					                \bigg(
 | 
				
			||||||
 | 
					                    &r + \gamma \blue{Q}
 | 
				
			||||||
 | 
					                    \Big(
 | 
				
			||||||
 | 
					                        s',
 | 
				
			||||||
 | 
					                        \mathop{\operatorname{argmax}}_{a'}
 | 
				
			||||||
 | 
					                            \green{Q}(s', a'; \green{\theta_i}); \blue{\theta_i^{-}}
 | 
				
			||||||
 | 
					                    \Big)
 | 
				
			||||||
 | 
					                    \\
 | 
				
			||||||
 | 
					                    - &Q(s,a;\theta_i)
 | 
				
			||||||
 | 
					                \bigg) ^ 2
 | 
				
			||||||
 | 
					            \Bigg]
 | 
				
			||||||
 | 
					        \end{align}
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        for _ in range(self.train_epochs):
 | 
				
			||||||
 | 
					            # sample from priority replay buffer
 | 
				
			||||||
 | 
					            samples = self.replay_buffer.sample(self.mini_batch_size, beta)
 | 
				
			||||||
 | 
					            # train network
 | 
				
			||||||
 | 
					            q_value = self.model(obs_to_torch(samples['obs']))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            with torch.no_grad():
 | 
				
			||||||
 | 
					                double_q_value = self.model(obs_to_torch(samples['next_obs']))
 | 
				
			||||||
 | 
					                target_q_value = self.target_model(obs_to_torch(samples['next_obs']))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            td_errors, loss = self.loss_func(q_value,
 | 
				
			||||||
 | 
					                                             q_value.new_tensor(samples['action']),
 | 
				
			||||||
 | 
					                                             double_q_value, target_q_value,
 | 
				
			||||||
 | 
					                                             q_value.new_tensor(samples['done']),
 | 
				
			||||||
 | 
					                                             q_value.new_tensor(samples['reward']),
 | 
				
			||||||
 | 
					                                             q_value.new_tensor(samples['weights']))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # $p_i = |\delta_i| + \epsilon$
 | 
				
			||||||
 | 
					            new_priorities = np.abs(td_errors.cpu().numpy()) + 1e-6
 | 
				
			||||||
 | 
					            # update replay buffer
 | 
				
			||||||
 | 
					            self.replay_buffer.update_priorities(samples['indexes'], new_priorities)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # compute gradients
 | 
				
			||||||
 | 
					            self.optimizer.zero_grad()
 | 
				
			||||||
 | 
					            loss.backward()
 | 
				
			||||||
 | 
					            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=0.5)
 | 
				
			||||||
 | 
					            self.optimizer.step()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def run_training_loop(self):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        ### Run training loop
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # copy to target network initially
 | 
				
			||||||
 | 
					        self.target_model.load_state_dict(self.model.state_dict())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # last 100 episode information
 | 
				
			||||||
 | 
					        tracker.set_queue('reward', 100, True)
 | 
				
			||||||
 | 
					        tracker.set_queue('length', 100, True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        for update in monit.loop(self.updates):
 | 
				
			||||||
 | 
					            # $\epsilon$, exploration fraction
 | 
				
			||||||
 | 
					            exploration = self.exploration_coefficient(update)
 | 
				
			||||||
 | 
					            tracker.add('exploration', exploration)
 | 
				
			||||||
 | 
					            # $\beta$ for priority replay
 | 
				
			||||||
 | 
					            beta = self.prioritized_replay_beta(update)
 | 
				
			||||||
 | 
					            tracker.add('beta', beta)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # sample with current policy
 | 
				
			||||||
 | 
					            self.sample(exploration)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if self.replay_buffer.is_full():
 | 
				
			||||||
 | 
					                # train the model
 | 
				
			||||||
 | 
					                self.train(beta)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                # periodically update target network
 | 
				
			||||||
 | 
					                if update % self.update_target_model == 0:
 | 
				
			||||||
 | 
					                    self.target_model.load_state_dict(self.model.state_dict())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            tracker.save()
 | 
				
			||||||
 | 
					            if (update + 1) % 1_000 == 0:
 | 
				
			||||||
 | 
					                logger.log()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def destroy(self):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        ### Destroy
 | 
				
			||||||
 | 
					        Stop the workers
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        for worker in self.workers:
 | 
				
			||||||
 | 
					            worker.child.send(("close", None))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# ## Run it
 | 
				
			||||||
 | 
					if __name__ == "__main__":
 | 
				
			||||||
 | 
					    experiment.create(name='dqn')
 | 
				
			||||||
 | 
					    m = Main()
 | 
				
			||||||
 | 
					    with experiment.start():
 | 
				
			||||||
 | 
					        m.run_training_loop()
 | 
				
			||||||
 | 
					    m.destroy()
 | 
				
			||||||
							
								
								
									
										236
									
								
								labml_nn/rl/dqn/replay_buffer.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										236
									
								
								labml_nn/rl/dqn/replay_buffer.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,236 @@
 | 
				
			|||||||
 | 
					import numpy as np
 | 
				
			||||||
 | 
					import random
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class ReplayBuffer:
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    ## Buffer for Prioritized Experience Replay
 | 
				
			||||||
 | 
					    [Prioritized experience replay](https://arxiv.org/abs/1511.05952)
 | 
				
			||||||
 | 
					     samples important transitions more frequently.
 | 
				
			||||||
 | 
					    The transitions are prioritized by the Temporal Difference error.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    We sample transition $i$ with probability,
 | 
				
			||||||
 | 
					    $$P(i) = \frac{p_i^\alpha}{\sum_k p_k^\alpha}$$
 | 
				
			||||||
 | 
					    where $\alpha$ is a hyper-parameter that determines how much
 | 
				
			||||||
 | 
					    prioritization is used, with $\alpha = 0$ corresponding to uniform case.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    We use proportional prioritization $p_i = |\delta_i| + \epsilon$ where
 | 
				
			||||||
 | 
					    $\delta_i$ is the temporal difference for transition $i$.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    We correct the bias introduced by prioritized replay by
 | 
				
			||||||
 | 
					     importance-sampling (IS) weights
 | 
				
			||||||
 | 
					    $$w_i = \bigg(\frac{1}{N} \frac{1}{P(i)}\bigg)^\beta$$
 | 
				
			||||||
 | 
					    that fully compensates for when $\beta = 1$.
 | 
				
			||||||
 | 
					    We normalize weights by $1/\max_i w_i$ for stability.
 | 
				
			||||||
 | 
					    Unbiased nature is most important towards the convergence at end of training.
 | 
				
			||||||
 | 
					    Therefore we increase $\beta$ towards end of training.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    ### Binary Segment Trees
 | 
				
			||||||
 | 
					    We use binary segment trees to efficiently calculate
 | 
				
			||||||
 | 
					    $\sum_k^i p_k^\alpha$, the cumulative probability,
 | 
				
			||||||
 | 
					    which is needed to sample.
 | 
				
			||||||
 | 
					    We also use a binary segment tree to find $\min p_i^\alpha$,
 | 
				
			||||||
 | 
					    which is needed for $1/\max_i w_i$.
 | 
				
			||||||
 | 
					    We can also use a min-heap for this.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    This is how a binary segment tree works for sum;
 | 
				
			||||||
 | 
					    it is similar for minimum.
 | 
				
			||||||
 | 
					    Let $x_i$ be the list of $N$ values we want to represent.
 | 
				
			||||||
 | 
					    Let $b_{i,j}$ be the $j^{\mathop{th}}$ node of the $i^{\mathop{th}}$ row
 | 
				
			||||||
 | 
					     in the binary tree.
 | 
				
			||||||
 | 
					    That is two children of node $b_{i,j}$ are $b_{i+1,2j}$ and $b_{i+1,2j + 1}$.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    The leaf nodes on row $D = \left\lceil {1 + \log_2 N} \right\rceil$
 | 
				
			||||||
 | 
					     will have values of $x$.
 | 
				
			||||||
 | 
					    Every node keeps the sum of the two child nodes.
 | 
				
			||||||
 | 
					    So the root node keeps the sum of the entire array of values.
 | 
				
			||||||
 | 
					    The two children of the root node keep
 | 
				
			||||||
 | 
					     the sum of the first half of the array and
 | 
				
			||||||
 | 
					     the sum of the second half of the array, and so on.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    $$b_{i,j} = \sum_{k = (j -1) * 2^{D - i} + 1}^{j * 2^{D - i}} x_k$$
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Number of nodes in row $i$,
 | 
				
			||||||
 | 
					    $$N_i = \left\lceil{\frac{N}{D - i + i}} \right\rceil$$
 | 
				
			||||||
 | 
					    This is equal to the sum of nodes in all rows above $i$.
 | 
				
			||||||
 | 
					    So we can use a single array $a$ to store the tree, where,
 | 
				
			||||||
 | 
					    $$b_{i,j} = a_{N_1 + j}$$
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Then child nodes of $a_i$ are $a_{2i}$ and $a_{2i + 1}$.
 | 
				
			||||||
 | 
					    That is,
 | 
				
			||||||
 | 
					    $$a_i = a_{2i} + a_{2i + 1}$$
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    This way of maintaining binary trees is very easy to program.
 | 
				
			||||||
 | 
					    *Note that we are indexing from 1*.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self, capacity, alpha):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        ### Initialize
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        # we use a power of 2 for capacity to make it easy to debug
 | 
				
			||||||
 | 
					        self.capacity = capacity
 | 
				
			||||||
 | 
					        # we refill the queue once it reaches capacity
 | 
				
			||||||
 | 
					        self.next_idx = 0
 | 
				
			||||||
 | 
					        # $\alpha$
 | 
				
			||||||
 | 
					        self.alpha = alpha
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # maintain segment binary trees to take sum and find minimum over a range
 | 
				
			||||||
 | 
					        self.priority_sum = [0 for _ in range(2 * self.capacity)]
 | 
				
			||||||
 | 
					        self.priority_min = [float('inf') for _ in range(2 * self.capacity)]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # current max priority, $p$, to be assigned to new transitions
 | 
				
			||||||
 | 
					        self.max_priority = 1.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # arrays for buffer
 | 
				
			||||||
 | 
					        self.data = {
 | 
				
			||||||
 | 
					            'obs': np.zeros(shape=(capacity, 4, 84, 84), dtype=np.uint8),
 | 
				
			||||||
 | 
					            'action': np.zeros(shape=capacity, dtype=np.int32),
 | 
				
			||||||
 | 
					            'reward': np.zeros(shape=capacity, dtype=np.float32),
 | 
				
			||||||
 | 
					            'next_obs': np.zeros(shape=(capacity, 4, 84, 84), dtype=np.uint8),
 | 
				
			||||||
 | 
					            'done': np.zeros(shape=capacity, dtype=np.bool)
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # size of the buffer
 | 
				
			||||||
 | 
					        self.size = 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def add(self, obs, action, reward, next_obs, done):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        ### Add sample to queue
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        idx = self.next_idx
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # store in the queue
 | 
				
			||||||
 | 
					        self.data['obs'][idx] = obs
 | 
				
			||||||
 | 
					        self.data['action'][idx] = action
 | 
				
			||||||
 | 
					        self.data['reward'][idx] = reward
 | 
				
			||||||
 | 
					        self.data['next_obs'][idx] = next_obs
 | 
				
			||||||
 | 
					        self.data['done'][idx] = done
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # increment head of the queue and calculate the size
 | 
				
			||||||
 | 
					        self.next_idx = (idx + 1) % self.capacity
 | 
				
			||||||
 | 
					        self.size = min(self.capacity, self.size + 1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # $p_i^\alpha$, new samples get `max_priority`
 | 
				
			||||||
 | 
					        priority_alpha = self.max_priority ** self.alpha
 | 
				
			||||||
 | 
					        self._set_priority_min(idx, priority_alpha)
 | 
				
			||||||
 | 
					        self._set_priority_sum(idx, priority_alpha)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _set_priority_min(self, idx, priority_alpha):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        #### Set priority in binary segment tree for minimum
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # leaf of the binary tree
 | 
				
			||||||
 | 
					        idx += self.capacity
 | 
				
			||||||
 | 
					        self.priority_min[idx] = priority_alpha
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # update tree, by traversing along ancestors
 | 
				
			||||||
 | 
					        while idx >= 2:
 | 
				
			||||||
 | 
					            idx //= 2
 | 
				
			||||||
 | 
					            self.priority_min[idx] = min(self.priority_min[2 * idx],
 | 
				
			||||||
 | 
					                                         self.priority_min[2 * idx + 1])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _set_priority_sum(self, idx, priority):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        #### Set priority in binary segment tree for sum
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # leaf of the binary tree
 | 
				
			||||||
 | 
					        idx += self.capacity
 | 
				
			||||||
 | 
					        self.priority_sum[idx] = priority
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # update tree, by traversing along ancestors
 | 
				
			||||||
 | 
					        while idx >= 2:
 | 
				
			||||||
 | 
					            idx //= 2
 | 
				
			||||||
 | 
					            self.priority_sum[idx] = self.priority_sum[2 * idx] + self.priority_sum[2 * idx + 1]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _sum(self):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        #### $\sum_k p_k^\alpha$
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        return self.priority_sum[1]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _min(self):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        #### $\min_k p_k^\alpha$
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        return self.priority_min[1]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def find_prefix_sum_idx(self, prefix_sum):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        #### Find largest $i$ such that $\sum_{k=1}^{i} p_k^\alpha  \le P$
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # start from the root
 | 
				
			||||||
 | 
					        idx = 1
 | 
				
			||||||
 | 
					        while idx < self.capacity:
 | 
				
			||||||
 | 
					            # if the sum of the left branch is higher than required sum
 | 
				
			||||||
 | 
					            if self.priority_sum[idx * 2] > prefix_sum:
 | 
				
			||||||
 | 
					                # go to left branch if the tree if the
 | 
				
			||||||
 | 
					                idx = 2 * idx
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                # otherwise go to right branch and reduce the sum of left
 | 
				
			||||||
 | 
					                #  branch from required sum
 | 
				
			||||||
 | 
					                prefix_sum -= self.priority_sum[idx * 2]
 | 
				
			||||||
 | 
					                idx = 2 * idx + 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return idx - self.capacity
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def sample(self, batch_size, beta):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        ### Sample from buffer
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        samples = {
 | 
				
			||||||
 | 
					            'weights': np.zeros(shape=batch_size, dtype=np.float32),
 | 
				
			||||||
 | 
					            'indexes': np.zeros(shape=batch_size, dtype=np.int32)
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # get samples
 | 
				
			||||||
 | 
					        for i in range(batch_size):
 | 
				
			||||||
 | 
					            p = random.random() * self._sum()
 | 
				
			||||||
 | 
					            idx = self.find_prefix_sum_idx(p)
 | 
				
			||||||
 | 
					            samples['indexes'][i] = idx
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # $\min_i P(i) = \frac{\min_i p_i^\alpha}{\sum_k p_k^\alpha}$
 | 
				
			||||||
 | 
					        prob_min = self._min() / self._sum()
 | 
				
			||||||
 | 
					        # $\max_i w_i = \bigg(\frac{1}{N} \frac{1}{\min_i P(i)}\bigg)^\beta$
 | 
				
			||||||
 | 
					        max_weight = (prob_min * self.size) ** (-beta)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        for i in range(batch_size):
 | 
				
			||||||
 | 
					            idx = samples['indexes'][i]
 | 
				
			||||||
 | 
					            # $P(i) = \frac{p_i^\alpha}{\sum_k p_k^\alpha}$
 | 
				
			||||||
 | 
					            prob = self.priority_sum[idx + self.capacity] / self._sum()
 | 
				
			||||||
 | 
					            # $w_i = \bigg(\frac{1}{N} \frac{1}{P(i)}\bigg)^\beta$
 | 
				
			||||||
 | 
					            weight = (prob * self.size) ** (-beta)
 | 
				
			||||||
 | 
					            # normalize by $\frac{1}{\max_i w_i}$,
 | 
				
			||||||
 | 
					            #  which also cancels off the $\frac{1}/{N}$ term
 | 
				
			||||||
 | 
					            samples['weights'][i] = weight / max_weight
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # get samples data
 | 
				
			||||||
 | 
					        for k, v in self.data.items():
 | 
				
			||||||
 | 
					            samples[k] = v[samples['indexes']]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return samples
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def update_priorities(self, indexes, priorities):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        ### Update priorities
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        for idx, priority in zip(indexes, priorities):
 | 
				
			||||||
 | 
					            self.max_priority = max(self.max_priority, priority)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # $p_i^\alpha$
 | 
				
			||||||
 | 
					            priority_alpha = priority ** self.alpha
 | 
				
			||||||
 | 
					            self._set_priority_min(idx, priority_alpha)
 | 
				
			||||||
 | 
					            self._set_priority_sum(idx, priority_alpha)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def is_full(self):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        ### Is the buffer full
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        We only start sampling afte the buffer is full.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        return self.capacity == self.size
 | 
				
			||||||
@ -128,7 +128,7 @@ class Trainer:
 | 
				
			|||||||
        # Value Loss
 | 
					        # Value Loss
 | 
				
			||||||
        self.value_loss = ClippedValueFunctionLoss()
 | 
					        self.value_loss = ClippedValueFunctionLoss()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def sample(self) -> (Dict[str, np.ndarray], List):
 | 
					    def sample(self) -> Dict[str, torch.Tensor]:
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        ### Sample data with current policy
 | 
					        ### Sample data with current policy
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user