This implements paper Prioritized experience replay, using a binary segment tree.
13import random
14
15import numpy as npPrioritized experience replay samples important transitions more frequently. The transitions are prioritized by the Temporal Difference error (td error), $\delta$.
We sample transition $i$ with probability, where $\alpha$ is a hyper-parameter that determines how much prioritization is used, with $\alpha = 0$ corresponding to uniform case. $p_i$ is the priority.
We use proportional prioritization $p_i = |\delta_i| + \epsilon$ where $\delta_i$ is the temporal difference for transition $i$.
We correct the bias introduced by prioritized replay using importance-sampling (IS) weights in the loss function. This fully compensates when $\beta = 1$. We normalize weights by $\frac{1}{\max_i w_i}$ for stability. Unbiased nature is most important towards the convergence at end of training. Therefore we increase $\beta$ towards end of training.
We use a binary segment tree to efficiently calculate $\sum_k^i p_k^\alpha$, the cumulative probability, which is needed to sample. We also use a binary segment tree to find $\min p_i^\alpha$, which is needed for $\frac{1}{\max_i w_i}$. We can also use a min-heap for this. Binary Segment Tree lets us calculate these in $\mathcal{O}(\log n)$ time, which is way more efficient that the naive $\mathcal{O}(n)$ approach.
This is how a binary segment tree works for sum; it is similar for minimum. Let $x_i$ be the list of $N$ values we want to represent. Let $b_{i,j}$ be the $j^{\mathop{th}}$ node of the $i^{\mathop{th}}$ row in the binary tree. That is two children of node $b_{i,j}$ are $b_{i+1,2j}$ and $b_{i+1,2j + 1}$.
The leaf nodes on row $D = \left\lceil {1 + \log_2 N} \right\rceil$ will have values of $x$. Every node keeps the sum of the two child nodes. That is, the root node keeps the sum of the entire array of values. The left and right children of the root node keep the sum of the first half of the array and the sum of the second half of the array, respectively. And so on…
Number of nodes in row $i$, This is equal to the sum of nodes in all rows above $i$. So we can use a single array $a$ to store the tree, where,
Then child nodes of $a_i$ are $a_{2i}$ and $a_{2i + 1}$. That is,
This way of maintaining binary trees is very easy to program. Note that we are indexing starting from 1.
We use the same structure to compute the minimum.
18class ReplayBuffer:88    def __init__(self, capacity, alpha):We use a power of $2$ for capacity because it simplifies the code and debugging
93        self.capacity = capacity$\alpha$
95        self.alpha = alphaMaintain segment binary trees to take sum and find minimum over a range
98        self.priority_sum = [0 for _ in range(2 * self.capacity)]
99        self.priority_min = [float('inf') for _ in range(2 * self.capacity)]Current max priority, $p$, to be assigned to new transitions
102        self.max_priority = 1.Arrays for buffer
105        self.data = {
106            'obs': np.zeros(shape=(capacity, 4, 84, 84), dtype=np.uint8),
107            'action': np.zeros(shape=capacity, dtype=np.int32),
108            'reward': np.zeros(shape=capacity, dtype=np.float32),
109            'next_obs': np.zeros(shape=(capacity, 4, 84, 84), dtype=np.uint8),
110            'done': np.zeros(shape=capacity, dtype=np.bool)
111        }We use cyclic buffers to store data, and next_idx keeps the index of the next empty
slot
114        self.next_idx = 0Size of the buffer
117        self.size = 0119    def add(self, obs, action, reward, next_obs, done):Get next available slot
125        idx = self.next_idxstore in the queue
128        self.data['obs'][idx] = obs
129        self.data['action'][idx] = action
130        self.data['reward'][idx] = reward
131        self.data['next_obs'][idx] = next_obs
132        self.data['done'][idx] = doneIncrement next available slot
135        self.next_idx = (idx + 1) % self.capacityCalculate the size
137        self.size = min(self.capacity, self.size + 1)$p_i^\alpha$, new samples get max_priority
140        priority_alpha = self.max_priority ** self.alphaUpdate the two segment trees for sum and minimum
142        self._set_priority_min(idx, priority_alpha)
143        self._set_priority_sum(idx, priority_alpha)145    def _set_priority_min(self, idx, priority_alpha):Leaf of the binary tree
151        idx += self.capacity
152        self.priority_min[idx] = priority_alphaUpdate tree, by traversing along ancestors. Continue until the root of the tree.
156        while idx >= 2:Get the index of the parent node
158            idx //= 2Value of the parent node is the minimum of it’s two children
160            self.priority_min[idx] = min(self.priority_min[2 * idx], self.priority_min[2 * idx + 1])162    def _set_priority_sum(self, idx, priority):Leaf of the binary tree
168        idx += self.capacitySet the priority at the leaf
170        self.priority_sum[idx] = priorityUpdate tree, by traversing along ancestors. Continue until the root of the tree.
174        while idx >= 2:Get the index of the parent node
176            idx //= 2Value of the parent node is the sum of it’s two children
178            self.priority_sum[idx] = self.priority_sum[2 * idx] + self.priority_sum[2 * idx + 1]180    def _sum(self):The root node keeps the sum of all values
186        return self.priority_sum[1]188    def _min(self):The root node keeps the minimum of all values
194        return self.priority_min[1]196    def find_prefix_sum_idx(self, prefix_sum):Start from the root
202        idx = 1
203        while idx < self.capacity:If the sum of the left branch is higher than required sum
205            if self.priority_sum[idx * 2] > prefix_sum:Go to left branch of the tree
207                idx = 2 * idx
208            else:Otherwise go to right branch and reduce the sum of left branch from required sum
211                prefix_sum -= self.priority_sum[idx * 2]
212                idx = 2 * idx + 1We are at the leaf node. Subtract the capacity by the index in the tree to get the index of actual value
216        return idx - self.capacity218    def sample(self, batch_size, beta):Initialize samples
224        samples = {
225            'weights': np.zeros(shape=batch_size, dtype=np.float32),
226            'indexes': np.zeros(shape=batch_size, dtype=np.int32)
227        }Get sample indexes
230        for i in range(batch_size):
231            p = random.random() * self._sum()
232            idx = self.find_prefix_sum_idx(p)
233            samples['indexes'][i] = idx$\min_i P(i) = \frac{\min_i p_i^\alpha}{\sum_k p_k^\alpha}$
236        prob_min = self._min() / self._sum()$\max_i w_i = \bigg(\frac{1}{N} \frac{1}{\min_i P(i)}\bigg)^\beta$
238        max_weight = (prob_min * self.size) ** (-beta)
239
240        for i in range(batch_size):
241            idx = samples['indexes'][i]$P(i) = \frac{p_i^\alpha}{\sum_k p_k^\alpha}$
243            prob = self.priority_sum[idx + self.capacity] / self._sum()$w_i = \bigg(\frac{1}{N} \frac{1}{P(i)}\bigg)^\beta$
245            weight = (prob * self.size) ** (-beta)Normalize by $\frac{1}{\max_i w_i}$, which also cancels off the $\frac{1}{N}$ term
248            samples['weights'][i] = weight / max_weightGet samples data
251        for k, v in self.data.items():
252            samples[k] = v[samples['indexes']]
253
254        return samples256    def update_priorities(self, indexes, priorities):261        for idx, priority in zip(indexes, priorities):Set current max priority
263            self.max_priority = max(self.max_priority, priority)Calculate $p_i^\alpha$
266            priority_alpha = priority ** self.alphaUpdate the trees
268            self._set_priority_min(idx, priority_alpha)
269            self._set_priority_sum(idx, priority_alpha)271    def is_full(self):275        return self.capacity == self.size