This commit is contained in:
Varuna Jayasiri
2020-10-24 13:34:26 +05:30
parent b4d35f871e
commit 7cc9f300e9
4 changed files with 672 additions and 1 deletions

View File

@ -0,0 +1,50 @@
"""
This is a Deep Q Learning implementation with:
* Double Q Network
* Dueling Network
* Prioritized Replay
"""
from typing import Tuple
import torch
from labml import tracker
from labml_helpers.module import Module
from labml_nn.rl.dqn.replay_buffer import ReplayBuffer
class QFuncLoss(Module):
def __init__(self, gamma: float):
super().__init__()
self.gamma = gamma
def __call__(self, q: torch.Tensor,
action: torch.Tensor,
double_q: torch.Tensor,
target_q: torch.Tensor,
done: torch.Tensor,
reward: torch.Tensor,
weights: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
q_sampled_action = q.gather(-1, action.to(torch.long).unsqueeze(-1)).squeeze(-1)
tracker.add('q_sampled_action', q_sampled_action)
with torch.no_grad():
best_next_action = torch.argmax(double_q, -1)
best_next_q_value = target_q.gather(-1, best_next_action.unsqueeze(-1)).squeeze(-1)
best_next_q_value *= (1 - done)
q_update = reward + self.gamma * best_next_q_value
tracker.add('q_update', q_update)
td_error = q_sampled_action - q_update
tracker.add('td_error', td_error)
# Huber loss
losses = torch.nn.functional.smooth_l1_loss(q_sampled_action, q_update, reduction='none')
loss = torch.mean(weights * losses)
tracker.add('loss', loss)
return td_error, loss

View File

@ -0,0 +1,385 @@
"""
\(
\def\hl1#1{{\color{orange}{#1}}}
\def\blue#1{{\color{cyan}{#1}}}
\def\green#1{{\color{yellowgreen}{#1}}}
\)
"""
import numpy as np
import torch
from torch import nn
from labml import tracker, experiment, logger, monit
from labml_helpers.module import Module
from labml_helpers.schedule import Piecewise
from labml_nn.rl.dqn import QFuncLoss
from labml_nn.rl.dqn.replay_buffer import ReplayBuffer
from labml_nn.rl.game import Worker
if torch.cuda.is_available():
device = torch.device("cuda:0")
else:
device = torch.device("cpu")
class Model(Module):
"""
## <a name="model"></a>Neural Network Model for $Q$ Values
#### Dueling Network ⚔️
We are using a [dueling network](https://arxiv.org/abs/1511.06581)
to calculate Q-values.
Intuition behind dueling network architure is that in most states
the action doesn't matter,
and in some states the action is significant. Dueling network allows
this to be represented very well.
\begin{align}
Q^\pi(s,a) &= V^\pi(s) + A^\pi(s, a)
\\
\mathop{\mathbb{E}}_{a \sim \pi(s)}
\Big[
A^\pi(s, a)
\Big]
&= 0
\end{align}
So we create two networks for $V$ and $A$ and get $Q$ from them.
$$
Q(s, a) = V(s) +
\Big(
A(s, a) - \frac{1}{|\mathcal{A}|} \sum_{a' \in \mathcal{A}} A(s, a')
\Big)
$$
We share the initial layers of the $V$ and $A$ networks.
"""
def __init__(self):
"""
### Initialize
We need `scope` because we need multiple copies of variables
for target network and training network.
"""
super().__init__()
self.conv = nn.Sequential(
# The first convolution layer takes a
# 84x84 frame and produces a 20x20 frame
nn.Conv2d(in_channels=4, out_channels=32, kernel_size=8, stride=4),
nn.ReLU(),
# The second convolution layer takes a
# 20x20 frame and produces a 9x9 frame
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2),
nn.ReLU(),
# The third convolution layer takes a
# 9x9 frame and produces a 7x7 frame
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1),
nn.ReLU(),
)
# A fully connected layer takes the flattened
# frame from third convolution layer, and outputs
# 512 features
self.lin = nn.Linear(in_features=7 * 7 * 64, out_features=512)
self.state_score = nn.Sequential(
nn.Linear(in_features=512, out_features=256),
nn.ReLU(),
nn.Linear(in_features=256, out_features=1),
)
self.action_score = nn.Sequential(
nn.Linear(in_features=512, out_features=256),
nn.ReLU(),
nn.Linear(in_features=256, out_features=4),
)
#
self.activation = nn.ReLU()
def __call__(self, obs: torch.Tensor):
h = self.conv(obs)
h = h.reshape((-1, 7 * 7 * 64))
h = self.activation(self.lin(h))
action_score = self.action_score(h)
state_score = self.state_score(h)
# $Q(s, a) =V(s) + \Big(A(s, a) - \frac{1}{|\mathcal{A}|} \sum_{a' \in \mathcal{A}} A(s, a')\Big)$
action_score_centered = action_score - action_score.mean(dim=-1, keepdim=True)
q = state_score + action_score_centered
return q
def obs_to_torch(obs: np.ndarray) -> torch.Tensor:
"""Scale observations from `[0, 255]` to `[0, 1]`"""
return torch.tensor(obs, dtype=torch.float32, device=device) / 255.
class Main(object):
"""
## <a name="main"></a>Main class
This class runs the training loop.
It initializes TensorFlow, handles logging and monitoring,
and runs workers as multiple processes.
"""
def __init__(self):
"""
### Initialize
"""
# #### Configurations
# number of workers
self.n_workers = 8
# steps sampled on each update
self.worker_steps = 4
# number of training iterations
self.train_epochs = 8
# number of updates
self.updates = 1_000_000
# size of mini batch for training
self.mini_batch_size = 32
# exploration as a function of time step
self.exploration_coefficient = Piecewise(
[
(0, 1.0),
(25_000, 0.1),
(self.updates / 2, 0.01)
], outside_value=0.01)
# update target network every 250 update
self.update_target_model = 250
# $\beta$ for replay buffer as a function of time steps
self.prioritized_replay_beta = Piecewise(
[
(0, 0.4),
(self.updates, 1)
], outside_value=1)
# replay buffer
self.replay_buffer = ReplayBuffer(2 ** 14, 0.6)
self.model = Model().to(device)
self.target_model = Model().to(device)
# last observation for each worker
# create workers
self.workers = [Worker(47 + i) for i in range(self.n_workers)]
# initialize tensors for observations
self.obs = np.zeros((self.n_workers, 4, 84, 84), dtype=np.uint8)
for worker in self.workers:
worker.child.send(("reset", None))
for i, worker in enumerate(self.workers):
self.obs[i] = worker.child.recv()
self.loss_func = QFuncLoss(0.99)
# optimizer
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=2.5e-4)
def _sample_action(self, q_value: torch.Tensor, exploration_coefficient: float):
"""
#### $\epsilon$-greedy Sampling
When sampling actions we use a $\epsilon$-greedy strategy, where we
take a greedy action with probabiliy $1 - \epsilon$ and
take a random action with probability $\epsilon$.
We refer to $\epsilon$ as *exploration*.
"""
with torch.no_grad():
greedy_action = torch.argmax(q_value, dim=-1)
random_action = torch.randint(q_value.shape[-1], greedy_action.shape, device=q_value.device)
is_choose_rand = torch.rand(greedy_action.shape, device=q_value.device) < exploration_coefficient
return torch.where(is_choose_rand, random_action, greedy_action).cpu().numpy()
def sample(self, exploration_coefficient: float):
"""### Sample data"""
with torch.no_grad():
# sample `SAMPLE_STEPS`
for t in range(self.worker_steps):
# sample actions
q_value = self.model(obs_to_torch(self.obs))
actions = self._sample_action(q_value, exploration_coefficient)
# run sampled actions on each worker
for w, worker in enumerate(self.workers):
worker.child.send(("step", actions[w]))
# collect information from each worker
for w, worker in enumerate(self.workers):
# get results after executing the actions
next_obs, reward, done, info = worker.child.recv()
# add transition to replay buffer
self.replay_buffer.add(self.obs[w], actions[w], reward, next_obs, done)
# update episode information
# collect episode info, which is available if an episode finished;
# this includes total reward and length of the episode -
# look at `Game` to see how it works.
# We also add a game frame to it for monitoring.
if info:
tracker.add('reward', info['reward'])
tracker.add('length', info['length'])
# update current observation
self.obs[w] = next_obs
def train(self, beta: float):
"""
### Train the model
We want to find optimal action-value function.
\begin{align}
Q^*(s,a) &= \max_\pi \mathbb{E} \Big[
r_t + \gamma r_{t + 1} + \gamma^2 r_{t + 2} + ... | s_t = s, a_t = a, \pi
\Big]
\\
Q^*(s,a) &= \mathop{\mathbb{E}}_{s' \sim \large{\varepsilon}} \Big[
r + \gamma \max_{a'} Q^* (s', a') | s, a
\Big]
\end{align}
#### Target network 🎯
In order to improve stability we use experience replay that randomly sample
from previous experience $U(D)$. We also use a Q network
with a separate set of paramters $\hl1{\theta_i^{-}}$ to calculate the target.
$\hl1{\theta_i^{-}}$ is updated periodically.
This is according to the [paper by DeepMind](https://deepmind.com/research/dqn/).
So the loss function is,
$$
\mathcal{L}_i(\theta_i) = \mathop{\mathbb{E}}_{(s,a,r,s') \sim U(D)}
\bigg[
\Big(
r + \gamma \max_{a'} Q(s', a'; \hl1{\theta_i^{-}}) - Q(s,a;\theta_i)
\Big) ^ 2
\bigg]
$$
#### Double $Q$-Learning
The max operator in the above calculation uses same network for both
selecting the best action and for evaluating the value.
That is,
$$
\max_{a'} Q(s', a'; \theta) = \blue{Q}
\Big(
s', \mathop{\operatorname{argmax}}_{a'}
\blue{Q}(s', a'; \blue{\theta}); \blue{\theta}
\Big)
$$
We use [double Q-learning](https://arxiv.org/abs/1509.06461), where
the $\operatorname{argmax}$ is taken from $\theta_i$ and
the value is taken from $\theta_i^{-}$.
And the loss function becomes,
\begin{align}
\mathcal{L}_i(\theta_i) = \mathop{\mathbb{E}}_{(s,a,r,s') \sim U(D)}
\Bigg[
\bigg(
&r + \gamma \blue{Q}
\Big(
s',
\mathop{\operatorname{argmax}}_{a'}
\green{Q}(s', a'; \green{\theta_i}); \blue{\theta_i^{-}}
\Big)
\\
- &Q(s,a;\theta_i)
\bigg) ^ 2
\Bigg]
\end{align}
"""
for _ in range(self.train_epochs):
# sample from priority replay buffer
samples = self.replay_buffer.sample(self.mini_batch_size, beta)
# train network
q_value = self.model(obs_to_torch(samples['obs']))
with torch.no_grad():
double_q_value = self.model(obs_to_torch(samples['next_obs']))
target_q_value = self.target_model(obs_to_torch(samples['next_obs']))
td_errors, loss = self.loss_func(q_value,
q_value.new_tensor(samples['action']),
double_q_value, target_q_value,
q_value.new_tensor(samples['done']),
q_value.new_tensor(samples['reward']),
q_value.new_tensor(samples['weights']))
# $p_i = |\delta_i| + \epsilon$
new_priorities = np.abs(td_errors.cpu().numpy()) + 1e-6
# update replay buffer
self.replay_buffer.update_priorities(samples['indexes'], new_priorities)
# compute gradients
self.optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=0.5)
self.optimizer.step()
def run_training_loop(self):
"""
### Run training loop
"""
# copy to target network initially
self.target_model.load_state_dict(self.model.state_dict())
# last 100 episode information
tracker.set_queue('reward', 100, True)
tracker.set_queue('length', 100, True)
for update in monit.loop(self.updates):
# $\epsilon$, exploration fraction
exploration = self.exploration_coefficient(update)
tracker.add('exploration', exploration)
# $\beta$ for priority replay
beta = self.prioritized_replay_beta(update)
tracker.add('beta', beta)
# sample with current policy
self.sample(exploration)
if self.replay_buffer.is_full():
# train the model
self.train(beta)
# periodically update target network
if update % self.update_target_model == 0:
self.target_model.load_state_dict(self.model.state_dict())
tracker.save()
if (update + 1) % 1_000 == 0:
logger.log()
def destroy(self):
"""
### Destroy
Stop the workers
"""
for worker in self.workers:
worker.child.send(("close", None))
# ## Run it
if __name__ == "__main__":
experiment.create(name='dqn')
m = Main()
with experiment.start():
m.run_training_loop()
m.destroy()

View File

@ -0,0 +1,236 @@
import numpy as np
import random
class ReplayBuffer:
"""
## Buffer for Prioritized Experience Replay
[Prioritized experience replay](https://arxiv.org/abs/1511.05952)
samples important transitions more frequently.
The transitions are prioritized by the Temporal Difference error.
We sample transition $i$ with probability,
$$P(i) = \frac{p_i^\alpha}{\sum_k p_k^\alpha}$$
where $\alpha$ is a hyper-parameter that determines how much
prioritization is used, with $\alpha = 0$ corresponding to uniform case.
We use proportional prioritization $p_i = |\delta_i| + \epsilon$ where
$\delta_i$ is the temporal difference for transition $i$.
We correct the bias introduced by prioritized replay by
importance-sampling (IS) weights
$$w_i = \bigg(\frac{1}{N} \frac{1}{P(i)}\bigg)^\beta$$
that fully compensates for when $\beta = 1$.
We normalize weights by $1/\max_i w_i$ for stability.
Unbiased nature is most important towards the convergence at end of training.
Therefore we increase $\beta$ towards end of training.
### Binary Segment Trees
We use binary segment trees to efficiently calculate
$\sum_k^i p_k^\alpha$, the cumulative probability,
which is needed to sample.
We also use a binary segment tree to find $\min p_i^\alpha$,
which is needed for $1/\max_i w_i$.
We can also use a min-heap for this.
This is how a binary segment tree works for sum;
it is similar for minimum.
Let $x_i$ be the list of $N$ values we want to represent.
Let $b_{i,j}$ be the $j^{\mathop{th}}$ node of the $i^{\mathop{th}}$ row
in the binary tree.
That is two children of node $b_{i,j}$ are $b_{i+1,2j}$ and $b_{i+1,2j + 1}$.
The leaf nodes on row $D = \left\lceil {1 + \log_2 N} \right\rceil$
will have values of $x$.
Every node keeps the sum of the two child nodes.
So the root node keeps the sum of the entire array of values.
The two children of the root node keep
the sum of the first half of the array and
the sum of the second half of the array, and so on.
$$b_{i,j} = \sum_{k = (j -1) * 2^{D - i} + 1}^{j * 2^{D - i}} x_k$$
Number of nodes in row $i$,
$$N_i = \left\lceil{\frac{N}{D - i + i}} \right\rceil$$
This is equal to the sum of nodes in all rows above $i$.
So we can use a single array $a$ to store the tree, where,
$$b_{i,j} = a_{N_1 + j}$$
Then child nodes of $a_i$ are $a_{2i}$ and $a_{2i + 1}$.
That is,
$$a_i = a_{2i} + a_{2i + 1}$$
This way of maintaining binary trees is very easy to program.
*Note that we are indexing from 1*.
"""
def __init__(self, capacity, alpha):
"""
### Initialize
"""
# we use a power of 2 for capacity to make it easy to debug
self.capacity = capacity
# we refill the queue once it reaches capacity
self.next_idx = 0
# $\alpha$
self.alpha = alpha
# maintain segment binary trees to take sum and find minimum over a range
self.priority_sum = [0 for _ in range(2 * self.capacity)]
self.priority_min = [float('inf') for _ in range(2 * self.capacity)]
# current max priority, $p$, to be assigned to new transitions
self.max_priority = 1.
# arrays for buffer
self.data = {
'obs': np.zeros(shape=(capacity, 4, 84, 84), dtype=np.uint8),
'action': np.zeros(shape=capacity, dtype=np.int32),
'reward': np.zeros(shape=capacity, dtype=np.float32),
'next_obs': np.zeros(shape=(capacity, 4, 84, 84), dtype=np.uint8),
'done': np.zeros(shape=capacity, dtype=np.bool)
}
# size of the buffer
self.size = 0
def add(self, obs, action, reward, next_obs, done):
"""
### Add sample to queue
"""
idx = self.next_idx
# store in the queue
self.data['obs'][idx] = obs
self.data['action'][idx] = action
self.data['reward'][idx] = reward
self.data['next_obs'][idx] = next_obs
self.data['done'][idx] = done
# increment head of the queue and calculate the size
self.next_idx = (idx + 1) % self.capacity
self.size = min(self.capacity, self.size + 1)
# $p_i^\alpha$, new samples get `max_priority`
priority_alpha = self.max_priority ** self.alpha
self._set_priority_min(idx, priority_alpha)
self._set_priority_sum(idx, priority_alpha)
def _set_priority_min(self, idx, priority_alpha):
"""
#### Set priority in binary segment tree for minimum
"""
# leaf of the binary tree
idx += self.capacity
self.priority_min[idx] = priority_alpha
# update tree, by traversing along ancestors
while idx >= 2:
idx //= 2
self.priority_min[idx] = min(self.priority_min[2 * idx],
self.priority_min[2 * idx + 1])
def _set_priority_sum(self, idx, priority):
"""
#### Set priority in binary segment tree for sum
"""
# leaf of the binary tree
idx += self.capacity
self.priority_sum[idx] = priority
# update tree, by traversing along ancestors
while idx >= 2:
idx //= 2
self.priority_sum[idx] = self.priority_sum[2 * idx] + self.priority_sum[2 * idx + 1]
def _sum(self):
"""
#### $\sum_k p_k^\alpha$
"""
return self.priority_sum[1]
def _min(self):
"""
#### $\min_k p_k^\alpha$
"""
return self.priority_min[1]
def find_prefix_sum_idx(self, prefix_sum):
"""
#### Find largest $i$ such that $\sum_{k=1}^{i} p_k^\alpha \le P$
"""
# start from the root
idx = 1
while idx < self.capacity:
# if the sum of the left branch is higher than required sum
if self.priority_sum[idx * 2] > prefix_sum:
# go to left branch if the tree if the
idx = 2 * idx
else:
# otherwise go to right branch and reduce the sum of left
# branch from required sum
prefix_sum -= self.priority_sum[idx * 2]
idx = 2 * idx + 1
return idx - self.capacity
def sample(self, batch_size, beta):
"""
### Sample from buffer
"""
samples = {
'weights': np.zeros(shape=batch_size, dtype=np.float32),
'indexes': np.zeros(shape=batch_size, dtype=np.int32)
}
# get samples
for i in range(batch_size):
p = random.random() * self._sum()
idx = self.find_prefix_sum_idx(p)
samples['indexes'][i] = idx
# $\min_i P(i) = \frac{\min_i p_i^\alpha}{\sum_k p_k^\alpha}$
prob_min = self._min() / self._sum()
# $\max_i w_i = \bigg(\frac{1}{N} \frac{1}{\min_i P(i)}\bigg)^\beta$
max_weight = (prob_min * self.size) ** (-beta)
for i in range(batch_size):
idx = samples['indexes'][i]
# $P(i) = \frac{p_i^\alpha}{\sum_k p_k^\alpha}$
prob = self.priority_sum[idx + self.capacity] / self._sum()
# $w_i = \bigg(\frac{1}{N} \frac{1}{P(i)}\bigg)^\beta$
weight = (prob * self.size) ** (-beta)
# normalize by $\frac{1}{\max_i w_i}$,
# which also cancels off the $\frac{1}/{N}$ term
samples['weights'][i] = weight / max_weight
# get samples data
for k, v in self.data.items():
samples[k] = v[samples['indexes']]
return samples
def update_priorities(self, indexes, priorities):
"""
### Update priorities
"""
for idx, priority in zip(indexes, priorities):
self.max_priority = max(self.max_priority, priority)
# $p_i^\alpha$
priority_alpha = priority ** self.alpha
self._set_priority_min(idx, priority_alpha)
self._set_priority_sum(idx, priority_alpha)
def is_full(self):
"""
### Is the buffer full
We only start sampling afte the buffer is full.
"""
return self.capacity == self.size

View File

@ -128,7 +128,7 @@ class Trainer:
# Value Loss
self.value_loss = ClippedValueFunctionLoss()
def sample(self) -> (Dict[str, np.ndarray], List):
def sample(self) -> Dict[str, torch.Tensor]:
"""
### Sample data with current policy
"""