mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-11-02 13:00:17 +08:00
📚 annotations
This commit is contained in:
@ -1,13 +1,24 @@
|
||||
"""
|
||||
This is a Deep Q Learning implementation with:
|
||||
# Deep Q Networks
|
||||
|
||||
This is a Deep Q Learning implementation that uses:
|
||||
|
||||
* [Dueling Network](model.html)
|
||||
* [Prioritized Replay](replay_buffer.html)
|
||||
* Double Q Network
|
||||
* Dueling Network
|
||||
* Prioritized Replay
|
||||
|
||||
Here's the [experiment](experiment.html) and [model](model.html).
|
||||
|
||||
\(
|
||||
\def\green#1{{\color{yellowgreen}{#1}}}
|
||||
\)
|
||||
|
||||
"""
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from labml import tracker
|
||||
from labml_helpers.module import Module
|
||||
@ -15,36 +26,133 @@ from labml_nn.rl.dqn.replay_buffer import ReplayBuffer
|
||||
|
||||
|
||||
class QFuncLoss(Module):
|
||||
"""
|
||||
## Train the model
|
||||
|
||||
We want to find optimal action-value function.
|
||||
|
||||
\begin{align}
|
||||
Q^*(s,a) &= \max_\pi \mathbb{E} \Big[
|
||||
r_t + \gamma r_{t + 1} + \gamma^2 r_{t + 2} + ... | s_t = s, a_t = a, \pi
|
||||
\Big]
|
||||
\\
|
||||
Q^*(s,a) &= \mathop{\mathbb{E}}_{s' \sim \large{\varepsilon}} \Big[
|
||||
r + \gamma \max_{a'} Q^* (s', a') | s, a
|
||||
\Big]
|
||||
\end{align}
|
||||
|
||||
### Target network 🎯
|
||||
In order to improve stability we use experience replay that randomly sample
|
||||
from previous experience $U(D)$. We also use a Q network
|
||||
with a separate set of paramters $\color{orangle}{\theta_i^{-}}$ to calculate the target.
|
||||
$\color{orangle}{\theta_i^{-}}$ is updated periodically.
|
||||
This is according to paper
|
||||
[Human Level Control Through Deep Reinforcement Learning](https://deepmind.com/research/dqn/).
|
||||
|
||||
So the loss function is,
|
||||
$$
|
||||
\mathcal{L}_i(\theta_i) = \mathop{\mathbb{E}}_{(s,a,r,s') \sim U(D)}
|
||||
\bigg[
|
||||
\Big(
|
||||
r + \gamma \max_{a'} Q(s', a'; \color{orange}{\theta_i^{-}}) - Q(s,a;\theta_i)
|
||||
\Big) ^ 2
|
||||
\bigg]
|
||||
$$
|
||||
|
||||
### Double $Q$-Learning
|
||||
The max operator in the above calculation uses same network for both
|
||||
selecting the best action and for evaluating the value.
|
||||
That is,
|
||||
$$
|
||||
\max_{a'} Q(s', a'; \theta) = \color{cyan}{Q}
|
||||
\Big(
|
||||
s', \mathop{\operatorname{argmax}}_{a'}
|
||||
\color{cyan}{Q}(s', a'; \color{cyan}{\theta}); \color{cyan}{\theta}
|
||||
\Big)
|
||||
$$
|
||||
We use [double Q-learning](https://arxiv.org/abs/1509.06461), where
|
||||
the $\operatorname{argmax}$ is taken from $\color{cyan}{\theta_i}$ and
|
||||
the value is taken from $\color{orange}{\theta_i^{-}}$.
|
||||
|
||||
And the loss function becomes,
|
||||
\begin{align}
|
||||
\mathcal{L}_i(\theta_i) = \mathop{\mathbb{E}}_{(s,a,r,s') \sim U(D)}
|
||||
\Bigg[
|
||||
\bigg(
|
||||
&r + \gamma \color{orange}{Q}
|
||||
\Big(
|
||||
s',
|
||||
\mathop{\operatorname{argmax}}_{a'}
|
||||
\color{cyan}{Q}(s', a'; \color{cyan}{\theta_i}); \color{orange}{\theta_i^{-}}
|
||||
\Big)
|
||||
\\
|
||||
- &Q(s,a;\theta_i)
|
||||
\bigg) ^ 2
|
||||
\Bigg]
|
||||
\end{align}
|
||||
"""
|
||||
|
||||
def __init__(self, gamma: float):
|
||||
super().__init__()
|
||||
self.gamma = gamma
|
||||
self.huber_loss = nn.SmoothL1Loss(reduction='none')
|
||||
|
||||
def __call__(self, q: torch.Tensor,
|
||||
action: torch.Tensor,
|
||||
double_q: torch.Tensor,
|
||||
target_q: torch.Tensor,
|
||||
done: torch.Tensor,
|
||||
reward: torch.Tensor,
|
||||
def __call__(self, q: torch.Tensor, action: torch.Tensor, double_q: torch.Tensor,
|
||||
target_q: torch.Tensor, done: torch.Tensor, reward: torch.Tensor,
|
||||
weights: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
* `q` - $Q(s;\theta_i)$
|
||||
* `action` - $a$
|
||||
* `double_q` - $\color{cyan}Q(s';\color{cyan}{\theta_i})$
|
||||
* `target_q` - $\color{orange}Q(s';\color{orange}{\theta_i^{-}})$
|
||||
* `done` - whether the game ended after taking the action
|
||||
* `reward` - $r$
|
||||
* `weights` - weights of the samples from prioritized experienced replay
|
||||
"""
|
||||
|
||||
# $Q(s,a;\theta_i)$
|
||||
q_sampled_action = q.gather(-1, action.to(torch.long).unsqueeze(-1)).squeeze(-1)
|
||||
tracker.add('q_sampled_action', q_sampled_action)
|
||||
|
||||
# Gradients shouldn't propagate gradients
|
||||
# $$r + \gamma \color{orange}{Q}
|
||||
# \Big(s',
|
||||
# \mathop{\operatorname{argmax}}_{a'}
|
||||
# \color{cyan}{Q}(s', a'; \color{cyan}{\theta_i}); \color{orange}{\theta_i^{-}}
|
||||
# \Big)$$
|
||||
with torch.no_grad():
|
||||
# Get the best action at state $s'$
|
||||
# $$\mathop{\operatorname{argmax}}_{a'}
|
||||
# \color{cyan}{Q}(s', a'; \color{cyan}{\theta_i})$$
|
||||
best_next_action = torch.argmax(double_q, -1)
|
||||
# Get the q value from the target network for the best action at state $s'$
|
||||
# $$\color{orange}{Q}
|
||||
# \Big(s',\mathop{\operatorname{argmax}}_{a'}
|
||||
# \color{cyan}{Q}(s', a'; \color{cyan}{\theta_i}); \color{orange}{\theta_i^{-}}
|
||||
# \Big)$$
|
||||
best_next_q_value = target_q.gather(-1, best_next_action.unsqueeze(-1)).squeeze(-1)
|
||||
|
||||
best_next_q_value *= (1 - done)
|
||||
|
||||
q_update = reward + self.gamma * best_next_q_value
|
||||
# Calculate the desired Q value.
|
||||
# We multiply by `(1 - done)` to zero out
|
||||
# the next state Q values if the game ended.
|
||||
#
|
||||
# $$r + \gamma \color{orange}{Q}
|
||||
# \Big(s',
|
||||
# \mathop{\operatorname{argmax}}_{a'}
|
||||
# \color{cyan}{Q}(s', a'; \color{cyan}{\theta_i}); \color{orange}{\theta_i^{-}}
|
||||
# \Big)$$
|
||||
q_update = reward + self.gamma * best_next_q_value * (1 - done)
|
||||
tracker.add('q_update', q_update)
|
||||
|
||||
# Temporal difference error $\delta$ is used to weigh samples in replay buffer
|
||||
td_error = q_sampled_action - q_update
|
||||
tracker.add('td_error', td_error)
|
||||
|
||||
# Huber loss
|
||||
losses = torch.nn.functional.smooth_l1_loss(q_sampled_action, q_update, reduction='none')
|
||||
# We take [Huber loss](https://en.wikipedia.org/wiki/Huber_loss) instead of
|
||||
# mean squared error loss because it is less sensitive to outliers
|
||||
losses = self.huber_loss(q_sampled_action, q_update)
|
||||
# Get weighted means
|
||||
loss = torch.mean(weights * losses)
|
||||
tracker.add('loss', loss)
|
||||
|
||||
return td_error, loss
|
||||
|
||||
|
||||
@ -8,12 +8,11 @@
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from labml import tracker, experiment, logger, monit
|
||||
from labml_helpers.module import Module
|
||||
from labml_helpers.schedule import Piecewise
|
||||
from labml_nn.rl.dqn import QFuncLoss
|
||||
from labml_nn.rl.dqn.model import Model
|
||||
from labml_nn.rl.dqn.replay_buffer import ReplayBuffer
|
||||
from labml_nn.rl.game import Worker
|
||||
|
||||
@ -23,105 +22,12 @@ else:
|
||||
device = torch.device("cpu")
|
||||
|
||||
|
||||
class Model(Module):
|
||||
"""
|
||||
## <a name="model"></a>Neural Network Model for $Q$ Values
|
||||
|
||||
#### Dueling Network ⚔️
|
||||
We are using a [dueling network](https://arxiv.org/abs/1511.06581)
|
||||
to calculate Q-values.
|
||||
Intuition behind dueling network architure is that in most states
|
||||
the action doesn't matter,
|
||||
and in some states the action is significant. Dueling network allows
|
||||
this to be represented very well.
|
||||
|
||||
\begin{align}
|
||||
Q^\pi(s,a) &= V^\pi(s) + A^\pi(s, a)
|
||||
\\
|
||||
\mathop{\mathbb{E}}_{a \sim \pi(s)}
|
||||
\Big[
|
||||
A^\pi(s, a)
|
||||
\Big]
|
||||
&= 0
|
||||
\end{align}
|
||||
|
||||
So we create two networks for $V$ and $A$ and get $Q$ from them.
|
||||
$$
|
||||
Q(s, a) = V(s) +
|
||||
\Big(
|
||||
A(s, a) - \frac{1}{|\mathcal{A}|} \sum_{a' \in \mathcal{A}} A(s, a')
|
||||
\Big)
|
||||
$$
|
||||
We share the initial layers of the $V$ and $A$ networks.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
### Initialize
|
||||
|
||||
We need `scope` because we need multiple copies of variables
|
||||
for target network and training network.
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
self.conv = nn.Sequential(
|
||||
# The first convolution layer takes a
|
||||
# 84x84 frame and produces a 20x20 frame
|
||||
nn.Conv2d(in_channels=4, out_channels=32, kernel_size=8, stride=4),
|
||||
nn.ReLU(),
|
||||
|
||||
# The second convolution layer takes a
|
||||
# 20x20 frame and produces a 9x9 frame
|
||||
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2),
|
||||
nn.ReLU(),
|
||||
|
||||
# The third convolution layer takes a
|
||||
# 9x9 frame and produces a 7x7 frame
|
||||
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1),
|
||||
nn.ReLU(),
|
||||
)
|
||||
|
||||
# A fully connected layer takes the flattened
|
||||
# frame from third convolution layer, and outputs
|
||||
# 512 features
|
||||
self.lin = nn.Linear(in_features=7 * 7 * 64, out_features=512)
|
||||
|
||||
self.state_score = nn.Sequential(
|
||||
nn.Linear(in_features=512, out_features=256),
|
||||
nn.ReLU(),
|
||||
nn.Linear(in_features=256, out_features=1),
|
||||
)
|
||||
self.action_score = nn.Sequential(
|
||||
nn.Linear(in_features=512, out_features=256),
|
||||
nn.ReLU(),
|
||||
nn.Linear(in_features=256, out_features=4),
|
||||
)
|
||||
|
||||
#
|
||||
self.activation = nn.ReLU()
|
||||
|
||||
def __call__(self, obs: torch.Tensor):
|
||||
h = self.conv(obs)
|
||||
h = h.reshape((-1, 7 * 7 * 64))
|
||||
|
||||
h = self.activation(self.lin(h))
|
||||
|
||||
action_score = self.action_score(h)
|
||||
state_score = self.state_score(h)
|
||||
|
||||
# $Q(s, a) =V(s) + \Big(A(s, a) - \frac{1}{|\mathcal{A}|} \sum_{a' \in \mathcal{A}} A(s, a')\Big)$
|
||||
action_score_centered = action_score - action_score.mean(dim=-1, keepdim=True)
|
||||
q = state_score + action_score_centered
|
||||
|
||||
return q
|
||||
|
||||
|
||||
def obs_to_torch(obs: np.ndarray) -> torch.Tensor:
|
||||
"""Scale observations from `[0, 255]` to `[0, 1]`"""
|
||||
return torch.tensor(obs, dtype=torch.float32, device=device) / 255.
|
||||
|
||||
|
||||
class Main(object):
|
||||
class Trainer:
|
||||
"""
|
||||
## <a name="main"></a>Main class
|
||||
This class runs the training loop.
|
||||
@ -239,71 +145,6 @@ class Main(object):
|
||||
self.obs[w] = next_obs
|
||||
|
||||
def train(self, beta: float):
|
||||
"""
|
||||
### Train the model
|
||||
|
||||
We want to find optimal action-value function.
|
||||
|
||||
\begin{align}
|
||||
Q^*(s,a) &= \max_\pi \mathbb{E} \Big[
|
||||
r_t + \gamma r_{t + 1} + \gamma^2 r_{t + 2} + ... | s_t = s, a_t = a, \pi
|
||||
\Big]
|
||||
\\
|
||||
Q^*(s,a) &= \mathop{\mathbb{E}}_{s' \sim \large{\varepsilon}} \Big[
|
||||
r + \gamma \max_{a'} Q^* (s', a') | s, a
|
||||
\Big]
|
||||
\end{align}
|
||||
|
||||
#### Target network 🎯
|
||||
In order to improve stability we use experience replay that randomly sample
|
||||
from previous experience $U(D)$. We also use a Q network
|
||||
with a separate set of paramters $\hl1{\theta_i^{-}}$ to calculate the target.
|
||||
$\hl1{\theta_i^{-}}$ is updated periodically.
|
||||
This is according to the [paper by DeepMind](https://deepmind.com/research/dqn/).
|
||||
|
||||
So the loss function is,
|
||||
$$
|
||||
\mathcal{L}_i(\theta_i) = \mathop{\mathbb{E}}_{(s,a,r,s') \sim U(D)}
|
||||
\bigg[
|
||||
\Big(
|
||||
r + \gamma \max_{a'} Q(s', a'; \hl1{\theta_i^{-}}) - Q(s,a;\theta_i)
|
||||
\Big) ^ 2
|
||||
\bigg]
|
||||
$$
|
||||
|
||||
#### Double $Q$-Learning
|
||||
The max operator in the above calculation uses same network for both
|
||||
selecting the best action and for evaluating the value.
|
||||
That is,
|
||||
$$
|
||||
\max_{a'} Q(s', a'; \theta) = \blue{Q}
|
||||
\Big(
|
||||
s', \mathop{\operatorname{argmax}}_{a'}
|
||||
\blue{Q}(s', a'; \blue{\theta}); \blue{\theta}
|
||||
\Big)
|
||||
$$
|
||||
We use [double Q-learning](https://arxiv.org/abs/1509.06461), where
|
||||
the $\operatorname{argmax}$ is taken from $\theta_i$ and
|
||||
the value is taken from $\theta_i^{-}$.
|
||||
|
||||
And the loss function becomes,
|
||||
\begin{align}
|
||||
\mathcal{L}_i(\theta_i) = \mathop{\mathbb{E}}_{(s,a,r,s') \sim U(D)}
|
||||
\Bigg[
|
||||
\bigg(
|
||||
&r + \gamma \blue{Q}
|
||||
\Big(
|
||||
s',
|
||||
\mathop{\operatorname{argmax}}_{a'}
|
||||
\green{Q}(s', a'; \green{\theta_i}); \blue{\theta_i^{-}}
|
||||
\Big)
|
||||
\\
|
||||
- &Q(s,a;\theta_i)
|
||||
\bigg) ^ 2
|
||||
\Bigg]
|
||||
\end{align}
|
||||
"""
|
||||
|
||||
for _ in range(self.train_epochs):
|
||||
# sample from priority replay buffer
|
||||
samples = self.replay_buffer.sample(self.mini_batch_size, beta)
|
||||
@ -379,7 +220,7 @@ class Main(object):
|
||||
# ## Run it
|
||||
if __name__ == "__main__":
|
||||
experiment.create(name='dqn')
|
||||
m = Main()
|
||||
m = Trainer()
|
||||
with experiment.start():
|
||||
m.run_training_loop()
|
||||
m.destroy()
|
||||
|
||||
100
labml_nn/rl/dqn/model.py
Normal file
100
labml_nn/rl/dqn/model.py
Normal file
@ -0,0 +1,100 @@
|
||||
"""
|
||||
# Neural Network Model
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from labml_helpers.module import Module
|
||||
|
||||
|
||||
class Model(Module):
|
||||
"""
|
||||
## Dueling Network ⚔️ Model for $Q$ Values
|
||||
|
||||
We are using a [dueling network](https://arxiv.org/abs/1511.06581)
|
||||
to calculate Q-values.
|
||||
Intuition behind dueling network architecture is that in most states
|
||||
the action doesn't matter,
|
||||
and in some states the action is significant. Dueling network allows
|
||||
this to be represented very well.
|
||||
|
||||
\begin{align}
|
||||
Q^\pi(s,a) &= V^\pi(s) + A^\pi(s, a)
|
||||
\\
|
||||
\mathop{\mathbb{E}}_{a \sim \pi(s)}
|
||||
\Big[
|
||||
A^\pi(s, a)
|
||||
\Big]
|
||||
&= 0
|
||||
\end{align}
|
||||
|
||||
So we create two networks for $V$ and $A$ and get $Q$ from them.
|
||||
$$
|
||||
Q(s, a) = V(s) +
|
||||
\Big(
|
||||
A(s, a) - \frac{1}{|\mathcal{A}|} \sum_{a' \in \mathcal{A}} A(s, a')
|
||||
\Big)
|
||||
$$
|
||||
We share the initial layers of the $V$ and $A$ networks.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
### Initialize
|
||||
|
||||
We need `scope` because we need multiple copies of variables
|
||||
for target network and training network.
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
self.conv = nn.Sequential(
|
||||
# The first convolution layer takes a
|
||||
# 84x84 frame and produces a 20x20 frame
|
||||
nn.Conv2d(in_channels=4, out_channels=32, kernel_size=8, stride=4),
|
||||
nn.ReLU(),
|
||||
|
||||
# The second convolution layer takes a
|
||||
# 20x20 frame and produces a 9x9 frame
|
||||
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2),
|
||||
nn.ReLU(),
|
||||
|
||||
# The third convolution layer takes a
|
||||
# 9x9 frame and produces a 7x7 frame
|
||||
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1),
|
||||
nn.ReLU(),
|
||||
)
|
||||
|
||||
# A fully connected layer takes the flattened
|
||||
# frame from third convolution layer, and outputs
|
||||
# 512 features
|
||||
self.lin = nn.Linear(in_features=7 * 7 * 64, out_features=512)
|
||||
|
||||
self.state_score = nn.Sequential(
|
||||
nn.Linear(in_features=512, out_features=256),
|
||||
nn.ReLU(),
|
||||
nn.Linear(in_features=256, out_features=1),
|
||||
)
|
||||
self.action_score = nn.Sequential(
|
||||
nn.Linear(in_features=512, out_features=256),
|
||||
nn.ReLU(),
|
||||
nn.Linear(in_features=256, out_features=4),
|
||||
)
|
||||
|
||||
#
|
||||
self.activation = nn.ReLU()
|
||||
|
||||
def __call__(self, obs: torch.Tensor):
|
||||
h = self.conv(obs)
|
||||
h = h.reshape((-1, 7 * 7 * 64))
|
||||
|
||||
h = self.activation(self.lin(h))
|
||||
|
||||
action_score = self.action_score(h)
|
||||
state_score = self.state_score(h)
|
||||
|
||||
# $Q(s, a) =V(s) + \Big(A(s, a) - \frac{1}{|\mathcal{A}|} \sum_{a' \in \mathcal{A}} A(s, a')\Big)$
|
||||
action_score_centered = action_score - action_score.mean(dim=-1, keepdim=True)
|
||||
q = state_score + action_score_centered
|
||||
|
||||
return q
|
||||
@ -1,3 +1,10 @@
|
||||
"""
|
||||
# Prioritized Experience Replace Buffer
|
||||
|
||||
This implements paper [Prioritized experience replay](https://arxiv.org/abs/1511.05952),
|
||||
using a binary segment tree.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import random
|
||||
|
||||
@ -5,9 +12,10 @@ import random
|
||||
class ReplayBuffer:
|
||||
"""
|
||||
## Buffer for Prioritized Experience Replay
|
||||
|
||||
[Prioritized experience replay](https://arxiv.org/abs/1511.05952)
|
||||
samples important transitions more frequently.
|
||||
The transitions are prioritized by the Temporal Difference error.
|
||||
The transitions are prioritized by the Temporal Difference error (td error).
|
||||
|
||||
We sample transition $i$ with probability,
|
||||
$$P(i) = \frac{p_i^\alpha}{\sum_k p_k^\alpha}$$
|
||||
@ -21,16 +29,16 @@ class ReplayBuffer:
|
||||
importance-sampling (IS) weights
|
||||
$$w_i = \bigg(\frac{1}{N} \frac{1}{P(i)}\bigg)^\beta$$
|
||||
that fully compensates for when $\beta = 1$.
|
||||
We normalize weights by $1/\max_i w_i$ for stability.
|
||||
We normalize weights by $\frac{1}{\max_i w_i}$ for stability.
|
||||
Unbiased nature is most important towards the convergence at end of training.
|
||||
Therefore we increase $\beta$ towards end of training.
|
||||
|
||||
### Binary Segment Trees
|
||||
We use binary segment trees to efficiently calculate
|
||||
### Binary Segment Tree
|
||||
We use a binary segment tree to efficiently calculate
|
||||
$\sum_k^i p_k^\alpha$, the cumulative probability,
|
||||
which is needed to sample.
|
||||
We also use a binary segment tree to find $\min p_i^\alpha$,
|
||||
which is needed for $1/\max_i w_i$.
|
||||
which is needed for $\frac{1}{\max_i w_i}$.
|
||||
We can also use a min-heap for this.
|
||||
|
||||
This is how a binary segment tree works for sum;
|
||||
@ -54,14 +62,16 @@ class ReplayBuffer:
|
||||
$$N_i = \left\lceil{\frac{N}{D - i + i}} \right\rceil$$
|
||||
This is equal to the sum of nodes in all rows above $i$.
|
||||
So we can use a single array $a$ to store the tree, where,
|
||||
$$b_{i,j} = a_{N_1 + j}$$
|
||||
$$b_{i,j} \rightarrow a_{N_i + j}$$
|
||||
|
||||
Then child nodes of $a_i$ are $a_{2i}$ and $a_{2i + 1}$.
|
||||
That is,
|
||||
$$a_i = a_{2i} + a_{2i + 1}$$
|
||||
|
||||
This way of maintaining binary trees is very easy to program.
|
||||
*Note that we are indexing from 1*.
|
||||
*Note that we are indexing starting from 1*.
|
||||
|
||||
We using the same structure to compute the minimum.
|
||||
"""
|
||||
|
||||
def __init__(self, capacity, alpha):
|
||||
@ -206,7 +216,7 @@ class ReplayBuffer:
|
||||
# $w_i = \bigg(\frac{1}{N} \frac{1}{P(i)}\bigg)^\beta$
|
||||
weight = (prob * self.size) ** (-beta)
|
||||
# normalize by $\frac{1}{\max_i w_i}$,
|
||||
# which also cancels off the $\frac{1}/{N}$ term
|
||||
# which also cancels off the $\frac{1}{N}$ term
|
||||
samples['weights'][i] = weight / max_weight
|
||||
|
||||
# get samples data
|
||||
@ -230,7 +240,5 @@ class ReplayBuffer:
|
||||
def is_full(self):
|
||||
"""
|
||||
### Is the buffer full
|
||||
|
||||
We only start sampling afte the buffer is full.
|
||||
"""
|
||||
return self.capacity == self.size
|
||||
|
||||
Reference in New Issue
Block a user