mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-11-01 03:43:09 +08:00
📚 annotations
This commit is contained in:
@ -3,7 +3,8 @@
|
|||||||
|
|
||||||
* [Proximal Policy Optimization](ppo)
|
* [Proximal Policy Optimization](ppo)
|
||||||
* [This is an experiment](ppo/experiment.html) that runs a PPO agent on Atari Breakout.
|
* [This is an experiment](ppo/experiment.html) that runs a PPO agent on Atari Breakout.
|
||||||
* [Generalized advantage estimation](ppo/gae.html)
|
* [Generalized advantage estimation](ppo/gae.html)
|
||||||
|
* [Deep Q Networks
|
||||||
|
|
||||||
[This is the implementation for OpenAI game wrapper](game.html) that uses `multiprocessing`.
|
[This is the implementation for OpenAI game wrapper](game.html) that uses `multiprocessing`.
|
||||||
"""
|
"""
|
||||||
@ -1,13 +1,12 @@
|
|||||||
"""
|
"""
|
||||||
# Deep Q Networks
|
# Deep Q Networks
|
||||||
|
|
||||||
This is a Deep Q Learning implementation that uses:
|
This is an implementation of paper
|
||||||
|
[Playing Atari with Deep Reinforcement Learning](https://arxiv.org/abs/1312.5602)
|
||||||
|
along with [Dueling Network](model.html), [Prioritized Replay](replay_buffer.html)
|
||||||
|
and Double Q Network.
|
||||||
|
|
||||||
* [Dueling Network](model.html)
|
Here are the [experiment](experiment.html) and [model](model.html) implementation.
|
||||||
* [Prioritized Replay](replay_buffer.html)
|
|
||||||
* Double Q Network
|
|
||||||
|
|
||||||
Here's the [experiment](experiment.html) and [model](model.html).
|
|
||||||
|
|
||||||
\(
|
\(
|
||||||
\def\green#1{{\color{yellowgreen}{#1}}}
|
\def\green#1{{\color{yellowgreen}{#1}}}
|
||||||
|
|||||||
@ -1,9 +1,8 @@
|
|||||||
"""
|
"""
|
||||||
\(
|
# DQN Experiment with Atari Breakout
|
||||||
\def\hl1#1{{\color{orange}{#1}}}
|
|
||||||
\def\blue#1{{\color{cyan}{#1}}}
|
This experiment trains a Deep Q Network (DQN) to play Atari Breakout game on OpenAI Gym.
|
||||||
\def\green#1{{\color{yellowgreen}{#1}}}
|
It runs the [game environments on multiple processes](../game.html) to sample efficiently.
|
||||||
\)
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -16,6 +15,7 @@ from labml_nn.rl.dqn.model import Model
|
|||||||
from labml_nn.rl.dqn.replay_buffer import ReplayBuffer
|
from labml_nn.rl.dqn.replay_buffer import ReplayBuffer
|
||||||
from labml_nn.rl.game import Worker
|
from labml_nn.rl.game import Worker
|
||||||
|
|
||||||
|
# Select device
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda:0")
|
device = torch.device("cuda:0")
|
||||||
else:
|
else:
|
||||||
@ -29,17 +29,10 @@ def obs_to_torch(obs: np.ndarray) -> torch.Tensor:
|
|||||||
|
|
||||||
class Trainer:
|
class Trainer:
|
||||||
"""
|
"""
|
||||||
## <a name="main"></a>Main class
|
## Trainer
|
||||||
This class runs the training loop.
|
|
||||||
It initializes TensorFlow, handles logging and monitoring,
|
|
||||||
and runs workers as multiple processes.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""
|
|
||||||
### Initialize
|
|
||||||
"""
|
|
||||||
|
|
||||||
# #### Configurations
|
# #### Configurations
|
||||||
|
|
||||||
# number of workers
|
# number of workers
|
||||||
@ -54,7 +47,7 @@ class Trainer:
|
|||||||
# size of mini batch for training
|
# size of mini batch for training
|
||||||
self.mini_batch_size = 32
|
self.mini_batch_size = 32
|
||||||
|
|
||||||
# exploration as a function of time step
|
# exploration as a function of updates
|
||||||
self.exploration_coefficient = Piecewise(
|
self.exploration_coefficient = Piecewise(
|
||||||
[
|
[
|
||||||
(0, 1.0),
|
(0, 1.0),
|
||||||
@ -65,20 +58,21 @@ class Trainer:
|
|||||||
# update target network every 250 update
|
# update target network every 250 update
|
||||||
self.update_target_model = 250
|
self.update_target_model = 250
|
||||||
|
|
||||||
# $\beta$ for replay buffer as a function of time steps
|
# $\beta$ for replay buffer as a function of updates
|
||||||
self.prioritized_replay_beta = Piecewise(
|
self.prioritized_replay_beta = Piecewise(
|
||||||
[
|
[
|
||||||
(0, 0.4),
|
(0, 0.4),
|
||||||
(self.updates, 1)
|
(self.updates, 1)
|
||||||
], outside_value=1)
|
], outside_value=1)
|
||||||
|
|
||||||
# replay buffer
|
# replay buffer with $\alpha = 0.6$
|
||||||
self.replay_buffer = ReplayBuffer(2 ** 14, 0.6)
|
self.replay_buffer = ReplayBuffer(2 ** 14, 0.6)
|
||||||
|
|
||||||
|
# Model for sampling and training
|
||||||
self.model = Model().to(device)
|
self.model = Model().to(device)
|
||||||
|
# target model to get $\color{orange}Q(s';\color{orange}{\theta_i^{-}})$
|
||||||
self.target_model = Model().to(device)
|
self.target_model = Model().to(device)
|
||||||
|
|
||||||
# last observation for each worker
|
|
||||||
# create workers
|
# create workers
|
||||||
self.workers = [Worker(47 + i) for i in range(self.n_workers)]
|
self.workers = [Worker(47 + i) for i in range(self.n_workers)]
|
||||||
|
|
||||||
@ -89,6 +83,7 @@ class Trainer:
|
|||||||
for i, worker in enumerate(self.workers):
|
for i, worker in enumerate(self.workers):
|
||||||
self.obs[i] = worker.child.recv()
|
self.obs[i] = worker.child.recv()
|
||||||
|
|
||||||
|
# loss function
|
||||||
self.loss_func = QFuncLoss(0.99)
|
self.loss_func = QFuncLoss(0.99)
|
||||||
# optimizer
|
# optimizer
|
||||||
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=2.5e-4)
|
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=2.5e-4)
|
||||||
@ -99,44 +94,48 @@ class Trainer:
|
|||||||
When sampling actions we use a $\epsilon$-greedy strategy, where we
|
When sampling actions we use a $\epsilon$-greedy strategy, where we
|
||||||
take a greedy action with probabiliy $1 - \epsilon$ and
|
take a greedy action with probabiliy $1 - \epsilon$ and
|
||||||
take a random action with probability $\epsilon$.
|
take a random action with probability $\epsilon$.
|
||||||
We refer to $\epsilon$ as *exploration*.
|
We refer to $\epsilon$ as `exploration_coefficient`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Sampling doesn't need gradients
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
# Sample the action with highest Q-value. This is the greedy action.
|
||||||
greedy_action = torch.argmax(q_value, dim=-1)
|
greedy_action = torch.argmax(q_value, dim=-1)
|
||||||
|
# Uniformly sample and action
|
||||||
random_action = torch.randint(q_value.shape[-1], greedy_action.shape, device=q_value.device)
|
random_action = torch.randint(q_value.shape[-1], greedy_action.shape, device=q_value.device)
|
||||||
|
# Whether to chose greedy action or the random action
|
||||||
is_choose_rand = torch.rand(greedy_action.shape, device=q_value.device) < exploration_coefficient
|
is_choose_rand = torch.rand(greedy_action.shape, device=q_value.device) < exploration_coefficient
|
||||||
|
# Pick the action based on `is_choose_rand`
|
||||||
return torch.where(is_choose_rand, random_action, greedy_action).cpu().numpy()
|
return torch.where(is_choose_rand, random_action, greedy_action).cpu().numpy()
|
||||||
|
|
||||||
def sample(self, exploration_coefficient: float):
|
def sample(self, exploration_coefficient: float):
|
||||||
"""### Sample data"""
|
"""### Sample data"""
|
||||||
|
|
||||||
|
# This doesn't need gradients
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
# sample `SAMPLE_STEPS`
|
# Sample `worker_steps`
|
||||||
for t in range(self.worker_steps):
|
for t in range(self.worker_steps):
|
||||||
# sample actions
|
# Get Q_values for the current observation
|
||||||
q_value = self.model(obs_to_torch(self.obs))
|
q_value = self.model(obs_to_torch(self.obs))
|
||||||
|
# Sample actions
|
||||||
actions = self._sample_action(q_value, exploration_coefficient)
|
actions = self._sample_action(q_value, exploration_coefficient)
|
||||||
|
|
||||||
# run sampled actions on each worker
|
# Run sampled actions on each worker
|
||||||
for w, worker in enumerate(self.workers):
|
for w, worker in enumerate(self.workers):
|
||||||
worker.child.send(("step", actions[w]))
|
worker.child.send(("step", actions[w]))
|
||||||
|
|
||||||
# collect information from each worker
|
# Collect information from each worker
|
||||||
for w, worker in enumerate(self.workers):
|
for w, worker in enumerate(self.workers):
|
||||||
# get results after executing the actions
|
# Get results after executing the actions
|
||||||
next_obs, reward, done, info = worker.child.recv()
|
next_obs, reward, done, info = worker.child.recv()
|
||||||
|
|
||||||
# add transition to replay buffer
|
# Add transition to replay buffer
|
||||||
self.replay_buffer.add(self.obs[w], actions[w], reward, next_obs, done)
|
self.replay_buffer.add(self.obs[w], actions[w], reward, next_obs, done)
|
||||||
|
|
||||||
# update episode information
|
# update episode information
|
||||||
# collect episode info, which is available if an episode finished;
|
# collect episode info, which is available if an episode finished;
|
||||||
# this includes total reward and length of the episode -
|
# this includes total reward and length of the episode -
|
||||||
# look at `Game` to see how it works.
|
# look at `Game` to see how it works.
|
||||||
# We also add a game frame to it for monitoring.
|
|
||||||
if info:
|
if info:
|
||||||
tracker.add('reward', info['reward'])
|
tracker.add('reward', info['reward'])
|
||||||
tracker.add('length', info['length'])
|
tracker.add('length', info['length'])
|
||||||
@ -145,16 +144,24 @@ class Trainer:
|
|||||||
self.obs[w] = next_obs
|
self.obs[w] = next_obs
|
||||||
|
|
||||||
def train(self, beta: float):
|
def train(self, beta: float):
|
||||||
|
"""
|
||||||
|
### Train the model
|
||||||
|
"""
|
||||||
for _ in range(self.train_epochs):
|
for _ in range(self.train_epochs):
|
||||||
# sample from priority replay buffer
|
# Sample from priority replay buffer
|
||||||
samples = self.replay_buffer.sample(self.mini_batch_size, beta)
|
samples = self.replay_buffer.sample(self.mini_batch_size, beta)
|
||||||
# train network
|
# Get the predicted Q-value
|
||||||
q_value = self.model(obs_to_torch(samples['obs']))
|
q_value = self.model(obs_to_torch(samples['obs']))
|
||||||
|
|
||||||
|
# Get the Q-values of the next state for [Double Q-learning](index.html).
|
||||||
|
# Gradients shouldn't propagate for these
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
# Get $\color{cyan}Q(s';\color{cyan}{\theta_i})$
|
||||||
double_q_value = self.model(obs_to_torch(samples['next_obs']))
|
double_q_value = self.model(obs_to_torch(samples['next_obs']))
|
||||||
|
# Get $\color{orange}Q(s';\color{orange}{\theta_i^{-}})$
|
||||||
target_q_value = self.target_model(obs_to_torch(samples['next_obs']))
|
target_q_value = self.target_model(obs_to_torch(samples['next_obs']))
|
||||||
|
|
||||||
|
# Compute Temporal Difference (TD) errors, $\delta$, and the loss, $\mathcal{L}(\theta)$.
|
||||||
td_errors, loss = self.loss_func(q_value,
|
td_errors, loss = self.loss_func(q_value,
|
||||||
q_value.new_tensor(samples['action']),
|
q_value.new_tensor(samples['action']),
|
||||||
double_q_value, target_q_value,
|
double_q_value, target_q_value,
|
||||||
@ -162,15 +169,18 @@ class Trainer:
|
|||||||
q_value.new_tensor(samples['reward']),
|
q_value.new_tensor(samples['reward']),
|
||||||
q_value.new_tensor(samples['weights']))
|
q_value.new_tensor(samples['weights']))
|
||||||
|
|
||||||
# $p_i = |\delta_i| + \epsilon$
|
# Calculate priorities for replay buffer $p_i = |\delta_i| + \epsilon$
|
||||||
new_priorities = np.abs(td_errors.cpu().numpy()) + 1e-6
|
new_priorities = np.abs(td_errors.cpu().numpy()) + 1e-6
|
||||||
# update replay buffer
|
# Update replay buffer priorities
|
||||||
self.replay_buffer.update_priorities(samples['indexes'], new_priorities)
|
self.replay_buffer.update_priorities(samples['indexes'], new_priorities)
|
||||||
|
|
||||||
# compute gradients
|
# Zero out the previously calculated gradients
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
|
# Calculate gradients
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
# Clip gradients
|
||||||
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=0.5)
|
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=0.5)
|
||||||
|
# Update parameters based on gradients
|
||||||
self.optimizer.step()
|
self.optimizer.step()
|
||||||
|
|
||||||
def run_training_loop(self):
|
def run_training_loop(self):
|
||||||
@ -178,33 +188,36 @@ class Trainer:
|
|||||||
### Run training loop
|
### Run training loop
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# copy to target network initially
|
# Last 100 episode information
|
||||||
self.target_model.load_state_dict(self.model.state_dict())
|
|
||||||
|
|
||||||
# last 100 episode information
|
|
||||||
tracker.set_queue('reward', 100, True)
|
tracker.set_queue('reward', 100, True)
|
||||||
tracker.set_queue('length', 100, True)
|
tracker.set_queue('length', 100, True)
|
||||||
|
|
||||||
|
# Copy to target network initially
|
||||||
|
self.target_model.load_state_dict(self.model.state_dict())
|
||||||
|
|
||||||
for update in monit.loop(self.updates):
|
for update in monit.loop(self.updates):
|
||||||
# $\epsilon$, exploration fraction
|
# $\epsilon$, exploration fraction
|
||||||
exploration = self.exploration_coefficient(update)
|
exploration = self.exploration_coefficient(update)
|
||||||
tracker.add('exploration', exploration)
|
tracker.add('exploration', exploration)
|
||||||
# $\beta$ for priority replay
|
# $\beta$ for prioritized replay
|
||||||
beta = self.prioritized_replay_beta(update)
|
beta = self.prioritized_replay_beta(update)
|
||||||
tracker.add('beta', beta)
|
tracker.add('beta', beta)
|
||||||
|
|
||||||
# sample with current policy
|
# Sample with current policy
|
||||||
self.sample(exploration)
|
self.sample(exploration)
|
||||||
|
|
||||||
|
# Start training after the buffer is full
|
||||||
if self.replay_buffer.is_full():
|
if self.replay_buffer.is_full():
|
||||||
# train the model
|
# Train the model
|
||||||
self.train(beta)
|
self.train(beta)
|
||||||
|
|
||||||
# periodically update target network
|
# Periodically update target network
|
||||||
if update % self.update_target_model == 0:
|
if update % self.update_target_model == 0:
|
||||||
self.target_model.load_state_dict(self.model.state_dict())
|
self.target_model.load_state_dict(self.model.state_dict())
|
||||||
|
|
||||||
|
# Save tracked indicators.
|
||||||
tracker.save()
|
tracker.save()
|
||||||
|
# Add a new line to the screen periodically
|
||||||
if (update + 1) % 1_000 == 0:
|
if (update + 1) % 1_000 == 0:
|
||||||
logger.log()
|
logger.log()
|
||||||
|
|
||||||
@ -217,10 +230,18 @@ class Trainer:
|
|||||||
worker.child.send(("close", None))
|
worker.child.send(("close", None))
|
||||||
|
|
||||||
|
|
||||||
# ## Run it
|
def main():
|
||||||
if __name__ == "__main__":
|
# Create the experiment
|
||||||
experiment.create(name='dqn')
|
experiment.create(name='dqn')
|
||||||
|
# Initialize the trainer
|
||||||
m = Trainer()
|
m = Trainer()
|
||||||
|
# Run and monitor the experiment
|
||||||
with experiment.start():
|
with experiment.start():
|
||||||
m.run_training_loop()
|
m.run_training_loop()
|
||||||
|
# Stop the workers
|
||||||
m.destroy()
|
m.destroy()
|
||||||
|
|
||||||
|
|
||||||
|
# ## Run it
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|||||||
@ -1,11 +1,11 @@
|
|||||||
"""
|
"""
|
||||||
# PPO Experiment with Atari Breakout
|
# PPO Experiment with Atari Breakout
|
||||||
|
|
||||||
This experiment runs PPO Atari Breakout game on OpenAI Gym.
|
This experiment trains Proximal Policy Optimization (PPO) agent Atari Breakout game on OpenAI Gym.
|
||||||
It runs the [game environments on multiple processes](game.html) to sample efficiently.
|
It runs the [game environments on multiple processes](../game.html) to sample efficiently.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Dict, List
|
from typing import Dict
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -15,10 +15,11 @@ from torch.distributions import Categorical
|
|||||||
|
|
||||||
from labml import monit, tracker, logger, experiment
|
from labml import monit, tracker, logger, experiment
|
||||||
from labml_helpers.module import Module
|
from labml_helpers.module import Module
|
||||||
|
from labml_nn.rl.game import Worker
|
||||||
from labml_nn.rl.ppo import ClippedPPOLoss, ClippedValueFunctionLoss
|
from labml_nn.rl.ppo import ClippedPPOLoss, ClippedValueFunctionLoss
|
||||||
from labml_nn.rl.ppo.gae import GAE
|
from labml_nn.rl.ppo.gae import GAE
|
||||||
from labml_nn.rl.game import Worker
|
|
||||||
|
|
||||||
|
# Select device
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda:0")
|
device = torch.device("cuda:0")
|
||||||
else:
|
else:
|
||||||
@ -82,6 +83,7 @@ class Trainer:
|
|||||||
"""
|
"""
|
||||||
## Trainer
|
## Trainer
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
# #### Configurations
|
# #### Configurations
|
||||||
|
|
||||||
@ -165,7 +167,6 @@ class Trainer:
|
|||||||
# collect episode info, which is available if an episode finished;
|
# collect episode info, which is available if an episode finished;
|
||||||
# this includes total reward and length of the episode -
|
# this includes total reward and length of the episode -
|
||||||
# look at `Game` to see how it works.
|
# look at `Game` to see how it works.
|
||||||
# We also add a game frame to it for monitoring.
|
|
||||||
if info:
|
if info:
|
||||||
tracker.add('reward', info['reward'])
|
tracker.add('reward', info['reward'])
|
||||||
tracker.add('length', info['length'])
|
tracker.add('length', info['length'])
|
||||||
@ -225,12 +226,16 @@ class Trainer:
|
|||||||
loss = self._calc_loss(clip_range=clip_range,
|
loss = self._calc_loss(clip_range=clip_range,
|
||||||
samples=mini_batch)
|
samples=mini_batch)
|
||||||
|
|
||||||
# compute gradients
|
# Set learning rate
|
||||||
for pg in self.optimizer.param_groups:
|
for pg in self.optimizer.param_groups:
|
||||||
pg['lr'] = learning_rate
|
pg['lr'] = learning_rate
|
||||||
|
# Zero out the previously calculated gradients
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
|
# Calculate gradients
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
# Clip gradients
|
||||||
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=0.5)
|
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=0.5)
|
||||||
|
# Update parameters based on gradients
|
||||||
self.optimizer.step()
|
self.optimizer.step()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -311,8 +316,9 @@ class Trainer:
|
|||||||
# train the model
|
# train the model
|
||||||
self.train(samples, learning_rate, clip_range)
|
self.train(samples, learning_rate, clip_range)
|
||||||
|
|
||||||
# write summary info to the writer, and log to the screen
|
# Save tracked indicators.
|
||||||
tracker.save()
|
tracker.save()
|
||||||
|
# Add a new line to the screen periodically
|
||||||
if (update + 1) % 1_000 == 0:
|
if (update + 1) % 1_000 == 0:
|
||||||
logger.log()
|
logger.log()
|
||||||
|
|
||||||
@ -325,10 +331,18 @@ class Trainer:
|
|||||||
worker.child.send(("close", None))
|
worker.child.send(("close", None))
|
||||||
|
|
||||||
|
|
||||||
# ## Run it
|
def main():
|
||||||
if __name__ == "__main__":
|
# Create the experiment
|
||||||
experiment.create(name='ppo')
|
experiment.create(name='ppo')
|
||||||
|
# Initialize the trainer
|
||||||
m = Trainer()
|
m = Trainer()
|
||||||
|
# Run and monitor the experiment
|
||||||
with experiment.start():
|
with experiment.start():
|
||||||
m.run_training_loop()
|
m.run_training_loop()
|
||||||
|
# Stop the workers
|
||||||
m.destroy()
|
m.destroy()
|
||||||
|
|
||||||
|
|
||||||
|
# ## Run it
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|||||||
Reference in New Issue
Block a user