mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-11-02 21:40:15 +08:00
✨ DQN
This commit is contained in:
@ -0,0 +1,50 @@
|
|||||||
|
"""
|
||||||
|
This is a Deep Q Learning implementation with:
|
||||||
|
* Double Q Network
|
||||||
|
* Dueling Network
|
||||||
|
* Prioritized Replay
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from labml import tracker
|
||||||
|
from labml_helpers.module import Module
|
||||||
|
from labml_nn.rl.dqn.replay_buffer import ReplayBuffer
|
||||||
|
|
||||||
|
|
||||||
|
class QFuncLoss(Module):
|
||||||
|
def __init__(self, gamma: float):
|
||||||
|
super().__init__()
|
||||||
|
self.gamma = gamma
|
||||||
|
|
||||||
|
def __call__(self, q: torch.Tensor,
|
||||||
|
action: torch.Tensor,
|
||||||
|
double_q: torch.Tensor,
|
||||||
|
target_q: torch.Tensor,
|
||||||
|
done: torch.Tensor,
|
||||||
|
reward: torch.Tensor,
|
||||||
|
weights: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
q_sampled_action = q.gather(-1, action.to(torch.long).unsqueeze(-1)).squeeze(-1)
|
||||||
|
tracker.add('q_sampled_action', q_sampled_action)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
best_next_action = torch.argmax(double_q, -1)
|
||||||
|
best_next_q_value = target_q.gather(-1, best_next_action.unsqueeze(-1)).squeeze(-1)
|
||||||
|
|
||||||
|
best_next_q_value *= (1 - done)
|
||||||
|
|
||||||
|
q_update = reward + self.gamma * best_next_q_value
|
||||||
|
tracker.add('q_update', q_update)
|
||||||
|
|
||||||
|
td_error = q_sampled_action - q_update
|
||||||
|
tracker.add('td_error', td_error)
|
||||||
|
|
||||||
|
# Huber loss
|
||||||
|
losses = torch.nn.functional.smooth_l1_loss(q_sampled_action, q_update, reduction='none')
|
||||||
|
loss = torch.mean(weights * losses)
|
||||||
|
tracker.add('loss', loss)
|
||||||
|
|
||||||
|
return td_error, loss
|
||||||
|
|
||||||
|
|||||||
385
labml_nn/rl/dqn/experiment.py
Normal file
385
labml_nn/rl/dqn/experiment.py
Normal file
@ -0,0 +1,385 @@
|
|||||||
|
"""
|
||||||
|
\(
|
||||||
|
\def\hl1#1{{\color{orange}{#1}}}
|
||||||
|
\def\blue#1{{\color{cyan}{#1}}}
|
||||||
|
\def\green#1{{\color{yellowgreen}{#1}}}
|
||||||
|
\)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from labml import tracker, experiment, logger, monit
|
||||||
|
from labml_helpers.module import Module
|
||||||
|
from labml_helpers.schedule import Piecewise
|
||||||
|
from labml_nn.rl.dqn import QFuncLoss
|
||||||
|
from labml_nn.rl.dqn.replay_buffer import ReplayBuffer
|
||||||
|
from labml_nn.rl.game import Worker
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda:0")
|
||||||
|
else:
|
||||||
|
device = torch.device("cpu")
|
||||||
|
|
||||||
|
|
||||||
|
class Model(Module):
|
||||||
|
"""
|
||||||
|
## <a name="model"></a>Neural Network Model for $Q$ Values
|
||||||
|
|
||||||
|
#### Dueling Network ⚔️
|
||||||
|
We are using a [dueling network](https://arxiv.org/abs/1511.06581)
|
||||||
|
to calculate Q-values.
|
||||||
|
Intuition behind dueling network architure is that in most states
|
||||||
|
the action doesn't matter,
|
||||||
|
and in some states the action is significant. Dueling network allows
|
||||||
|
this to be represented very well.
|
||||||
|
|
||||||
|
\begin{align}
|
||||||
|
Q^\pi(s,a) &= V^\pi(s) + A^\pi(s, a)
|
||||||
|
\\
|
||||||
|
\mathop{\mathbb{E}}_{a \sim \pi(s)}
|
||||||
|
\Big[
|
||||||
|
A^\pi(s, a)
|
||||||
|
\Big]
|
||||||
|
&= 0
|
||||||
|
\end{align}
|
||||||
|
|
||||||
|
So we create two networks for $V$ and $A$ and get $Q$ from them.
|
||||||
|
$$
|
||||||
|
Q(s, a) = V(s) +
|
||||||
|
\Big(
|
||||||
|
A(s, a) - \frac{1}{|\mathcal{A}|} \sum_{a' \in \mathcal{A}} A(s, a')
|
||||||
|
\Big)
|
||||||
|
$$
|
||||||
|
We share the initial layers of the $V$ and $A$ networks.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""
|
||||||
|
### Initialize
|
||||||
|
|
||||||
|
We need `scope` because we need multiple copies of variables
|
||||||
|
for target network and training network.
|
||||||
|
"""
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
self.conv = nn.Sequential(
|
||||||
|
# The first convolution layer takes a
|
||||||
|
# 84x84 frame and produces a 20x20 frame
|
||||||
|
nn.Conv2d(in_channels=4, out_channels=32, kernel_size=8, stride=4),
|
||||||
|
nn.ReLU(),
|
||||||
|
|
||||||
|
# The second convolution layer takes a
|
||||||
|
# 20x20 frame and produces a 9x9 frame
|
||||||
|
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2),
|
||||||
|
nn.ReLU(),
|
||||||
|
|
||||||
|
# The third convolution layer takes a
|
||||||
|
# 9x9 frame and produces a 7x7 frame
|
||||||
|
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1),
|
||||||
|
nn.ReLU(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# A fully connected layer takes the flattened
|
||||||
|
# frame from third convolution layer, and outputs
|
||||||
|
# 512 features
|
||||||
|
self.lin = nn.Linear(in_features=7 * 7 * 64, out_features=512)
|
||||||
|
|
||||||
|
self.state_score = nn.Sequential(
|
||||||
|
nn.Linear(in_features=512, out_features=256),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(in_features=256, out_features=1),
|
||||||
|
)
|
||||||
|
self.action_score = nn.Sequential(
|
||||||
|
nn.Linear(in_features=512, out_features=256),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(in_features=256, out_features=4),
|
||||||
|
)
|
||||||
|
|
||||||
|
#
|
||||||
|
self.activation = nn.ReLU()
|
||||||
|
|
||||||
|
def __call__(self, obs: torch.Tensor):
|
||||||
|
h = self.conv(obs)
|
||||||
|
h = h.reshape((-1, 7 * 7 * 64))
|
||||||
|
|
||||||
|
h = self.activation(self.lin(h))
|
||||||
|
|
||||||
|
action_score = self.action_score(h)
|
||||||
|
state_score = self.state_score(h)
|
||||||
|
|
||||||
|
# $Q(s, a) =V(s) + \Big(A(s, a) - \frac{1}{|\mathcal{A}|} \sum_{a' \in \mathcal{A}} A(s, a')\Big)$
|
||||||
|
action_score_centered = action_score - action_score.mean(dim=-1, keepdim=True)
|
||||||
|
q = state_score + action_score_centered
|
||||||
|
|
||||||
|
return q
|
||||||
|
|
||||||
|
|
||||||
|
def obs_to_torch(obs: np.ndarray) -> torch.Tensor:
|
||||||
|
"""Scale observations from `[0, 255]` to `[0, 1]`"""
|
||||||
|
return torch.tensor(obs, dtype=torch.float32, device=device) / 255.
|
||||||
|
|
||||||
|
|
||||||
|
class Main(object):
|
||||||
|
"""
|
||||||
|
## <a name="main"></a>Main class
|
||||||
|
This class runs the training loop.
|
||||||
|
It initializes TensorFlow, handles logging and monitoring,
|
||||||
|
and runs workers as multiple processes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""
|
||||||
|
### Initialize
|
||||||
|
"""
|
||||||
|
|
||||||
|
# #### Configurations
|
||||||
|
|
||||||
|
# number of workers
|
||||||
|
self.n_workers = 8
|
||||||
|
# steps sampled on each update
|
||||||
|
self.worker_steps = 4
|
||||||
|
# number of training iterations
|
||||||
|
self.train_epochs = 8
|
||||||
|
|
||||||
|
# number of updates
|
||||||
|
self.updates = 1_000_000
|
||||||
|
# size of mini batch for training
|
||||||
|
self.mini_batch_size = 32
|
||||||
|
|
||||||
|
# exploration as a function of time step
|
||||||
|
self.exploration_coefficient = Piecewise(
|
||||||
|
[
|
||||||
|
(0, 1.0),
|
||||||
|
(25_000, 0.1),
|
||||||
|
(self.updates / 2, 0.01)
|
||||||
|
], outside_value=0.01)
|
||||||
|
|
||||||
|
# update target network every 250 update
|
||||||
|
self.update_target_model = 250
|
||||||
|
|
||||||
|
# $\beta$ for replay buffer as a function of time steps
|
||||||
|
self.prioritized_replay_beta = Piecewise(
|
||||||
|
[
|
||||||
|
(0, 0.4),
|
||||||
|
(self.updates, 1)
|
||||||
|
], outside_value=1)
|
||||||
|
|
||||||
|
# replay buffer
|
||||||
|
self.replay_buffer = ReplayBuffer(2 ** 14, 0.6)
|
||||||
|
|
||||||
|
self.model = Model().to(device)
|
||||||
|
self.target_model = Model().to(device)
|
||||||
|
|
||||||
|
# last observation for each worker
|
||||||
|
# create workers
|
||||||
|
self.workers = [Worker(47 + i) for i in range(self.n_workers)]
|
||||||
|
|
||||||
|
# initialize tensors for observations
|
||||||
|
self.obs = np.zeros((self.n_workers, 4, 84, 84), dtype=np.uint8)
|
||||||
|
for worker in self.workers:
|
||||||
|
worker.child.send(("reset", None))
|
||||||
|
for i, worker in enumerate(self.workers):
|
||||||
|
self.obs[i] = worker.child.recv()
|
||||||
|
|
||||||
|
self.loss_func = QFuncLoss(0.99)
|
||||||
|
# optimizer
|
||||||
|
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=2.5e-4)
|
||||||
|
|
||||||
|
def _sample_action(self, q_value: torch.Tensor, exploration_coefficient: float):
|
||||||
|
"""
|
||||||
|
#### $\epsilon$-greedy Sampling
|
||||||
|
When sampling actions we use a $\epsilon$-greedy strategy, where we
|
||||||
|
take a greedy action with probabiliy $1 - \epsilon$ and
|
||||||
|
take a random action with probability $\epsilon$.
|
||||||
|
We refer to $\epsilon$ as *exploration*.
|
||||||
|
"""
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
greedy_action = torch.argmax(q_value, dim=-1)
|
||||||
|
random_action = torch.randint(q_value.shape[-1], greedy_action.shape, device=q_value.device)
|
||||||
|
|
||||||
|
is_choose_rand = torch.rand(greedy_action.shape, device=q_value.device) < exploration_coefficient
|
||||||
|
|
||||||
|
return torch.where(is_choose_rand, random_action, greedy_action).cpu().numpy()
|
||||||
|
|
||||||
|
def sample(self, exploration_coefficient: float):
|
||||||
|
"""### Sample data"""
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
# sample `SAMPLE_STEPS`
|
||||||
|
for t in range(self.worker_steps):
|
||||||
|
# sample actions
|
||||||
|
q_value = self.model(obs_to_torch(self.obs))
|
||||||
|
actions = self._sample_action(q_value, exploration_coefficient)
|
||||||
|
|
||||||
|
# run sampled actions on each worker
|
||||||
|
for w, worker in enumerate(self.workers):
|
||||||
|
worker.child.send(("step", actions[w]))
|
||||||
|
|
||||||
|
# collect information from each worker
|
||||||
|
for w, worker in enumerate(self.workers):
|
||||||
|
# get results after executing the actions
|
||||||
|
next_obs, reward, done, info = worker.child.recv()
|
||||||
|
|
||||||
|
# add transition to replay buffer
|
||||||
|
self.replay_buffer.add(self.obs[w], actions[w], reward, next_obs, done)
|
||||||
|
|
||||||
|
# update episode information
|
||||||
|
# collect episode info, which is available if an episode finished;
|
||||||
|
# this includes total reward and length of the episode -
|
||||||
|
# look at `Game` to see how it works.
|
||||||
|
# We also add a game frame to it for monitoring.
|
||||||
|
if info:
|
||||||
|
tracker.add('reward', info['reward'])
|
||||||
|
tracker.add('length', info['length'])
|
||||||
|
|
||||||
|
# update current observation
|
||||||
|
self.obs[w] = next_obs
|
||||||
|
|
||||||
|
def train(self, beta: float):
|
||||||
|
"""
|
||||||
|
### Train the model
|
||||||
|
|
||||||
|
We want to find optimal action-value function.
|
||||||
|
|
||||||
|
\begin{align}
|
||||||
|
Q^*(s,a) &= \max_\pi \mathbb{E} \Big[
|
||||||
|
r_t + \gamma r_{t + 1} + \gamma^2 r_{t + 2} + ... | s_t = s, a_t = a, \pi
|
||||||
|
\Big]
|
||||||
|
\\
|
||||||
|
Q^*(s,a) &= \mathop{\mathbb{E}}_{s' \sim \large{\varepsilon}} \Big[
|
||||||
|
r + \gamma \max_{a'} Q^* (s', a') | s, a
|
||||||
|
\Big]
|
||||||
|
\end{align}
|
||||||
|
|
||||||
|
#### Target network 🎯
|
||||||
|
In order to improve stability we use experience replay that randomly sample
|
||||||
|
from previous experience $U(D)$. We also use a Q network
|
||||||
|
with a separate set of paramters $\hl1{\theta_i^{-}}$ to calculate the target.
|
||||||
|
$\hl1{\theta_i^{-}}$ is updated periodically.
|
||||||
|
This is according to the [paper by DeepMind](https://deepmind.com/research/dqn/).
|
||||||
|
|
||||||
|
So the loss function is,
|
||||||
|
$$
|
||||||
|
\mathcal{L}_i(\theta_i) = \mathop{\mathbb{E}}_{(s,a,r,s') \sim U(D)}
|
||||||
|
\bigg[
|
||||||
|
\Big(
|
||||||
|
r + \gamma \max_{a'} Q(s', a'; \hl1{\theta_i^{-}}) - Q(s,a;\theta_i)
|
||||||
|
\Big) ^ 2
|
||||||
|
\bigg]
|
||||||
|
$$
|
||||||
|
|
||||||
|
#### Double $Q$-Learning
|
||||||
|
The max operator in the above calculation uses same network for both
|
||||||
|
selecting the best action and for evaluating the value.
|
||||||
|
That is,
|
||||||
|
$$
|
||||||
|
\max_{a'} Q(s', a'; \theta) = \blue{Q}
|
||||||
|
\Big(
|
||||||
|
s', \mathop{\operatorname{argmax}}_{a'}
|
||||||
|
\blue{Q}(s', a'; \blue{\theta}); \blue{\theta}
|
||||||
|
\Big)
|
||||||
|
$$
|
||||||
|
We use [double Q-learning](https://arxiv.org/abs/1509.06461), where
|
||||||
|
the $\operatorname{argmax}$ is taken from $\theta_i$ and
|
||||||
|
the value is taken from $\theta_i^{-}$.
|
||||||
|
|
||||||
|
And the loss function becomes,
|
||||||
|
\begin{align}
|
||||||
|
\mathcal{L}_i(\theta_i) = \mathop{\mathbb{E}}_{(s,a,r,s') \sim U(D)}
|
||||||
|
\Bigg[
|
||||||
|
\bigg(
|
||||||
|
&r + \gamma \blue{Q}
|
||||||
|
\Big(
|
||||||
|
s',
|
||||||
|
\mathop{\operatorname{argmax}}_{a'}
|
||||||
|
\green{Q}(s', a'; \green{\theta_i}); \blue{\theta_i^{-}}
|
||||||
|
\Big)
|
||||||
|
\\
|
||||||
|
- &Q(s,a;\theta_i)
|
||||||
|
\bigg) ^ 2
|
||||||
|
\Bigg]
|
||||||
|
\end{align}
|
||||||
|
"""
|
||||||
|
|
||||||
|
for _ in range(self.train_epochs):
|
||||||
|
# sample from priority replay buffer
|
||||||
|
samples = self.replay_buffer.sample(self.mini_batch_size, beta)
|
||||||
|
# train network
|
||||||
|
q_value = self.model(obs_to_torch(samples['obs']))
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
double_q_value = self.model(obs_to_torch(samples['next_obs']))
|
||||||
|
target_q_value = self.target_model(obs_to_torch(samples['next_obs']))
|
||||||
|
|
||||||
|
td_errors, loss = self.loss_func(q_value,
|
||||||
|
q_value.new_tensor(samples['action']),
|
||||||
|
double_q_value, target_q_value,
|
||||||
|
q_value.new_tensor(samples['done']),
|
||||||
|
q_value.new_tensor(samples['reward']),
|
||||||
|
q_value.new_tensor(samples['weights']))
|
||||||
|
|
||||||
|
# $p_i = |\delta_i| + \epsilon$
|
||||||
|
new_priorities = np.abs(td_errors.cpu().numpy()) + 1e-6
|
||||||
|
# update replay buffer
|
||||||
|
self.replay_buffer.update_priorities(samples['indexes'], new_priorities)
|
||||||
|
|
||||||
|
# compute gradients
|
||||||
|
self.optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=0.5)
|
||||||
|
self.optimizer.step()
|
||||||
|
|
||||||
|
def run_training_loop(self):
|
||||||
|
"""
|
||||||
|
### Run training loop
|
||||||
|
"""
|
||||||
|
|
||||||
|
# copy to target network initially
|
||||||
|
self.target_model.load_state_dict(self.model.state_dict())
|
||||||
|
|
||||||
|
# last 100 episode information
|
||||||
|
tracker.set_queue('reward', 100, True)
|
||||||
|
tracker.set_queue('length', 100, True)
|
||||||
|
|
||||||
|
for update in monit.loop(self.updates):
|
||||||
|
# $\epsilon$, exploration fraction
|
||||||
|
exploration = self.exploration_coefficient(update)
|
||||||
|
tracker.add('exploration', exploration)
|
||||||
|
# $\beta$ for priority replay
|
||||||
|
beta = self.prioritized_replay_beta(update)
|
||||||
|
tracker.add('beta', beta)
|
||||||
|
|
||||||
|
# sample with current policy
|
||||||
|
self.sample(exploration)
|
||||||
|
|
||||||
|
if self.replay_buffer.is_full():
|
||||||
|
# train the model
|
||||||
|
self.train(beta)
|
||||||
|
|
||||||
|
# periodically update target network
|
||||||
|
if update % self.update_target_model == 0:
|
||||||
|
self.target_model.load_state_dict(self.model.state_dict())
|
||||||
|
|
||||||
|
tracker.save()
|
||||||
|
if (update + 1) % 1_000 == 0:
|
||||||
|
logger.log()
|
||||||
|
|
||||||
|
def destroy(self):
|
||||||
|
"""
|
||||||
|
### Destroy
|
||||||
|
Stop the workers
|
||||||
|
"""
|
||||||
|
for worker in self.workers:
|
||||||
|
worker.child.send(("close", None))
|
||||||
|
|
||||||
|
|
||||||
|
# ## Run it
|
||||||
|
if __name__ == "__main__":
|
||||||
|
experiment.create(name='dqn')
|
||||||
|
m = Main()
|
||||||
|
with experiment.start():
|
||||||
|
m.run_training_loop()
|
||||||
|
m.destroy()
|
||||||
236
labml_nn/rl/dqn/replay_buffer.py
Normal file
236
labml_nn/rl/dqn/replay_buffer.py
Normal file
@ -0,0 +1,236 @@
|
|||||||
|
import numpy as np
|
||||||
|
import random
|
||||||
|
|
||||||
|
|
||||||
|
class ReplayBuffer:
|
||||||
|
"""
|
||||||
|
## Buffer for Prioritized Experience Replay
|
||||||
|
[Prioritized experience replay](https://arxiv.org/abs/1511.05952)
|
||||||
|
samples important transitions more frequently.
|
||||||
|
The transitions are prioritized by the Temporal Difference error.
|
||||||
|
|
||||||
|
We sample transition $i$ with probability,
|
||||||
|
$$P(i) = \frac{p_i^\alpha}{\sum_k p_k^\alpha}$$
|
||||||
|
where $\alpha$ is a hyper-parameter that determines how much
|
||||||
|
prioritization is used, with $\alpha = 0$ corresponding to uniform case.
|
||||||
|
|
||||||
|
We use proportional prioritization $p_i = |\delta_i| + \epsilon$ where
|
||||||
|
$\delta_i$ is the temporal difference for transition $i$.
|
||||||
|
|
||||||
|
We correct the bias introduced by prioritized replay by
|
||||||
|
importance-sampling (IS) weights
|
||||||
|
$$w_i = \bigg(\frac{1}{N} \frac{1}{P(i)}\bigg)^\beta$$
|
||||||
|
that fully compensates for when $\beta = 1$.
|
||||||
|
We normalize weights by $1/\max_i w_i$ for stability.
|
||||||
|
Unbiased nature is most important towards the convergence at end of training.
|
||||||
|
Therefore we increase $\beta$ towards end of training.
|
||||||
|
|
||||||
|
### Binary Segment Trees
|
||||||
|
We use binary segment trees to efficiently calculate
|
||||||
|
$\sum_k^i p_k^\alpha$, the cumulative probability,
|
||||||
|
which is needed to sample.
|
||||||
|
We also use a binary segment tree to find $\min p_i^\alpha$,
|
||||||
|
which is needed for $1/\max_i w_i$.
|
||||||
|
We can also use a min-heap for this.
|
||||||
|
|
||||||
|
This is how a binary segment tree works for sum;
|
||||||
|
it is similar for minimum.
|
||||||
|
Let $x_i$ be the list of $N$ values we want to represent.
|
||||||
|
Let $b_{i,j}$ be the $j^{\mathop{th}}$ node of the $i^{\mathop{th}}$ row
|
||||||
|
in the binary tree.
|
||||||
|
That is two children of node $b_{i,j}$ are $b_{i+1,2j}$ and $b_{i+1,2j + 1}$.
|
||||||
|
|
||||||
|
The leaf nodes on row $D = \left\lceil {1 + \log_2 N} \right\rceil$
|
||||||
|
will have values of $x$.
|
||||||
|
Every node keeps the sum of the two child nodes.
|
||||||
|
So the root node keeps the sum of the entire array of values.
|
||||||
|
The two children of the root node keep
|
||||||
|
the sum of the first half of the array and
|
||||||
|
the sum of the second half of the array, and so on.
|
||||||
|
|
||||||
|
$$b_{i,j} = \sum_{k = (j -1) * 2^{D - i} + 1}^{j * 2^{D - i}} x_k$$
|
||||||
|
|
||||||
|
Number of nodes in row $i$,
|
||||||
|
$$N_i = \left\lceil{\frac{N}{D - i + i}} \right\rceil$$
|
||||||
|
This is equal to the sum of nodes in all rows above $i$.
|
||||||
|
So we can use a single array $a$ to store the tree, where,
|
||||||
|
$$b_{i,j} = a_{N_1 + j}$$
|
||||||
|
|
||||||
|
Then child nodes of $a_i$ are $a_{2i}$ and $a_{2i + 1}$.
|
||||||
|
That is,
|
||||||
|
$$a_i = a_{2i} + a_{2i + 1}$$
|
||||||
|
|
||||||
|
This way of maintaining binary trees is very easy to program.
|
||||||
|
*Note that we are indexing from 1*.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, capacity, alpha):
|
||||||
|
"""
|
||||||
|
### Initialize
|
||||||
|
"""
|
||||||
|
# we use a power of 2 for capacity to make it easy to debug
|
||||||
|
self.capacity = capacity
|
||||||
|
# we refill the queue once it reaches capacity
|
||||||
|
self.next_idx = 0
|
||||||
|
# $\alpha$
|
||||||
|
self.alpha = alpha
|
||||||
|
|
||||||
|
# maintain segment binary trees to take sum and find minimum over a range
|
||||||
|
self.priority_sum = [0 for _ in range(2 * self.capacity)]
|
||||||
|
self.priority_min = [float('inf') for _ in range(2 * self.capacity)]
|
||||||
|
|
||||||
|
# current max priority, $p$, to be assigned to new transitions
|
||||||
|
self.max_priority = 1.
|
||||||
|
|
||||||
|
# arrays for buffer
|
||||||
|
self.data = {
|
||||||
|
'obs': np.zeros(shape=(capacity, 4, 84, 84), dtype=np.uint8),
|
||||||
|
'action': np.zeros(shape=capacity, dtype=np.int32),
|
||||||
|
'reward': np.zeros(shape=capacity, dtype=np.float32),
|
||||||
|
'next_obs': np.zeros(shape=(capacity, 4, 84, 84), dtype=np.uint8),
|
||||||
|
'done': np.zeros(shape=capacity, dtype=np.bool)
|
||||||
|
}
|
||||||
|
|
||||||
|
# size of the buffer
|
||||||
|
self.size = 0
|
||||||
|
|
||||||
|
def add(self, obs, action, reward, next_obs, done):
|
||||||
|
"""
|
||||||
|
### Add sample to queue
|
||||||
|
"""
|
||||||
|
|
||||||
|
idx = self.next_idx
|
||||||
|
|
||||||
|
# store in the queue
|
||||||
|
self.data['obs'][idx] = obs
|
||||||
|
self.data['action'][idx] = action
|
||||||
|
self.data['reward'][idx] = reward
|
||||||
|
self.data['next_obs'][idx] = next_obs
|
||||||
|
self.data['done'][idx] = done
|
||||||
|
|
||||||
|
# increment head of the queue and calculate the size
|
||||||
|
self.next_idx = (idx + 1) % self.capacity
|
||||||
|
self.size = min(self.capacity, self.size + 1)
|
||||||
|
|
||||||
|
# $p_i^\alpha$, new samples get `max_priority`
|
||||||
|
priority_alpha = self.max_priority ** self.alpha
|
||||||
|
self._set_priority_min(idx, priority_alpha)
|
||||||
|
self._set_priority_sum(idx, priority_alpha)
|
||||||
|
|
||||||
|
def _set_priority_min(self, idx, priority_alpha):
|
||||||
|
"""
|
||||||
|
#### Set priority in binary segment tree for minimum
|
||||||
|
"""
|
||||||
|
|
||||||
|
# leaf of the binary tree
|
||||||
|
idx += self.capacity
|
||||||
|
self.priority_min[idx] = priority_alpha
|
||||||
|
|
||||||
|
# update tree, by traversing along ancestors
|
||||||
|
while idx >= 2:
|
||||||
|
idx //= 2
|
||||||
|
self.priority_min[idx] = min(self.priority_min[2 * idx],
|
||||||
|
self.priority_min[2 * idx + 1])
|
||||||
|
|
||||||
|
def _set_priority_sum(self, idx, priority):
|
||||||
|
"""
|
||||||
|
#### Set priority in binary segment tree for sum
|
||||||
|
"""
|
||||||
|
|
||||||
|
# leaf of the binary tree
|
||||||
|
idx += self.capacity
|
||||||
|
self.priority_sum[idx] = priority
|
||||||
|
|
||||||
|
# update tree, by traversing along ancestors
|
||||||
|
while idx >= 2:
|
||||||
|
idx //= 2
|
||||||
|
self.priority_sum[idx] = self.priority_sum[2 * idx] + self.priority_sum[2 * idx + 1]
|
||||||
|
|
||||||
|
def _sum(self):
|
||||||
|
"""
|
||||||
|
#### $\sum_k p_k^\alpha$
|
||||||
|
"""
|
||||||
|
return self.priority_sum[1]
|
||||||
|
|
||||||
|
def _min(self):
|
||||||
|
"""
|
||||||
|
#### $\min_k p_k^\alpha$
|
||||||
|
"""
|
||||||
|
return self.priority_min[1]
|
||||||
|
|
||||||
|
def find_prefix_sum_idx(self, prefix_sum):
|
||||||
|
"""
|
||||||
|
#### Find largest $i$ such that $\sum_{k=1}^{i} p_k^\alpha \le P$
|
||||||
|
"""
|
||||||
|
|
||||||
|
# start from the root
|
||||||
|
idx = 1
|
||||||
|
while idx < self.capacity:
|
||||||
|
# if the sum of the left branch is higher than required sum
|
||||||
|
if self.priority_sum[idx * 2] > prefix_sum:
|
||||||
|
# go to left branch if the tree if the
|
||||||
|
idx = 2 * idx
|
||||||
|
else:
|
||||||
|
# otherwise go to right branch and reduce the sum of left
|
||||||
|
# branch from required sum
|
||||||
|
prefix_sum -= self.priority_sum[idx * 2]
|
||||||
|
idx = 2 * idx + 1
|
||||||
|
|
||||||
|
return idx - self.capacity
|
||||||
|
|
||||||
|
def sample(self, batch_size, beta):
|
||||||
|
"""
|
||||||
|
### Sample from buffer
|
||||||
|
"""
|
||||||
|
|
||||||
|
samples = {
|
||||||
|
'weights': np.zeros(shape=batch_size, dtype=np.float32),
|
||||||
|
'indexes': np.zeros(shape=batch_size, dtype=np.int32)
|
||||||
|
}
|
||||||
|
|
||||||
|
# get samples
|
||||||
|
for i in range(batch_size):
|
||||||
|
p = random.random() * self._sum()
|
||||||
|
idx = self.find_prefix_sum_idx(p)
|
||||||
|
samples['indexes'][i] = idx
|
||||||
|
|
||||||
|
# $\min_i P(i) = \frac{\min_i p_i^\alpha}{\sum_k p_k^\alpha}$
|
||||||
|
prob_min = self._min() / self._sum()
|
||||||
|
# $\max_i w_i = \bigg(\frac{1}{N} \frac{1}{\min_i P(i)}\bigg)^\beta$
|
||||||
|
max_weight = (prob_min * self.size) ** (-beta)
|
||||||
|
|
||||||
|
for i in range(batch_size):
|
||||||
|
idx = samples['indexes'][i]
|
||||||
|
# $P(i) = \frac{p_i^\alpha}{\sum_k p_k^\alpha}$
|
||||||
|
prob = self.priority_sum[idx + self.capacity] / self._sum()
|
||||||
|
# $w_i = \bigg(\frac{1}{N} \frac{1}{P(i)}\bigg)^\beta$
|
||||||
|
weight = (prob * self.size) ** (-beta)
|
||||||
|
# normalize by $\frac{1}{\max_i w_i}$,
|
||||||
|
# which also cancels off the $\frac{1}/{N}$ term
|
||||||
|
samples['weights'][i] = weight / max_weight
|
||||||
|
|
||||||
|
# get samples data
|
||||||
|
for k, v in self.data.items():
|
||||||
|
samples[k] = v[samples['indexes']]
|
||||||
|
|
||||||
|
return samples
|
||||||
|
|
||||||
|
def update_priorities(self, indexes, priorities):
|
||||||
|
"""
|
||||||
|
### Update priorities
|
||||||
|
"""
|
||||||
|
for idx, priority in zip(indexes, priorities):
|
||||||
|
self.max_priority = max(self.max_priority, priority)
|
||||||
|
|
||||||
|
# $p_i^\alpha$
|
||||||
|
priority_alpha = priority ** self.alpha
|
||||||
|
self._set_priority_min(idx, priority_alpha)
|
||||||
|
self._set_priority_sum(idx, priority_alpha)
|
||||||
|
|
||||||
|
def is_full(self):
|
||||||
|
"""
|
||||||
|
### Is the buffer full
|
||||||
|
|
||||||
|
We only start sampling afte the buffer is full.
|
||||||
|
"""
|
||||||
|
return self.capacity == self.size
|
||||||
@ -128,7 +128,7 @@ class Trainer:
|
|||||||
# Value Loss
|
# Value Loss
|
||||||
self.value_loss = ClippedValueFunctionLoss()
|
self.value_loss = ClippedValueFunctionLoss()
|
||||||
|
|
||||||
def sample(self) -> (Dict[str, np.ndarray], List):
|
def sample(self) -> Dict[str, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
### Sample data with current policy
|
### Sample data with current policy
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user