This implements paper Prioritized experience replay, using a binary segment tree.
16import random
17
18import numpy as npPrioritized experience replay samples important transitions more frequently. The transitions are prioritized by the Temporal Difference error (td error), .
We sample transition with probability, where is a hyper-parameter that determines how much prioritization is used, with corresponding to uniform case. is the priority.
We use proportional prioritization where is the temporal difference for transition .
We correct the bias introduced by prioritized replay using importance-sampling (IS) weights in the loss function. This fully compensates when . We normalize weights by for stability. Unbiased nature is most important towards the convergence at end of training. Therefore we increase towards end of training.
We use a binary segment tree to efficiently calculate , the cumulative probability, which is needed to sample. We also use a binary segment tree to find , which is needed for . We can also use a min-heap for this. Binary Segment Tree lets us calculate these in time, which is way more efficient that the naive approach.
This is how a binary segment tree works for sum; it is similar for minimum. Let be the list of values we want to represent. Let be the node of the row in the binary tree. That is two children of node are and .
The leaf nodes on row will have values of . Every node keeps the sum of the two child nodes. That is, the root node keeps the sum of the entire array of values. The left and right children of the root node keep the sum of the first half of the array and the sum of the second half of the array, respectively. And so on...
Number of nodes in row , This is equal to the sum of nodes in all rows above . So we can use a single array to store the tree, where,
Then child nodes of are and . That is,
This way of maintaining binary trees is very easy to program. Note that we are indexing starting from 1.
We use the same structure to compute the minimum.
21class ReplayBuffer:91    def __init__(self, capacity, alpha):We use a power of for capacity because it simplifies the code and debugging
96        self.capacity = capacity98        self.alpha = alphaMaintain segment binary trees to take sum and find minimum over a range
101        self.priority_sum = [0 for _ in range(2 * self.capacity)]
102        self.priority_min = [float('inf') for _ in range(2 * self.capacity)]Current max priority, , to be assigned to new transitions
105        self.max_priority = 1.Arrays for buffer
108        self.data = {
109            'obs': np.zeros(shape=(capacity, 4, 84, 84), dtype=np.uint8),
110            'action': np.zeros(shape=capacity, dtype=np.int32),
111            'reward': np.zeros(shape=capacity, dtype=np.float32),
112            'next_obs': np.zeros(shape=(capacity, 4, 84, 84), dtype=np.uint8),
113            'done': np.zeros(shape=capacity, dtype=np.bool)
114        }We use cyclic buffers to store data, and next_idx
 keeps the index of the next empty slot 
117        self.next_idx = 0Size of the buffer
120        self.size = 0122    def add(self, obs, action, reward, next_obs, done):Get next available slot
128        idx = self.next_idxstore in the queue
131        self.data['obs'][idx] = obs
132        self.data['action'][idx] = action
133        self.data['reward'][idx] = reward
134        self.data['next_obs'][idx] = next_obs
135        self.data['done'][idx] = doneIncrement next available slot
138        self.next_idx = (idx + 1) % self.capacityCalculate the size
140        self.size = min(self.capacity, self.size + 1), new samples get max_priority
 
143        priority_alpha = self.max_priority ** self.alphaUpdate the two segment trees for sum and minimum
145        self._set_priority_min(idx, priority_alpha)
146        self._set_priority_sum(idx, priority_alpha)148    def _set_priority_min(self, idx, priority_alpha):Leaf of the binary tree
154        idx += self.capacity
155        self.priority_min[idx] = priority_alphaUpdate tree, by traversing along ancestors. Continue until the root of the tree.
159        while idx >= 2:Get the index of the parent node
161            idx //= 2Value of the parent node is the minimum of it's two children
163            self.priority_min[idx] = min(self.priority_min[2 * idx], self.priority_min[2 * idx + 1])165    def _set_priority_sum(self, idx, priority):Leaf of the binary tree
171        idx += self.capacitySet the priority at the leaf
173        self.priority_sum[idx] = priorityUpdate tree, by traversing along ancestors. Continue until the root of the tree.
177        while idx >= 2:Get the index of the parent node
179            idx //= 2Value of the parent node is the sum of it's two children
181            self.priority_sum[idx] = self.priority_sum[2 * idx] + self.priority_sum[2 * idx + 1]183    def _sum(self):The root node keeps the sum of all values
189        return self.priority_sum[1]191    def _min(self):The root node keeps the minimum of all values
197        return self.priority_min[1]199    def find_prefix_sum_idx(self, prefix_sum):Start from the root
205        idx = 1
206        while idx < self.capacity:If the sum of the left branch is higher than required sum
208            if self.priority_sum[idx * 2] > prefix_sum:Go to left branch of the tree
210                idx = 2 * idx
211            else:Otherwise go to right branch and reduce the sum of left branch from required sum
214                prefix_sum -= self.priority_sum[idx * 2]
215                idx = 2 * idx + 1We are at the leaf node. Subtract the capacity by the index in the tree to get the index of actual value
219        return idx - self.capacity221    def sample(self, batch_size, beta):Initialize samples
227        samples = {
228            'weights': np.zeros(shape=batch_size, dtype=np.float32),
229            'indexes': np.zeros(shape=batch_size, dtype=np.int32)
230        }Get sample indexes
233        for i in range(batch_size):
234            p = random.random() * self._sum()
235            idx = self.find_prefix_sum_idx(p)
236            samples['indexes'][i] = idx239        prob_min = self._min() / self._sum()241        max_weight = (prob_min * self.size) ** (-beta)
242
243        for i in range(batch_size):
244            idx = samples['indexes'][i]246            prob = self.priority_sum[idx + self.capacity] / self._sum()248            weight = (prob * self.size) ** (-beta)Normalize by , which also cancels off the term
251            samples['weights'][i] = weight / max_weightGet samples data
254        for k, v in self.data.items():
255            samples[k] = v[samples['indexes']]
256
257        return samples259    def update_priorities(self, indexes, priorities):264        for idx, priority in zip(indexes, priorities):Set current max priority
266            self.max_priority = max(self.max_priority, priority)Calculate
269            priority_alpha = priority ** self.alphaUpdate the trees
271            self._set_priority_min(idx, priority_alpha)
272            self._set_priority_sum(idx, priority_alpha)274    def is_full(self):278        return self.capacity == self.size