From ca258af351640262a00c7cb989e607c5c6050d4e Mon Sep 17 00:00:00 2001 From: Varuna Jayasiri Date: Sun, 25 Oct 2020 09:44:43 +0530 Subject: [PATCH] =?UTF-8?q?=F0=9F=93=9A=20annotations?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- labml_nn/rl/dqn/__init__.py | 2 +- labml_nn/rl/dqn/experiment.py | 2 +- labml_nn/rl/dqn/model.py | 41 +++++++------- labml_nn/rl/dqn/replay_buffer.py | 94 ++++++++++++++++++++------------ 4 files changed, 82 insertions(+), 57 deletions(-) diff --git a/labml_nn/rl/dqn/__init__.py b/labml_nn/rl/dqn/__init__.py index dd4956d5..d3ac73e1 100644 --- a/labml_nn/rl/dqn/__init__.py +++ b/labml_nn/rl/dqn/__init__.py @@ -1,5 +1,5 @@ """ -# Deep Q Networks +# Deep Q Networks (DQN) This is an implementation of paper [Playing Atari with Deep Reinforcement Learning](https://arxiv.org/abs/1312.5602) diff --git a/labml_nn/rl/dqn/experiment.py b/labml_nn/rl/dqn/experiment.py index 6a4cd36f..b08043a2 100644 --- a/labml_nn/rl/dqn/experiment.py +++ b/labml_nn/rl/dqn/experiment.py @@ -65,7 +65,7 @@ class Trainer: (self.updates, 1) ], outside_value=1) - # replay buffer with $\alpha = 0.6$ + # Replay buffer with $\alpha = 0.6$. Capacity of the replay buffer must be a power of 2. self.replay_buffer = ReplayBuffer(2 ** 14, 0.6) # Model for sampling and training diff --git a/labml_nn/rl/dqn/model.py b/labml_nn/rl/dqn/model.py index 0e40e349..94dcbd55 100644 --- a/labml_nn/rl/dqn/model.py +++ b/labml_nn/rl/dqn/model.py @@ -1,5 +1,5 @@ """ -# Neural Network Model +# Neural Network Model for Deep Q Network (DQN) """ import torch @@ -40,61 +40,60 @@ class Model(Module): """ def __init__(self): - """ - ### Initialize - - We need `scope` because we need multiple copies of variables - for target network and training network. - """ - super().__init__() self.conv = nn.Sequential( # The first convolution layer takes a - # 84x84 frame and produces a 20x20 frame + # $84\times84$ frame and produces a $20\times20$ frame nn.Conv2d(in_channels=4, out_channels=32, kernel_size=8, stride=4), nn.ReLU(), # The second convolution layer takes a - # 20x20 frame and produces a 9x9 frame + # $20\times20$ frame and produces a $9\times9$ frame nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2), nn.ReLU(), # The third convolution layer takes a - # 9x9 frame and produces a 7x7 frame + # $9\times9$ frame and produces a $7\times7$ frame nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1), nn.ReLU(), ) # A fully connected layer takes the flattened # frame from third convolution layer, and outputs - # 512 features + # $512$ features self.lin = nn.Linear(in_features=7 * 7 * 64, out_features=512) + self.activation = nn.ReLU() - self.state_score = nn.Sequential( + # This head gives the state value $V$ + self.state_value = nn.Sequential( nn.Linear(in_features=512, out_features=256), nn.ReLU(), nn.Linear(in_features=256, out_features=1), ) - self.action_score = nn.Sequential( + # This head gives the action value $A$ + self.action_value = nn.Sequential( nn.Linear(in_features=512, out_features=256), nn.ReLU(), nn.Linear(in_features=256, out_features=4), ) - # - self.activation = nn.ReLU() - def __call__(self, obs: torch.Tensor): + # Convolution h = self.conv(obs) + # Reshape for linear layers h = h.reshape((-1, 7 * 7 * 64)) + # Linear layer h = self.activation(self.lin(h)) - action_score = self.action_score(h) - state_score = self.state_score(h) + # $A$ + action_value = self.action_value(h) + # $V$ + state_value = self.state_value(h) + # $A(s, a) - \frac{1}{|\mathcal{A}|} \sum_{a' \in \mathcal{A}} A(s, a')$ + action_score_centered = action_value - action_value.mean(dim=-1, keepdim=True) # $Q(s, a) =V(s) + \Big(A(s, a) - \frac{1}{|\mathcal{A}|} \sum_{a' \in \mathcal{A}} A(s, a')\Big)$ - action_score_centered = action_score - action_score.mean(dim=-1, keepdim=True) - q = state_score + action_score_centered + q = state_value + action_score_centered return q diff --git a/labml_nn/rl/dqn/replay_buffer.py b/labml_nn/rl/dqn/replay_buffer.py index 86dded7d..ed52857a 100644 --- a/labml_nn/rl/dqn/replay_buffer.py +++ b/labml_nn/rl/dqn/replay_buffer.py @@ -1,13 +1,14 @@ """ -# Prioritized Experience Replace Buffer +# Prioritized Experience Replay Buffer This implements paper [Prioritized experience replay](https://arxiv.org/abs/1511.05952), using a binary segment tree. """ -import numpy as np import random +import numpy as np + class ReplayBuffer: """ @@ -15,20 +16,21 @@ class ReplayBuffer: [Prioritized experience replay](https://arxiv.org/abs/1511.05952) samples important transitions more frequently. - The transitions are prioritized by the Temporal Difference error (td error). + The transitions are prioritized by the Temporal Difference error (td error), $\delta$. We sample transition $i$ with probability, $$P(i) = \frac{p_i^\alpha}{\sum_k p_k^\alpha}$$ where $\alpha$ is a hyper-parameter that determines how much prioritization is used, with $\alpha = 0$ corresponding to uniform case. + $p_i$ is the priority. We use proportional prioritization $p_i = |\delta_i| + \epsilon$ where $\delta_i$ is the temporal difference for transition $i$. - We correct the bias introduced by prioritized replay by + We correct the bias introduced by prioritized replay using importance-sampling (IS) weights - $$w_i = \bigg(\frac{1}{N} \frac{1}{P(i)}\bigg)^\beta$$ - that fully compensates for when $\beta = 1$. + $$w_i = \bigg(\frac{1}{N} \frac{1}{P(i)}\bigg)^\beta$$ in the loss function. + This fully compensates when $\beta = 1$. We normalize weights by $\frac{1}{\max_i w_i}$ for stability. Unbiased nature is most important towards the convergence at end of training. Therefore we increase $\beta$ towards end of training. @@ -40,6 +42,9 @@ class ReplayBuffer: We also use a binary segment tree to find $\min p_i^\alpha$, which is needed for $\frac{1}{\max_i w_i}$. We can also use a min-heap for this. + Binary Segment Tree lets us calculate these in $\mathcal{O}(\log n)$ + time, which is way more efficient that the naive $\mathcal{O}(n)$ + approach. This is how a binary segment tree works for sum; it is similar for minimum. @@ -51,15 +56,16 @@ class ReplayBuffer: The leaf nodes on row $D = \left\lceil {1 + \log_2 N} \right\rceil$ will have values of $x$. Every node keeps the sum of the two child nodes. - So the root node keeps the sum of the entire array of values. - The two children of the root node keep + That is, the root node keeps the sum of the entire array of values. + The left and right children of the root node keep the sum of the first half of the array and - the sum of the second half of the array, and so on. + the sum of the second half of the array, respectively. + And so on... $$b_{i,j} = \sum_{k = (j -1) * 2^{D - i} + 1}^{j * 2^{D - i}} x_k$$ Number of nodes in row $i$, - $$N_i = \left\lceil{\frac{N}{D - i + i}} \right\rceil$$ + $$N_i = \left\lceil{\frac{N}{D - i + 1}} \right\rceil$$ This is equal to the sum of nodes in all rows above $i$. So we can use a single array $a$ to store the tree, where, $$b_{i,j} \rightarrow a_{N_i + j}$$ @@ -71,28 +77,26 @@ class ReplayBuffer: This way of maintaining binary trees is very easy to program. *Note that we are indexing starting from 1*. - We using the same structure to compute the minimum. + We use the same structure to compute the minimum. """ def __init__(self, capacity, alpha): """ ### Initialize """ - # we use a power of 2 for capacity to make it easy to debug + # We use a power of $2$ for capacity because it simplifies the code and debugging self.capacity = capacity - # we refill the queue once it reaches capacity - self.next_idx = 0 # $\alpha$ self.alpha = alpha - # maintain segment binary trees to take sum and find minimum over a range + # Maintain segment binary trees to take sum and find minimum over a range self.priority_sum = [0 for _ in range(2 * self.capacity)] self.priority_min = [float('inf') for _ in range(2 * self.capacity)] - # current max priority, $p$, to be assigned to new transitions + # Current max priority, $p$, to be assigned to new transitions self.max_priority = 1. - # arrays for buffer + # Arrays for buffer self.data = { 'obs': np.zeros(shape=(capacity, 4, 84, 84), dtype=np.uint8), 'action': np.zeros(shape=capacity, dtype=np.int32), @@ -100,8 +104,11 @@ class ReplayBuffer: 'next_obs': np.zeros(shape=(capacity, 4, 84, 84), dtype=np.uint8), 'done': np.zeros(shape=capacity, dtype=np.bool) } + # We use cyclic buffers to store data, and `next_idx` keeps the index of the next empty + # slot + self.next_idx = 0 - # size of the buffer + # Size of the buffer self.size = 0 def add(self, obs, action, reward, next_obs, done): @@ -109,6 +116,7 @@ class ReplayBuffer: ### Add sample to queue """ + # Get next available slot idx = self.next_idx # store in the queue @@ -118,12 +126,14 @@ class ReplayBuffer: self.data['next_obs'][idx] = next_obs self.data['done'][idx] = done - # increment head of the queue and calculate the size + # Increment next available slot self.next_idx = (idx + 1) % self.capacity + # Calculate the size self.size = min(self.capacity, self.size + 1) # $p_i^\alpha$, new samples get `max_priority` priority_alpha = self.max_priority ** self.alpha + # Update the two segment trees for sum and minimum self._set_priority_min(idx, priority_alpha) self._set_priority_sum(idx, priority_alpha) @@ -132,40 +142,50 @@ class ReplayBuffer: #### Set priority in binary segment tree for minimum """ - # leaf of the binary tree + # Leaf of the binary tree idx += self.capacity self.priority_min[idx] = priority_alpha - # update tree, by traversing along ancestors + # Update tree, by traversing along ancestors. + # Continue until the root of the tree. while idx >= 2: + # Get the index of the parent node idx //= 2 - self.priority_min[idx] = min(self.priority_min[2 * idx], - self.priority_min[2 * idx + 1]) + # Value of the parent node is the minimum of it's two children + self.priority_min[idx] = min(self.priority_min[2 * idx], self.priority_min[2 * idx + 1]) def _set_priority_sum(self, idx, priority): """ #### Set priority in binary segment tree for sum """ - # leaf of the binary tree + # Leaf of the binary tree idx += self.capacity + # Set the priority at the leaf self.priority_sum[idx] = priority - # update tree, by traversing along ancestors + # Update tree, by traversing along ancestors. + # Continue until the root of the tree. while idx >= 2: + # Get the index of the parent node idx //= 2 + # Value of the parent node is the sum of it's two children self.priority_sum[idx] = self.priority_sum[2 * idx] + self.priority_sum[2 * idx + 1] def _sum(self): """ #### $\sum_k p_k^\alpha$ """ + + # The root node keeps the sum of all values return self.priority_sum[1] def _min(self): """ #### $\min_k p_k^\alpha$ """ + + # The root node keeps the minimum of all values return self.priority_min[1] def find_prefix_sum_idx(self, prefix_sum): @@ -173,19 +193,21 @@ class ReplayBuffer: #### Find largest $i$ such that $\sum_{k=1}^{i} p_k^\alpha \le P$ """ - # start from the root + # Start from the root idx = 1 while idx < self.capacity: - # if the sum of the left branch is higher than required sum + # If the sum of the left branch is higher than required sum if self.priority_sum[idx * 2] > prefix_sum: - # go to left branch if the tree if the + # Go to left branch of the tree idx = 2 * idx else: - # otherwise go to right branch and reduce the sum of left + # Otherwise go to right branch and reduce the sum of left # branch from required sum prefix_sum -= self.priority_sum[idx * 2] idx = 2 * idx + 1 + # We are at the leaf node. Subtract the capacity by the index in the tree + # to get the index of actual value return idx - self.capacity def sample(self, batch_size, beta): @@ -193,12 +215,13 @@ class ReplayBuffer: ### Sample from buffer """ + # Initialize samples samples = { 'weights': np.zeros(shape=batch_size, dtype=np.float32), 'indexes': np.zeros(shape=batch_size, dtype=np.int32) } - # get samples + # Get sample indexes for i in range(batch_size): p = random.random() * self._sum() idx = self.find_prefix_sum_idx(p) @@ -215,11 +238,11 @@ class ReplayBuffer: prob = self.priority_sum[idx + self.capacity] / self._sum() # $w_i = \bigg(\frac{1}{N} \frac{1}{P(i)}\bigg)^\beta$ weight = (prob * self.size) ** (-beta) - # normalize by $\frac{1}{\max_i w_i}$, + # Normalize by $\frac{1}{\max_i w_i}$, # which also cancels off the $\frac{1}{N}$ term samples['weights'][i] = weight / max_weight - # get samples data + # Get samples data for k, v in self.data.items(): samples[k] = v[samples['indexes']] @@ -229,16 +252,19 @@ class ReplayBuffer: """ ### Update priorities """ + for idx, priority in zip(indexes, priorities): + # Set current max priority self.max_priority = max(self.max_priority, priority) - # $p_i^\alpha$ + # Calculate $p_i^\alpha$ priority_alpha = priority ** self.alpha + # Update the trees self._set_priority_min(idx, priority_alpha) self._set_priority_sum(idx, priority_alpha) def is_full(self): """ - ### Is the buffer full + ### Whether the buffer is full """ return self.capacity == self.size