📚 annotations

This commit is contained in:
Varuna Jayasiri
2020-10-25 09:04:47 +05:30
parent 2ddfe1b252
commit b492e7c7ca
4 changed files with 93 additions and 58 deletions

View File

@ -1,9 +1,8 @@
"""
\(
\def\hl1#1{{\color{orange}{#1}}}
\def\blue#1{{\color{cyan}{#1}}}
\def\green#1{{\color{yellowgreen}{#1}}}
\)
# DQN Experiment with Atari Breakout
This experiment trains a Deep Q Network (DQN) to play Atari Breakout game on OpenAI Gym.
It runs the [game environments on multiple processes](../game.html) to sample efficiently.
"""
import numpy as np
@ -16,6 +15,7 @@ from labml_nn.rl.dqn.model import Model
from labml_nn.rl.dqn.replay_buffer import ReplayBuffer
from labml_nn.rl.game import Worker
# Select device
if torch.cuda.is_available():
device = torch.device("cuda:0")
else:
@ -29,17 +29,10 @@ def obs_to_torch(obs: np.ndarray) -> torch.Tensor:
class Trainer:
"""
## <a name="main"></a>Main class
This class runs the training loop.
It initializes TensorFlow, handles logging and monitoring,
and runs workers as multiple processes.
## Trainer
"""
def __init__(self):
"""
### Initialize
"""
# #### Configurations
# number of workers
@ -54,7 +47,7 @@ class Trainer:
# size of mini batch for training
self.mini_batch_size = 32
# exploration as a function of time step
# exploration as a function of updates
self.exploration_coefficient = Piecewise(
[
(0, 1.0),
@ -65,20 +58,21 @@ class Trainer:
# update target network every 250 update
self.update_target_model = 250
# $\beta$ for replay buffer as a function of time steps
# $\beta$ for replay buffer as a function of updates
self.prioritized_replay_beta = Piecewise(
[
(0, 0.4),
(self.updates, 1)
], outside_value=1)
# replay buffer
# replay buffer with $\alpha = 0.6$
self.replay_buffer = ReplayBuffer(2 ** 14, 0.6)
# Model for sampling and training
self.model = Model().to(device)
# target model to get $\color{orange}Q(s';\color{orange}{\theta_i^{-}})$
self.target_model = Model().to(device)
# last observation for each worker
# create workers
self.workers = [Worker(47 + i) for i in range(self.n_workers)]
@ -89,6 +83,7 @@ class Trainer:
for i, worker in enumerate(self.workers):
self.obs[i] = worker.child.recv()
# loss function
self.loss_func = QFuncLoss(0.99)
# optimizer
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=2.5e-4)
@ -99,44 +94,48 @@ class Trainer:
When sampling actions we use a $\epsilon$-greedy strategy, where we
take a greedy action with probabiliy $1 - \epsilon$ and
take a random action with probability $\epsilon$.
We refer to $\epsilon$ as *exploration*.
We refer to $\epsilon$ as `exploration_coefficient`.
"""
# Sampling doesn't need gradients
with torch.no_grad():
# Sample the action with highest Q-value. This is the greedy action.
greedy_action = torch.argmax(q_value, dim=-1)
# Uniformly sample and action
random_action = torch.randint(q_value.shape[-1], greedy_action.shape, device=q_value.device)
# Whether to chose greedy action or the random action
is_choose_rand = torch.rand(greedy_action.shape, device=q_value.device) < exploration_coefficient
# Pick the action based on `is_choose_rand`
return torch.where(is_choose_rand, random_action, greedy_action).cpu().numpy()
def sample(self, exploration_coefficient: float):
"""### Sample data"""
# This doesn't need gradients
with torch.no_grad():
# sample `SAMPLE_STEPS`
# Sample `worker_steps`
for t in range(self.worker_steps):
# sample actions
# Get Q_values for the current observation
q_value = self.model(obs_to_torch(self.obs))
# Sample actions
actions = self._sample_action(q_value, exploration_coefficient)
# run sampled actions on each worker
# Run sampled actions on each worker
for w, worker in enumerate(self.workers):
worker.child.send(("step", actions[w]))
# collect information from each worker
# Collect information from each worker
for w, worker in enumerate(self.workers):
# get results after executing the actions
# Get results after executing the actions
next_obs, reward, done, info = worker.child.recv()
# add transition to replay buffer
# Add transition to replay buffer
self.replay_buffer.add(self.obs[w], actions[w], reward, next_obs, done)
# update episode information
# collect episode info, which is available if an episode finished;
# this includes total reward and length of the episode -
# look at `Game` to see how it works.
# We also add a game frame to it for monitoring.
if info:
tracker.add('reward', info['reward'])
tracker.add('length', info['length'])
@ -145,16 +144,24 @@ class Trainer:
self.obs[w] = next_obs
def train(self, beta: float):
"""
### Train the model
"""
for _ in range(self.train_epochs):
# sample from priority replay buffer
# Sample from priority replay buffer
samples = self.replay_buffer.sample(self.mini_batch_size, beta)
# train network
# Get the predicted Q-value
q_value = self.model(obs_to_torch(samples['obs']))
# Get the Q-values of the next state for [Double Q-learning](index.html).
# Gradients shouldn't propagate for these
with torch.no_grad():
# Get $\color{cyan}Q(s';\color{cyan}{\theta_i})$
double_q_value = self.model(obs_to_torch(samples['next_obs']))
# Get $\color{orange}Q(s';\color{orange}{\theta_i^{-}})$
target_q_value = self.target_model(obs_to_torch(samples['next_obs']))
# Compute Temporal Difference (TD) errors, $\delta$, and the loss, $\mathcal{L}(\theta)$.
td_errors, loss = self.loss_func(q_value,
q_value.new_tensor(samples['action']),
double_q_value, target_q_value,
@ -162,15 +169,18 @@ class Trainer:
q_value.new_tensor(samples['reward']),
q_value.new_tensor(samples['weights']))
# $p_i = |\delta_i| + \epsilon$
# Calculate priorities for replay buffer $p_i = |\delta_i| + \epsilon$
new_priorities = np.abs(td_errors.cpu().numpy()) + 1e-6
# update replay buffer
# Update replay buffer priorities
self.replay_buffer.update_priorities(samples['indexes'], new_priorities)
# compute gradients
# Zero out the previously calculated gradients
self.optimizer.zero_grad()
# Calculate gradients
loss.backward()
# Clip gradients
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=0.5)
# Update parameters based on gradients
self.optimizer.step()
def run_training_loop(self):
@ -178,33 +188,36 @@ class Trainer:
### Run training loop
"""
# copy to target network initially
self.target_model.load_state_dict(self.model.state_dict())
# last 100 episode information
# Last 100 episode information
tracker.set_queue('reward', 100, True)
tracker.set_queue('length', 100, True)
# Copy to target network initially
self.target_model.load_state_dict(self.model.state_dict())
for update in monit.loop(self.updates):
# $\epsilon$, exploration fraction
exploration = self.exploration_coefficient(update)
tracker.add('exploration', exploration)
# $\beta$ for priority replay
# $\beta$ for prioritized replay
beta = self.prioritized_replay_beta(update)
tracker.add('beta', beta)
# sample with current policy
# Sample with current policy
self.sample(exploration)
# Start training after the buffer is full
if self.replay_buffer.is_full():
# train the model
# Train the model
self.train(beta)
# periodically update target network
# Periodically update target network
if update % self.update_target_model == 0:
self.target_model.load_state_dict(self.model.state_dict())
# Save tracked indicators.
tracker.save()
# Add a new line to the screen periodically
if (update + 1) % 1_000 == 0:
logger.log()
@ -217,10 +230,18 @@ class Trainer:
worker.child.send(("close", None))
# ## Run it
if __name__ == "__main__":
def main():
# Create the experiment
experiment.create(name='dqn')
# Initialize the trainer
m = Trainer()
# Run and monitor the experiment
with experiment.start():
m.run_training_loop()
# Stop the workers
m.destroy()
# ## Run it
if __name__ == "__main__":
main()