Files
Varuna Jayasiri 36331f7605 favicon
2021-01-27 09:26:03 +05:30

793 lines
51 KiB
HTML

<!DOCTYPE html>
<html>
<head>
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<meta name="description" content="Annotated implementation of prioritized experience replay using a binary segment tree."/>
<meta name="twitter:card" content="summary"/>
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta name="twitter:title" content="Prioritized Experience Replay Buffer"/>
<meta name="twitter:description" content="Annotated implementation of prioritized experience replay using a binary segment tree."/>
<meta name="twitter:site" content="@labmlai"/>
<meta name="twitter:creator" content="@labmlai"/>
<meta property="og:url" content="https://nn.labml.ai/rl/dqn/replay_buffer.html"/>
<meta property="og:title" content="Prioritized Experience Replay Buffer"/>
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta property="og:site_name" content="LabML Neural Networks"/>
<meta property="og:type" content="object"/>
<meta property="og:title" content="Prioritized Experience Replay Buffer"/>
<meta property="og:description" content="Annotated implementation of prioritized experience replay using a binary segment tree."/>
<title>Prioritized Experience Replay Buffer</title>
<link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="../../pylit.css">
<link rel="canonical" href="https://nn.labml.ai/rl/dqn/replay_buffer.html"/>
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag() {
dataLayer.push(arguments);
}
gtag('js', new Date());
gtag('config', 'G-4V3HC8HBLH');
</script>
</head>
<body>
<div id='container'>
<div id="background"></div>
<div class='section'>
<div class='docs'>
<p>
<a class="parent" href="/">home</a>
<a class="parent" href="../index.html">rl</a>
<a class="parent" href="index.html">dqn</a>
</p>
<p>
<a href="https://github.com/lab-ml/labml_nn/tree/master/labml_nn/rl/dqn/replay_buffer.py">
<img alt="Github"
src="https://img.shields.io/github/stars/lab-ml/nn?style=social"
style="max-width:100%;"/></a>
<a href="https://join.slack.com/t/labforml/shared_invite/zt-egj9zvq9-Dl3hhZqobexgT7aVKnD14g/"
rel="nofollow">
<img alt="Join Slact"
src="https://img.shields.io/badge/slack-chat-green.svg?logo=slack"
style="max-width:100%;"/></a>
<a href="https://twitter.com/labmlai"
rel="nofollow">
<img alt="Twitter"
src="https://img.shields.io/twitter/follow/labmlai?style=social"
style="max-width:100%;"/></a>
</p>
</div>
</div>
<div class='section' id='section-0'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-0'>#</a>
</div>
<h1>Prioritized Experience Replay Buffer</h1>
<p>This implements paper <a href="https://arxiv.org/abs/1511.05952">Prioritized experience replay</a>,
using a binary segment tree.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">13</span><span></span><span class="kn">import</span> <span class="nn">random</span>
<span class="lineno">14</span>
<span class="lineno">15</span><span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span></pre></div>
</div>
</div>
<div class='section' id='section-1'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-1'>#</a>
</div>
<h2>Buffer for Prioritized Experience Replay</h2>
<p><a href="https://arxiv.org/abs/1511.05952">Prioritized experience replay</a>
samples important transitions more frequently.
The transitions are prioritized by the Temporal Difference error (td error), $\delta$.</p>
<p>We sample transition $i$ with probability,
<script type="math/tex; mode=display">P(i) = \frac{p_i^\alpha}{\sum_k p_k^\alpha}</script>
where $\alpha$ is a hyper-parameter that determines how much
prioritization is used, with $\alpha = 0$ corresponding to uniform case.
$p_i$ is the priority.</p>
<p>We use proportional prioritization $p_i = |\delta_i| + \epsilon$ where
$\delta_i$ is the temporal difference for transition $i$.</p>
<p>We correct the bias introduced by prioritized replay using
importance-sampling (IS) weights
<script type="math/tex; mode=display">w_i = \bigg(\frac{1}{N} \frac{1}{P(i)}\bigg)^\beta</script> in the loss function.
This fully compensates when $\beta = 1$.
We normalize weights by $\frac{1}{\max_i w_i}$ for stability.
Unbiased nature is most important towards the convergence at end of training.
Therefore we increase $\beta$ towards end of training.</p>
<h3>Binary Segment Tree</h3>
<p>We use a binary segment tree to efficiently calculate
$\sum_k^i p_k^\alpha$, the cumulative probability,
which is needed to sample.
We also use a binary segment tree to find $\min p_i^\alpha$,
which is needed for $\frac{1}{\max_i w_i}$.
We can also use a min-heap for this.
Binary Segment Tree lets us calculate these in $\mathcal{O}(\log n)$
time, which is way more efficient that the naive $\mathcal{O}(n)$
approach.</p>
<p>This is how a binary segment tree works for sum;
it is similar for minimum.
Let $x_i$ be the list of $N$ values we want to represent.
Let $b_{i,j}$ be the $j^{\mathop{th}}$ node of the $i^{\mathop{th}}$ row
in the binary tree.
That is two children of node $b_{i,j}$ are $b_{i+1,2j}$ and $b_{i+1,2j + 1}$.</p>
<p>The leaf nodes on row $D = \left\lceil {1 + \log_2 N} \right\rceil$
will have values of $x$.
Every node keeps the sum of the two child nodes.
That is, the root node keeps the sum of the entire array of values.
The left and right children of the root node keep
the sum of the first half of the array and
the sum of the second half of the array, respectively.
And so on&hellip;</p>
<p>
<script type="math/tex; mode=display">b_{i,j} = \sum_{k = (j -1) * 2^{D - i} + 1}^{j * 2^{D - i}} x_k</script>
</p>
<p>Number of nodes in row $i$,
<script type="math/tex; mode=display">N_i = \left\lceil{\frac{N}{D - i + 1}} \right\rceil</script>
This is equal to the sum of nodes in all rows above $i$.
So we can use a single array $a$ to store the tree, where,
<script type="math/tex; mode=display">b_{i,j} \rightarrow a_{N_i + j}</script>
</p>
<p>Then child nodes of $a_i$ are $a_{2i}$ and $a_{2i + 1}$.
That is,
<script type="math/tex; mode=display">a_i = a_{2i} + a_{2i + 1}</script>
</p>
<p>This way of maintaining binary trees is very easy to program.
<em>Note that we are indexing starting from 1</em>.</p>
<p>We use the same structure to compute the minimum.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">18</span><span class="k">class</span> <span class="nc">ReplayBuffer</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-2'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-2'>#</a>
</div>
<h3>Initialize</h3>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">88</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">capacity</span><span class="p">,</span> <span class="n">alpha</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-3'>
<div class='docs'>
<div class='section-link'>
<a href='#section-3'>#</a>
</div>
<p>We use a power of $2$ for capacity because it simplifies the code and debugging</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">93</span> <span class="bp">self</span><span class="o">.</span><span class="n">capacity</span> <span class="o">=</span> <span class="n">capacity</span></pre></div>
</div>
</div>
<div class='section' id='section-4'>
<div class='docs'>
<div class='section-link'>
<a href='#section-4'>#</a>
</div>
<p>$\alpha$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">95</span> <span class="bp">self</span><span class="o">.</span><span class="n">alpha</span> <span class="o">=</span> <span class="n">alpha</span></pre></div>
</div>
</div>
<div class='section' id='section-5'>
<div class='docs'>
<div class='section-link'>
<a href='#section-5'>#</a>
</div>
<p>Maintain segment binary trees to take sum and find minimum over a range</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">98</span> <span class="bp">self</span><span class="o">.</span><span class="n">priority_sum</span> <span class="o">=</span> <span class="p">[</span><span class="mi">0</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">2</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">capacity</span><span class="p">)]</span>
<span class="lineno">99</span> <span class="bp">self</span><span class="o">.</span><span class="n">priority_min</span> <span class="o">=</span> <span class="p">[</span><span class="nb">float</span><span class="p">(</span><span class="s1">&#39;inf&#39;</span><span class="p">)</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">2</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">capacity</span><span class="p">)]</span></pre></div>
</div>
</div>
<div class='section' id='section-6'>
<div class='docs'>
<div class='section-link'>
<a href='#section-6'>#</a>
</div>
<p>Current max priority, $p$, to be assigned to new transitions</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">102</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_priority</span> <span class="o">=</span> <span class="mf">1.</span></pre></div>
</div>
</div>
<div class='section' id='section-7'>
<div class='docs'>
<div class='section-link'>
<a href='#section-7'>#</a>
</div>
<p>Arrays for buffer</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">105</span> <span class="bp">self</span><span class="o">.</span><span class="n">data</span> <span class="o">=</span> <span class="p">{</span>
<span class="lineno">106</span> <span class="s1">&#39;obs&#39;</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="n">capacity</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">84</span><span class="p">,</span> <span class="mi">84</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">uint8</span><span class="p">),</span>
<span class="lineno">107</span> <span class="s1">&#39;action&#39;</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="n">capacity</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="lineno">108</span> <span class="s1">&#39;reward&#39;</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="n">capacity</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">),</span>
<span class="lineno">109</span> <span class="s1">&#39;next_obs&#39;</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="n">capacity</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">84</span><span class="p">,</span> <span class="mi">84</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">uint8</span><span class="p">),</span>
<span class="lineno">110</span> <span class="s1">&#39;done&#39;</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="n">capacity</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">bool</span><span class="p">)</span>
<span class="lineno">111</span> <span class="p">}</span></pre></div>
</div>
</div>
<div class='section' id='section-8'>
<div class='docs'>
<div class='section-link'>
<a href='#section-8'>#</a>
</div>
<p>We use cyclic buffers to store data, and <code>next_idx</code> keeps the index of the next empty
slot</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">114</span> <span class="bp">self</span><span class="o">.</span><span class="n">next_idx</span> <span class="o">=</span> <span class="mi">0</span></pre></div>
</div>
</div>
<div class='section' id='section-9'>
<div class='docs'>
<div class='section-link'>
<a href='#section-9'>#</a>
</div>
<p>Size of the buffer</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">117</span> <span class="bp">self</span><span class="o">.</span><span class="n">size</span> <span class="o">=</span> <span class="mi">0</span></pre></div>
</div>
</div>
<div class='section' id='section-10'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-10'>#</a>
</div>
<h3>Add sample to queue</h3>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">119</span> <span class="k">def</span> <span class="nf">add</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">obs</span><span class="p">,</span> <span class="n">action</span><span class="p">,</span> <span class="n">reward</span><span class="p">,</span> <span class="n">next_obs</span><span class="p">,</span> <span class="n">done</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-11'>
<div class='docs'>
<div class='section-link'>
<a href='#section-11'>#</a>
</div>
<p>Get next available slot</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">125</span> <span class="n">idx</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">next_idx</span></pre></div>
</div>
</div>
<div class='section' id='section-12'>
<div class='docs'>
<div class='section-link'>
<a href='#section-12'>#</a>
</div>
<p>store in the queue</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">128</span> <span class="bp">self</span><span class="o">.</span><span class="n">data</span><span class="p">[</span><span class="s1">&#39;obs&#39;</span><span class="p">][</span><span class="n">idx</span><span class="p">]</span> <span class="o">=</span> <span class="n">obs</span>
<span class="lineno">129</span> <span class="bp">self</span><span class="o">.</span><span class="n">data</span><span class="p">[</span><span class="s1">&#39;action&#39;</span><span class="p">][</span><span class="n">idx</span><span class="p">]</span> <span class="o">=</span> <span class="n">action</span>
<span class="lineno">130</span> <span class="bp">self</span><span class="o">.</span><span class="n">data</span><span class="p">[</span><span class="s1">&#39;reward&#39;</span><span class="p">][</span><span class="n">idx</span><span class="p">]</span> <span class="o">=</span> <span class="n">reward</span>
<span class="lineno">131</span> <span class="bp">self</span><span class="o">.</span><span class="n">data</span><span class="p">[</span><span class="s1">&#39;next_obs&#39;</span><span class="p">][</span><span class="n">idx</span><span class="p">]</span> <span class="o">=</span> <span class="n">next_obs</span>
<span class="lineno">132</span> <span class="bp">self</span><span class="o">.</span><span class="n">data</span><span class="p">[</span><span class="s1">&#39;done&#39;</span><span class="p">][</span><span class="n">idx</span><span class="p">]</span> <span class="o">=</span> <span class="n">done</span></pre></div>
</div>
</div>
<div class='section' id='section-13'>
<div class='docs'>
<div class='section-link'>
<a href='#section-13'>#</a>
</div>
<p>Increment next available slot</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">135</span> <span class="bp">self</span><span class="o">.</span><span class="n">next_idx</span> <span class="o">=</span> <span class="p">(</span><span class="n">idx</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">%</span> <span class="bp">self</span><span class="o">.</span><span class="n">capacity</span></pre></div>
</div>
</div>
<div class='section' id='section-14'>
<div class='docs'>
<div class='section-link'>
<a href='#section-14'>#</a>
</div>
<p>Calculate the size</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">137</span> <span class="bp">self</span><span class="o">.</span><span class="n">size</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">capacity</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">size</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-15'>
<div class='docs'>
<div class='section-link'>
<a href='#section-15'>#</a>
</div>
<p>$p_i^\alpha$, new samples get <code>max_priority</code></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">140</span> <span class="n">priority_alpha</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_priority</span> <span class="o">**</span> <span class="bp">self</span><span class="o">.</span><span class="n">alpha</span></pre></div>
</div>
</div>
<div class='section' id='section-16'>
<div class='docs'>
<div class='section-link'>
<a href='#section-16'>#</a>
</div>
<p>Update the two segment trees for sum and minimum</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">142</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set_priority_min</span><span class="p">(</span><span class="n">idx</span><span class="p">,</span> <span class="n">priority_alpha</span><span class="p">)</span>
<span class="lineno">143</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set_priority_sum</span><span class="p">(</span><span class="n">idx</span><span class="p">,</span> <span class="n">priority_alpha</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-17'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-17'>#</a>
</div>
<h4>Set priority in binary segment tree for minimum</h4>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">145</span> <span class="k">def</span> <span class="nf">_set_priority_min</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">idx</span><span class="p">,</span> <span class="n">priority_alpha</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-18'>
<div class='docs'>
<div class='section-link'>
<a href='#section-18'>#</a>
</div>
<p>Leaf of the binary tree</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">151</span> <span class="n">idx</span> <span class="o">+=</span> <span class="bp">self</span><span class="o">.</span><span class="n">capacity</span>
<span class="lineno">152</span> <span class="bp">self</span><span class="o">.</span><span class="n">priority_min</span><span class="p">[</span><span class="n">idx</span><span class="p">]</span> <span class="o">=</span> <span class="n">priority_alpha</span></pre></div>
</div>
</div>
<div class='section' id='section-19'>
<div class='docs'>
<div class='section-link'>
<a href='#section-19'>#</a>
</div>
<p>Update tree, by traversing along ancestors.
Continue until the root of the tree.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">156</span> <span class="k">while</span> <span class="n">idx</span> <span class="o">&gt;=</span> <span class="mi">2</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-20'>
<div class='docs'>
<div class='section-link'>
<a href='#section-20'>#</a>
</div>
<p>Get the index of the parent node</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">158</span> <span class="n">idx</span> <span class="o">//=</span> <span class="mi">2</span></pre></div>
</div>
</div>
<div class='section' id='section-21'>
<div class='docs'>
<div class='section-link'>
<a href='#section-21'>#</a>
</div>
<p>Value of the parent node is the minimum of it&rsquo;s two children</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">160</span> <span class="bp">self</span><span class="o">.</span><span class="n">priority_min</span><span class="p">[</span><span class="n">idx</span><span class="p">]</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">priority_min</span><span class="p">[</span><span class="mi">2</span> <span class="o">*</span> <span class="n">idx</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">priority_min</span><span class="p">[</span><span class="mi">2</span> <span class="o">*</span> <span class="n">idx</span> <span class="o">+</span> <span class="mi">1</span><span class="p">])</span></pre></div>
</div>
</div>
<div class='section' id='section-22'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-22'>#</a>
</div>
<h4>Set priority in binary segment tree for sum</h4>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">162</span> <span class="k">def</span> <span class="nf">_set_priority_sum</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">idx</span><span class="p">,</span> <span class="n">priority</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-23'>
<div class='docs'>
<div class='section-link'>
<a href='#section-23'>#</a>
</div>
<p>Leaf of the binary tree</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">168</span> <span class="n">idx</span> <span class="o">+=</span> <span class="bp">self</span><span class="o">.</span><span class="n">capacity</span></pre></div>
</div>
</div>
<div class='section' id='section-24'>
<div class='docs'>
<div class='section-link'>
<a href='#section-24'>#</a>
</div>
<p>Set the priority at the leaf</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">170</span> <span class="bp">self</span><span class="o">.</span><span class="n">priority_sum</span><span class="p">[</span><span class="n">idx</span><span class="p">]</span> <span class="o">=</span> <span class="n">priority</span></pre></div>
</div>
</div>
<div class='section' id='section-25'>
<div class='docs'>
<div class='section-link'>
<a href='#section-25'>#</a>
</div>
<p>Update tree, by traversing along ancestors.
Continue until the root of the tree.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">174</span> <span class="k">while</span> <span class="n">idx</span> <span class="o">&gt;=</span> <span class="mi">2</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-26'>
<div class='docs'>
<div class='section-link'>
<a href='#section-26'>#</a>
</div>
<p>Get the index of the parent node</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">176</span> <span class="n">idx</span> <span class="o">//=</span> <span class="mi">2</span></pre></div>
</div>
</div>
<div class='section' id='section-27'>
<div class='docs'>
<div class='section-link'>
<a href='#section-27'>#</a>
</div>
<p>Value of the parent node is the sum of it&rsquo;s two children</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">178</span> <span class="bp">self</span><span class="o">.</span><span class="n">priority_sum</span><span class="p">[</span><span class="n">idx</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">priority_sum</span><span class="p">[</span><span class="mi">2</span> <span class="o">*</span> <span class="n">idx</span><span class="p">]</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">priority_sum</span><span class="p">[</span><span class="mi">2</span> <span class="o">*</span> <span class="n">idx</span> <span class="o">+</span> <span class="mi">1</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-28'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-28'>#</a>
</div>
<h4>$\sum_k p_k^\alpha$</h4>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">180</span> <span class="k">def</span> <span class="nf">_sum</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-29'>
<div class='docs'>
<div class='section-link'>
<a href='#section-29'>#</a>
</div>
<p>The root node keeps the sum of all values</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">186</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">priority_sum</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-30'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-30'>#</a>
</div>
<h4>$\min_k p_k^\alpha$</h4>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">188</span> <span class="k">def</span> <span class="nf">_min</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-31'>
<div class='docs'>
<div class='section-link'>
<a href='#section-31'>#</a>
</div>
<p>The root node keeps the minimum of all values</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">194</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">priority_min</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-32'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-32'>#</a>
</div>
<h4>Find largest $i$ such that $\sum_{k=1}^{i} p_k^\alpha \le P$</h4>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">196</span> <span class="k">def</span> <span class="nf">find_prefix_sum_idx</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">prefix_sum</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-33'>
<div class='docs'>
<div class='section-link'>
<a href='#section-33'>#</a>
</div>
<p>Start from the root</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">202</span> <span class="n">idx</span> <span class="o">=</span> <span class="mi">1</span>
<span class="lineno">203</span> <span class="k">while</span> <span class="n">idx</span> <span class="o">&lt;</span> <span class="bp">self</span><span class="o">.</span><span class="n">capacity</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-34'>
<div class='docs'>
<div class='section-link'>
<a href='#section-34'>#</a>
</div>
<p>If the sum of the left branch is higher than required sum</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">205</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">priority_sum</span><span class="p">[</span><span class="n">idx</span> <span class="o">*</span> <span class="mi">2</span><span class="p">]</span> <span class="o">&gt;</span> <span class="n">prefix_sum</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-35'>
<div class='docs'>
<div class='section-link'>
<a href='#section-35'>#</a>
</div>
<p>Go to left branch of the tree</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">207</span> <span class="n">idx</span> <span class="o">=</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">idx</span>
<span class="lineno">208</span> <span class="k">else</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-36'>
<div class='docs'>
<div class='section-link'>
<a href='#section-36'>#</a>
</div>
<p>Otherwise go to right branch and reduce the sum of left
branch from required sum</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">211</span> <span class="n">prefix_sum</span> <span class="o">-=</span> <span class="bp">self</span><span class="o">.</span><span class="n">priority_sum</span><span class="p">[</span><span class="n">idx</span> <span class="o">*</span> <span class="mi">2</span><span class="p">]</span>
<span class="lineno">212</span> <span class="n">idx</span> <span class="o">=</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">idx</span> <span class="o">+</span> <span class="mi">1</span></pre></div>
</div>
</div>
<div class='section' id='section-37'>
<div class='docs'>
<div class='section-link'>
<a href='#section-37'>#</a>
</div>
<p>We are at the leaf node. Subtract the capacity by the index in the tree
to get the index of actual value</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">216</span> <span class="k">return</span> <span class="n">idx</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">capacity</span></pre></div>
</div>
</div>
<div class='section' id='section-38'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-38'>#</a>
</div>
<h3>Sample from buffer</h3>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">218</span> <span class="k">def</span> <span class="nf">sample</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">beta</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-39'>
<div class='docs'>
<div class='section-link'>
<a href='#section-39'>#</a>
</div>
<p>Initialize samples</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">224</span> <span class="n">samples</span> <span class="o">=</span> <span class="p">{</span>
<span class="lineno">225</span> <span class="s1">&#39;weights&#39;</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">),</span>
<span class="lineno">226</span> <span class="s1">&#39;indexes&#39;</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
<span class="lineno">227</span> <span class="p">}</span></pre></div>
</div>
</div>
<div class='section' id='section-40'>
<div class='docs'>
<div class='section-link'>
<a href='#section-40'>#</a>
</div>
<p>Get sample indexes</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">230</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span><span class="p">):</span>
<span class="lineno">231</span> <span class="n">p</span> <span class="o">=</span> <span class="n">random</span><span class="o">.</span><span class="n">random</span><span class="p">()</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">_sum</span><span class="p">()</span>
<span class="lineno">232</span> <span class="n">idx</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">find_prefix_sum_idx</span><span class="p">(</span><span class="n">p</span><span class="p">)</span>
<span class="lineno">233</span> <span class="n">samples</span><span class="p">[</span><span class="s1">&#39;indexes&#39;</span><span class="p">][</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">idx</span></pre></div>
</div>
</div>
<div class='section' id='section-41'>
<div class='docs'>
<div class='section-link'>
<a href='#section-41'>#</a>
</div>
<p>$\min_i P(i) = \frac{\min_i p_i^\alpha}{\sum_k p_k^\alpha}$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">236</span> <span class="n">prob_min</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_min</span><span class="p">()</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">_sum</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-42'>
<div class='docs'>
<div class='section-link'>
<a href='#section-42'>#</a>
</div>
<p>$\max_i w_i = \bigg(\frac{1}{N} \frac{1}{\min_i P(i)}\bigg)^\beta$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">238</span> <span class="n">max_weight</span> <span class="o">=</span> <span class="p">(</span><span class="n">prob_min</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">size</span><span class="p">)</span> <span class="o">**</span> <span class="p">(</span><span class="o">-</span><span class="n">beta</span><span class="p">)</span>
<span class="lineno">239</span>
<span class="lineno">240</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span><span class="p">):</span>
<span class="lineno">241</span> <span class="n">idx</span> <span class="o">=</span> <span class="n">samples</span><span class="p">[</span><span class="s1">&#39;indexes&#39;</span><span class="p">][</span><span class="n">i</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-43'>
<div class='docs'>
<div class='section-link'>
<a href='#section-43'>#</a>
</div>
<p>$P(i) = \frac{p_i^\alpha}{\sum_k p_k^\alpha}$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">243</span> <span class="n">prob</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">priority_sum</span><span class="p">[</span><span class="n">idx</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">capacity</span><span class="p">]</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">_sum</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-44'>
<div class='docs'>
<div class='section-link'>
<a href='#section-44'>#</a>
</div>
<p>$w_i = \bigg(\frac{1}{N} \frac{1}{P(i)}\bigg)^\beta$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">245</span> <span class="n">weight</span> <span class="o">=</span> <span class="p">(</span><span class="n">prob</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">size</span><span class="p">)</span> <span class="o">**</span> <span class="p">(</span><span class="o">-</span><span class="n">beta</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-45'>
<div class='docs'>
<div class='section-link'>
<a href='#section-45'>#</a>
</div>
<p>Normalize by $\frac{1}{\max_i w_i}$,
which also cancels off the $\frac{1}{N}$ term</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">248</span> <span class="n">samples</span><span class="p">[</span><span class="s1">&#39;weights&#39;</span><span class="p">][</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">weight</span> <span class="o">/</span> <span class="n">max_weight</span></pre></div>
</div>
</div>
<div class='section' id='section-46'>
<div class='docs'>
<div class='section-link'>
<a href='#section-46'>#</a>
</div>
<p>Get samples data</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">251</span> <span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
<span class="lineno">252</span> <span class="n">samples</span><span class="p">[</span><span class="n">k</span><span class="p">]</span> <span class="o">=</span> <span class="n">v</span><span class="p">[</span><span class="n">samples</span><span class="p">[</span><span class="s1">&#39;indexes&#39;</span><span class="p">]]</span>
<span class="lineno">253</span>
<span class="lineno">254</span> <span class="k">return</span> <span class="n">samples</span></pre></div>
</div>
</div>
<div class='section' id='section-47'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-47'>#</a>
</div>
<h3>Update priorities</h3>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">256</span> <span class="k">def</span> <span class="nf">update_priorities</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">indexes</span><span class="p">,</span> <span class="n">priorities</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-48'>
<div class='docs'>
<div class='section-link'>
<a href='#section-48'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">261</span> <span class="k">for</span> <span class="n">idx</span><span class="p">,</span> <span class="n">priority</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">indexes</span><span class="p">,</span> <span class="n">priorities</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-49'>
<div class='docs'>
<div class='section-link'>
<a href='#section-49'>#</a>
</div>
<p>Set current max priority</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">263</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_priority</span> <span class="o">=</span> <span class="nb">max</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">max_priority</span><span class="p">,</span> <span class="n">priority</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-50'>
<div class='docs'>
<div class='section-link'>
<a href='#section-50'>#</a>
</div>
<p>Calculate $p_i^\alpha$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">266</span> <span class="n">priority_alpha</span> <span class="o">=</span> <span class="n">priority</span> <span class="o">**</span> <span class="bp">self</span><span class="o">.</span><span class="n">alpha</span></pre></div>
</div>
</div>
<div class='section' id='section-51'>
<div class='docs'>
<div class='section-link'>
<a href='#section-51'>#</a>
</div>
<p>Update the trees</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">268</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set_priority_min</span><span class="p">(</span><span class="n">idx</span><span class="p">,</span> <span class="n">priority_alpha</span><span class="p">)</span>
<span class="lineno">269</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set_priority_sum</span><span class="p">(</span><span class="n">idx</span><span class="p">,</span> <span class="n">priority_alpha</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-52'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-52'>#</a>
</div>
<h3>Whether the buffer is full</h3>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">271</span> <span class="k">def</span> <span class="nf">is_full</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-53'>
<div class='docs'>
<div class='section-link'>
<a href='#section-53'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">275</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">capacity</span> <span class="o">==</span> <span class="bp">self</span><span class="o">.</span><span class="n">size</span></pre></div>
</div>
</div>
</div>
</div>
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.4/MathJax.js?config=TeX-AMS_HTML">
</script>
<!-- MathJax configuration -->
<script type="text/x-mathjax-config">
MathJax.Hub.Config({
tex2jax: {
inlineMath: [ ['$','$'] ],
displayMath: [ ['$$','$$'] ],
processEscapes: true,
processEnvironments: true
},
// Center justify equations in code and markdown cells. Elsewhere
// we use CSS to left justify single line equations in code cells.
displayAlign: 'center',
"HTML-CSS": { fonts: ["TeX"] }
});
</script>
</body>
</html>