mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-11-01 03:43:09 +08:00
📚 annotations
This commit is contained in:
@ -1,5 +1,5 @@
|
|||||||
"""
|
"""
|
||||||
# Deep Q Networks
|
# Deep Q Networks (DQN)
|
||||||
|
|
||||||
This is an implementation of paper
|
This is an implementation of paper
|
||||||
[Playing Atari with Deep Reinforcement Learning](https://arxiv.org/abs/1312.5602)
|
[Playing Atari with Deep Reinforcement Learning](https://arxiv.org/abs/1312.5602)
|
||||||
|
|||||||
@ -65,7 +65,7 @@ class Trainer:
|
|||||||
(self.updates, 1)
|
(self.updates, 1)
|
||||||
], outside_value=1)
|
], outside_value=1)
|
||||||
|
|
||||||
# replay buffer with $\alpha = 0.6$
|
# Replay buffer with $\alpha = 0.6$. Capacity of the replay buffer must be a power of 2.
|
||||||
self.replay_buffer = ReplayBuffer(2 ** 14, 0.6)
|
self.replay_buffer = ReplayBuffer(2 ** 14, 0.6)
|
||||||
|
|
||||||
# Model for sampling and training
|
# Model for sampling and training
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
"""
|
"""
|
||||||
# Neural Network Model
|
# Neural Network Model for Deep Q Network (DQN)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -40,61 +40,60 @@ class Model(Module):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""
|
|
||||||
### Initialize
|
|
||||||
|
|
||||||
We need `scope` because we need multiple copies of variables
|
|
||||||
for target network and training network.
|
|
||||||
"""
|
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.conv = nn.Sequential(
|
self.conv = nn.Sequential(
|
||||||
# The first convolution layer takes a
|
# The first convolution layer takes a
|
||||||
# 84x84 frame and produces a 20x20 frame
|
# $84\times84$ frame and produces a $20\times20$ frame
|
||||||
nn.Conv2d(in_channels=4, out_channels=32, kernel_size=8, stride=4),
|
nn.Conv2d(in_channels=4, out_channels=32, kernel_size=8, stride=4),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
|
|
||||||
# The second convolution layer takes a
|
# The second convolution layer takes a
|
||||||
# 20x20 frame and produces a 9x9 frame
|
# $20\times20$ frame and produces a $9\times9$ frame
|
||||||
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2),
|
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
|
|
||||||
# The third convolution layer takes a
|
# The third convolution layer takes a
|
||||||
# 9x9 frame and produces a 7x7 frame
|
# $9\times9$ frame and produces a $7\times7$ frame
|
||||||
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1),
|
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
)
|
)
|
||||||
|
|
||||||
# A fully connected layer takes the flattened
|
# A fully connected layer takes the flattened
|
||||||
# frame from third convolution layer, and outputs
|
# frame from third convolution layer, and outputs
|
||||||
# 512 features
|
# $512$ features
|
||||||
self.lin = nn.Linear(in_features=7 * 7 * 64, out_features=512)
|
self.lin = nn.Linear(in_features=7 * 7 * 64, out_features=512)
|
||||||
|
self.activation = nn.ReLU()
|
||||||
|
|
||||||
self.state_score = nn.Sequential(
|
# This head gives the state value $V$
|
||||||
|
self.state_value = nn.Sequential(
|
||||||
nn.Linear(in_features=512, out_features=256),
|
nn.Linear(in_features=512, out_features=256),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
nn.Linear(in_features=256, out_features=1),
|
nn.Linear(in_features=256, out_features=1),
|
||||||
)
|
)
|
||||||
self.action_score = nn.Sequential(
|
# This head gives the action value $A$
|
||||||
|
self.action_value = nn.Sequential(
|
||||||
nn.Linear(in_features=512, out_features=256),
|
nn.Linear(in_features=512, out_features=256),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
nn.Linear(in_features=256, out_features=4),
|
nn.Linear(in_features=256, out_features=4),
|
||||||
)
|
)
|
||||||
|
|
||||||
#
|
|
||||||
self.activation = nn.ReLU()
|
|
||||||
|
|
||||||
def __call__(self, obs: torch.Tensor):
|
def __call__(self, obs: torch.Tensor):
|
||||||
|
# Convolution
|
||||||
h = self.conv(obs)
|
h = self.conv(obs)
|
||||||
|
# Reshape for linear layers
|
||||||
h = h.reshape((-1, 7 * 7 * 64))
|
h = h.reshape((-1, 7 * 7 * 64))
|
||||||
|
|
||||||
|
# Linear layer
|
||||||
h = self.activation(self.lin(h))
|
h = self.activation(self.lin(h))
|
||||||
|
|
||||||
action_score = self.action_score(h)
|
# $A$
|
||||||
state_score = self.state_score(h)
|
action_value = self.action_value(h)
|
||||||
|
# $V$
|
||||||
|
state_value = self.state_value(h)
|
||||||
|
|
||||||
|
# $A(s, a) - \frac{1}{|\mathcal{A}|} \sum_{a' \in \mathcal{A}} A(s, a')$
|
||||||
|
action_score_centered = action_value - action_value.mean(dim=-1, keepdim=True)
|
||||||
# $Q(s, a) =V(s) + \Big(A(s, a) - \frac{1}{|\mathcal{A}|} \sum_{a' \in \mathcal{A}} A(s, a')\Big)$
|
# $Q(s, a) =V(s) + \Big(A(s, a) - \frac{1}{|\mathcal{A}|} \sum_{a' \in \mathcal{A}} A(s, a')\Big)$
|
||||||
action_score_centered = action_score - action_score.mean(dim=-1, keepdim=True)
|
q = state_value + action_score_centered
|
||||||
q = state_score + action_score_centered
|
|
||||||
|
|
||||||
return q
|
return q
|
||||||
|
|||||||
@ -1,13 +1,14 @@
|
|||||||
"""
|
"""
|
||||||
# Prioritized Experience Replace Buffer
|
# Prioritized Experience Replay Buffer
|
||||||
|
|
||||||
This implements paper [Prioritized experience replay](https://arxiv.org/abs/1511.05952),
|
This implements paper [Prioritized experience replay](https://arxiv.org/abs/1511.05952),
|
||||||
using a binary segment tree.
|
using a binary segment tree.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import random
|
import random
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
class ReplayBuffer:
|
class ReplayBuffer:
|
||||||
"""
|
"""
|
||||||
@ -15,20 +16,21 @@ class ReplayBuffer:
|
|||||||
|
|
||||||
[Prioritized experience replay](https://arxiv.org/abs/1511.05952)
|
[Prioritized experience replay](https://arxiv.org/abs/1511.05952)
|
||||||
samples important transitions more frequently.
|
samples important transitions more frequently.
|
||||||
The transitions are prioritized by the Temporal Difference error (td error).
|
The transitions are prioritized by the Temporal Difference error (td error), $\delta$.
|
||||||
|
|
||||||
We sample transition $i$ with probability,
|
We sample transition $i$ with probability,
|
||||||
$$P(i) = \frac{p_i^\alpha}{\sum_k p_k^\alpha}$$
|
$$P(i) = \frac{p_i^\alpha}{\sum_k p_k^\alpha}$$
|
||||||
where $\alpha$ is a hyper-parameter that determines how much
|
where $\alpha$ is a hyper-parameter that determines how much
|
||||||
prioritization is used, with $\alpha = 0$ corresponding to uniform case.
|
prioritization is used, with $\alpha = 0$ corresponding to uniform case.
|
||||||
|
$p_i$ is the priority.
|
||||||
|
|
||||||
We use proportional prioritization $p_i = |\delta_i| + \epsilon$ where
|
We use proportional prioritization $p_i = |\delta_i| + \epsilon$ where
|
||||||
$\delta_i$ is the temporal difference for transition $i$.
|
$\delta_i$ is the temporal difference for transition $i$.
|
||||||
|
|
||||||
We correct the bias introduced by prioritized replay by
|
We correct the bias introduced by prioritized replay using
|
||||||
importance-sampling (IS) weights
|
importance-sampling (IS) weights
|
||||||
$$w_i = \bigg(\frac{1}{N} \frac{1}{P(i)}\bigg)^\beta$$
|
$$w_i = \bigg(\frac{1}{N} \frac{1}{P(i)}\bigg)^\beta$$ in the loss function.
|
||||||
that fully compensates for when $\beta = 1$.
|
This fully compensates when $\beta = 1$.
|
||||||
We normalize weights by $\frac{1}{\max_i w_i}$ for stability.
|
We normalize weights by $\frac{1}{\max_i w_i}$ for stability.
|
||||||
Unbiased nature is most important towards the convergence at end of training.
|
Unbiased nature is most important towards the convergence at end of training.
|
||||||
Therefore we increase $\beta$ towards end of training.
|
Therefore we increase $\beta$ towards end of training.
|
||||||
@ -40,6 +42,9 @@ class ReplayBuffer:
|
|||||||
We also use a binary segment tree to find $\min p_i^\alpha$,
|
We also use a binary segment tree to find $\min p_i^\alpha$,
|
||||||
which is needed for $\frac{1}{\max_i w_i}$.
|
which is needed for $\frac{1}{\max_i w_i}$.
|
||||||
We can also use a min-heap for this.
|
We can also use a min-heap for this.
|
||||||
|
Binary Segment Tree lets us calculate these in $\mathcal{O}(\log n)$
|
||||||
|
time, which is way more efficient that the naive $\mathcal{O}(n)$
|
||||||
|
approach.
|
||||||
|
|
||||||
This is how a binary segment tree works for sum;
|
This is how a binary segment tree works for sum;
|
||||||
it is similar for minimum.
|
it is similar for minimum.
|
||||||
@ -51,15 +56,16 @@ class ReplayBuffer:
|
|||||||
The leaf nodes on row $D = \left\lceil {1 + \log_2 N} \right\rceil$
|
The leaf nodes on row $D = \left\lceil {1 + \log_2 N} \right\rceil$
|
||||||
will have values of $x$.
|
will have values of $x$.
|
||||||
Every node keeps the sum of the two child nodes.
|
Every node keeps the sum of the two child nodes.
|
||||||
So the root node keeps the sum of the entire array of values.
|
That is, the root node keeps the sum of the entire array of values.
|
||||||
The two children of the root node keep
|
The left and right children of the root node keep
|
||||||
the sum of the first half of the array and
|
the sum of the first half of the array and
|
||||||
the sum of the second half of the array, and so on.
|
the sum of the second half of the array, respectively.
|
||||||
|
And so on...
|
||||||
|
|
||||||
$$b_{i,j} = \sum_{k = (j -1) * 2^{D - i} + 1}^{j * 2^{D - i}} x_k$$
|
$$b_{i,j} = \sum_{k = (j -1) * 2^{D - i} + 1}^{j * 2^{D - i}} x_k$$
|
||||||
|
|
||||||
Number of nodes in row $i$,
|
Number of nodes in row $i$,
|
||||||
$$N_i = \left\lceil{\frac{N}{D - i + i}} \right\rceil$$
|
$$N_i = \left\lceil{\frac{N}{D - i + 1}} \right\rceil$$
|
||||||
This is equal to the sum of nodes in all rows above $i$.
|
This is equal to the sum of nodes in all rows above $i$.
|
||||||
So we can use a single array $a$ to store the tree, where,
|
So we can use a single array $a$ to store the tree, where,
|
||||||
$$b_{i,j} \rightarrow a_{N_i + j}$$
|
$$b_{i,j} \rightarrow a_{N_i + j}$$
|
||||||
@ -71,28 +77,26 @@ class ReplayBuffer:
|
|||||||
This way of maintaining binary trees is very easy to program.
|
This way of maintaining binary trees is very easy to program.
|
||||||
*Note that we are indexing starting from 1*.
|
*Note that we are indexing starting from 1*.
|
||||||
|
|
||||||
We using the same structure to compute the minimum.
|
We use the same structure to compute the minimum.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, capacity, alpha):
|
def __init__(self, capacity, alpha):
|
||||||
"""
|
"""
|
||||||
### Initialize
|
### Initialize
|
||||||
"""
|
"""
|
||||||
# we use a power of 2 for capacity to make it easy to debug
|
# We use a power of $2$ for capacity because it simplifies the code and debugging
|
||||||
self.capacity = capacity
|
self.capacity = capacity
|
||||||
# we refill the queue once it reaches capacity
|
|
||||||
self.next_idx = 0
|
|
||||||
# $\alpha$
|
# $\alpha$
|
||||||
self.alpha = alpha
|
self.alpha = alpha
|
||||||
|
|
||||||
# maintain segment binary trees to take sum and find minimum over a range
|
# Maintain segment binary trees to take sum and find minimum over a range
|
||||||
self.priority_sum = [0 for _ in range(2 * self.capacity)]
|
self.priority_sum = [0 for _ in range(2 * self.capacity)]
|
||||||
self.priority_min = [float('inf') for _ in range(2 * self.capacity)]
|
self.priority_min = [float('inf') for _ in range(2 * self.capacity)]
|
||||||
|
|
||||||
# current max priority, $p$, to be assigned to new transitions
|
# Current max priority, $p$, to be assigned to new transitions
|
||||||
self.max_priority = 1.
|
self.max_priority = 1.
|
||||||
|
|
||||||
# arrays for buffer
|
# Arrays for buffer
|
||||||
self.data = {
|
self.data = {
|
||||||
'obs': np.zeros(shape=(capacity, 4, 84, 84), dtype=np.uint8),
|
'obs': np.zeros(shape=(capacity, 4, 84, 84), dtype=np.uint8),
|
||||||
'action': np.zeros(shape=capacity, dtype=np.int32),
|
'action': np.zeros(shape=capacity, dtype=np.int32),
|
||||||
@ -100,8 +104,11 @@ class ReplayBuffer:
|
|||||||
'next_obs': np.zeros(shape=(capacity, 4, 84, 84), dtype=np.uint8),
|
'next_obs': np.zeros(shape=(capacity, 4, 84, 84), dtype=np.uint8),
|
||||||
'done': np.zeros(shape=capacity, dtype=np.bool)
|
'done': np.zeros(shape=capacity, dtype=np.bool)
|
||||||
}
|
}
|
||||||
|
# We use cyclic buffers to store data, and `next_idx` keeps the index of the next empty
|
||||||
|
# slot
|
||||||
|
self.next_idx = 0
|
||||||
|
|
||||||
# size of the buffer
|
# Size of the buffer
|
||||||
self.size = 0
|
self.size = 0
|
||||||
|
|
||||||
def add(self, obs, action, reward, next_obs, done):
|
def add(self, obs, action, reward, next_obs, done):
|
||||||
@ -109,6 +116,7 @@ class ReplayBuffer:
|
|||||||
### Add sample to queue
|
### Add sample to queue
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Get next available slot
|
||||||
idx = self.next_idx
|
idx = self.next_idx
|
||||||
|
|
||||||
# store in the queue
|
# store in the queue
|
||||||
@ -118,12 +126,14 @@ class ReplayBuffer:
|
|||||||
self.data['next_obs'][idx] = next_obs
|
self.data['next_obs'][idx] = next_obs
|
||||||
self.data['done'][idx] = done
|
self.data['done'][idx] = done
|
||||||
|
|
||||||
# increment head of the queue and calculate the size
|
# Increment next available slot
|
||||||
self.next_idx = (idx + 1) % self.capacity
|
self.next_idx = (idx + 1) % self.capacity
|
||||||
|
# Calculate the size
|
||||||
self.size = min(self.capacity, self.size + 1)
|
self.size = min(self.capacity, self.size + 1)
|
||||||
|
|
||||||
# $p_i^\alpha$, new samples get `max_priority`
|
# $p_i^\alpha$, new samples get `max_priority`
|
||||||
priority_alpha = self.max_priority ** self.alpha
|
priority_alpha = self.max_priority ** self.alpha
|
||||||
|
# Update the two segment trees for sum and minimum
|
||||||
self._set_priority_min(idx, priority_alpha)
|
self._set_priority_min(idx, priority_alpha)
|
||||||
self._set_priority_sum(idx, priority_alpha)
|
self._set_priority_sum(idx, priority_alpha)
|
||||||
|
|
||||||
@ -132,40 +142,50 @@ class ReplayBuffer:
|
|||||||
#### Set priority in binary segment tree for minimum
|
#### Set priority in binary segment tree for minimum
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# leaf of the binary tree
|
# Leaf of the binary tree
|
||||||
idx += self.capacity
|
idx += self.capacity
|
||||||
self.priority_min[idx] = priority_alpha
|
self.priority_min[idx] = priority_alpha
|
||||||
|
|
||||||
# update tree, by traversing along ancestors
|
# Update tree, by traversing along ancestors.
|
||||||
|
# Continue until the root of the tree.
|
||||||
while idx >= 2:
|
while idx >= 2:
|
||||||
|
# Get the index of the parent node
|
||||||
idx //= 2
|
idx //= 2
|
||||||
self.priority_min[idx] = min(self.priority_min[2 * idx],
|
# Value of the parent node is the minimum of it's two children
|
||||||
self.priority_min[2 * idx + 1])
|
self.priority_min[idx] = min(self.priority_min[2 * idx], self.priority_min[2 * idx + 1])
|
||||||
|
|
||||||
def _set_priority_sum(self, idx, priority):
|
def _set_priority_sum(self, idx, priority):
|
||||||
"""
|
"""
|
||||||
#### Set priority in binary segment tree for sum
|
#### Set priority in binary segment tree for sum
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# leaf of the binary tree
|
# Leaf of the binary tree
|
||||||
idx += self.capacity
|
idx += self.capacity
|
||||||
|
# Set the priority at the leaf
|
||||||
self.priority_sum[idx] = priority
|
self.priority_sum[idx] = priority
|
||||||
|
|
||||||
# update tree, by traversing along ancestors
|
# Update tree, by traversing along ancestors.
|
||||||
|
# Continue until the root of the tree.
|
||||||
while idx >= 2:
|
while idx >= 2:
|
||||||
|
# Get the index of the parent node
|
||||||
idx //= 2
|
idx //= 2
|
||||||
|
# Value of the parent node is the sum of it's two children
|
||||||
self.priority_sum[idx] = self.priority_sum[2 * idx] + self.priority_sum[2 * idx + 1]
|
self.priority_sum[idx] = self.priority_sum[2 * idx] + self.priority_sum[2 * idx + 1]
|
||||||
|
|
||||||
def _sum(self):
|
def _sum(self):
|
||||||
"""
|
"""
|
||||||
#### $\sum_k p_k^\alpha$
|
#### $\sum_k p_k^\alpha$
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# The root node keeps the sum of all values
|
||||||
return self.priority_sum[1]
|
return self.priority_sum[1]
|
||||||
|
|
||||||
def _min(self):
|
def _min(self):
|
||||||
"""
|
"""
|
||||||
#### $\min_k p_k^\alpha$
|
#### $\min_k p_k^\alpha$
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# The root node keeps the minimum of all values
|
||||||
return self.priority_min[1]
|
return self.priority_min[1]
|
||||||
|
|
||||||
def find_prefix_sum_idx(self, prefix_sum):
|
def find_prefix_sum_idx(self, prefix_sum):
|
||||||
@ -173,19 +193,21 @@ class ReplayBuffer:
|
|||||||
#### Find largest $i$ such that $\sum_{k=1}^{i} p_k^\alpha \le P$
|
#### Find largest $i$ such that $\sum_{k=1}^{i} p_k^\alpha \le P$
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# start from the root
|
# Start from the root
|
||||||
idx = 1
|
idx = 1
|
||||||
while idx < self.capacity:
|
while idx < self.capacity:
|
||||||
# if the sum of the left branch is higher than required sum
|
# If the sum of the left branch is higher than required sum
|
||||||
if self.priority_sum[idx * 2] > prefix_sum:
|
if self.priority_sum[idx * 2] > prefix_sum:
|
||||||
# go to left branch if the tree if the
|
# Go to left branch of the tree
|
||||||
idx = 2 * idx
|
idx = 2 * idx
|
||||||
else:
|
else:
|
||||||
# otherwise go to right branch and reduce the sum of left
|
# Otherwise go to right branch and reduce the sum of left
|
||||||
# branch from required sum
|
# branch from required sum
|
||||||
prefix_sum -= self.priority_sum[idx * 2]
|
prefix_sum -= self.priority_sum[idx * 2]
|
||||||
idx = 2 * idx + 1
|
idx = 2 * idx + 1
|
||||||
|
|
||||||
|
# We are at the leaf node. Subtract the capacity by the index in the tree
|
||||||
|
# to get the index of actual value
|
||||||
return idx - self.capacity
|
return idx - self.capacity
|
||||||
|
|
||||||
def sample(self, batch_size, beta):
|
def sample(self, batch_size, beta):
|
||||||
@ -193,12 +215,13 @@ class ReplayBuffer:
|
|||||||
### Sample from buffer
|
### Sample from buffer
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Initialize samples
|
||||||
samples = {
|
samples = {
|
||||||
'weights': np.zeros(shape=batch_size, dtype=np.float32),
|
'weights': np.zeros(shape=batch_size, dtype=np.float32),
|
||||||
'indexes': np.zeros(shape=batch_size, dtype=np.int32)
|
'indexes': np.zeros(shape=batch_size, dtype=np.int32)
|
||||||
}
|
}
|
||||||
|
|
||||||
# get samples
|
# Get sample indexes
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
p = random.random() * self._sum()
|
p = random.random() * self._sum()
|
||||||
idx = self.find_prefix_sum_idx(p)
|
idx = self.find_prefix_sum_idx(p)
|
||||||
@ -215,11 +238,11 @@ class ReplayBuffer:
|
|||||||
prob = self.priority_sum[idx + self.capacity] / self._sum()
|
prob = self.priority_sum[idx + self.capacity] / self._sum()
|
||||||
# $w_i = \bigg(\frac{1}{N} \frac{1}{P(i)}\bigg)^\beta$
|
# $w_i = \bigg(\frac{1}{N} \frac{1}{P(i)}\bigg)^\beta$
|
||||||
weight = (prob * self.size) ** (-beta)
|
weight = (prob * self.size) ** (-beta)
|
||||||
# normalize by $\frac{1}{\max_i w_i}$,
|
# Normalize by $\frac{1}{\max_i w_i}$,
|
||||||
# which also cancels off the $\frac{1}{N}$ term
|
# which also cancels off the $\frac{1}{N}$ term
|
||||||
samples['weights'][i] = weight / max_weight
|
samples['weights'][i] = weight / max_weight
|
||||||
|
|
||||||
# get samples data
|
# Get samples data
|
||||||
for k, v in self.data.items():
|
for k, v in self.data.items():
|
||||||
samples[k] = v[samples['indexes']]
|
samples[k] = v[samples['indexes']]
|
||||||
|
|
||||||
@ -229,16 +252,19 @@ class ReplayBuffer:
|
|||||||
"""
|
"""
|
||||||
### Update priorities
|
### Update priorities
|
||||||
"""
|
"""
|
||||||
|
|
||||||
for idx, priority in zip(indexes, priorities):
|
for idx, priority in zip(indexes, priorities):
|
||||||
|
# Set current max priority
|
||||||
self.max_priority = max(self.max_priority, priority)
|
self.max_priority = max(self.max_priority, priority)
|
||||||
|
|
||||||
# $p_i^\alpha$
|
# Calculate $p_i^\alpha$
|
||||||
priority_alpha = priority ** self.alpha
|
priority_alpha = priority ** self.alpha
|
||||||
|
# Update the trees
|
||||||
self._set_priority_min(idx, priority_alpha)
|
self._set_priority_min(idx, priority_alpha)
|
||||||
self._set_priority_sum(idx, priority_alpha)
|
self._set_priority_sum(idx, priority_alpha)
|
||||||
|
|
||||||
def is_full(self):
|
def is_full(self):
|
||||||
"""
|
"""
|
||||||
### Is the buffer full
|
### Whether the buffer is full
|
||||||
"""
|
"""
|
||||||
return self.capacity == self.size
|
return self.capacity == self.size
|
||||||
|
|||||||
Reference in New Issue
Block a user