This experiment trains Proximal Policy Optimization (PPO) agent Atari Breakout game on OpenAI Gym. It runs the game environments on multiple processes to sample efficiently.
16from typing import Dict
17
18import numpy as np
19import torch
20from torch import nn
21from torch import optim
22from torch.distributions import Categorical
23
24from labml import monit, tracker, logger, experiment
25from labml.configs import FloatDynamicHyperParam, IntDynamicHyperParam
26from labml_helpers.module import Module
27from labml_nn.rl.game import Worker
28from labml_nn.rl.ppo import ClippedPPOLoss, ClippedValueFunctionLoss
29from labml_nn.rl.ppo.gae import GAE
Select device
32if torch.cuda.is_available():
33 device = torch.device("cuda:0")
34else:
35 device = torch.device("cpu")
38class Model(Module):
43 def __init__(self):
44 super().__init__()
The first convolution layer takes a 84x84 frame and produces a 20x20 frame
48 self.conv1 = nn.Conv2d(in_channels=4, out_channels=32, kernel_size=8, stride=4)
The second convolution layer takes a 20x20 frame and produces a 9x9 frame
52 self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2)
The third convolution layer takes a 9x9 frame and produces a 7x7 frame
56 self.conv3 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1)
A fully connected layer takes the flattened frame from third convolution layer, and outputs 512 features
61 self.lin = nn.Linear(in_features=7 * 7 * 64, out_features=512)
A fully connected layer to get logits for $\pi$
64 self.pi_logits = nn.Linear(in_features=512, out_features=4)
A fully connected layer to get value function
67 self.value = nn.Linear(in_features=512, out_features=1)
70 self.activation = nn.ReLU()
72 def __call__(self, obs: torch.Tensor):
73 h = self.activation(self.conv1(obs))
74 h = self.activation(self.conv2(h))
75 h = self.activation(self.conv3(h))
76 h = h.reshape((-1, 7 * 7 * 64))
77
78 h = self.activation(self.lin(h))
79
80 pi = Categorical(logits=self.pi_logits(h))
81 value = self.value(h).reshape(-1)
82
83 return pi, value
Scale observations from [0, 255]
to [0, 1]
86def obs_to_torch(obs: np.ndarray) -> torch.Tensor:
88 return torch.tensor(obs, dtype=torch.float32, device=device) / 255.
91class Trainer:
96 def __init__(self, *,
97 updates: int, epochs: IntDynamicHyperParam,
98 n_workers: int, worker_steps: int, batches: int,
99 value_loss_coef: FloatDynamicHyperParam,
100 entropy_bonus_coef: FloatDynamicHyperParam,
101 clip_range: FloatDynamicHyperParam,
102 learning_rate: FloatDynamicHyperParam,
103 ):
number of updates
107 self.updates = updates
number of epochs to train the model with sampled data
109 self.epochs = epochs
number of worker processes
111 self.n_workers = n_workers
number of steps to run on each process for a single update
113 self.worker_steps = worker_steps
number of mini batches
115 self.batches = batches
total number of samples for a single update
117 self.batch_size = self.n_workers * self.worker_steps
size of a mini batch
119 self.mini_batch_size = self.batch_size // self.batches
120 assert (self.batch_size % self.batches == 0)
Value loss coefficient
123 self.value_loss_coef = value_loss_coef
Entropy bonus coefficient
125 self.entropy_bonus_coef = entropy_bonus_coef
Clipping range
128 self.clip_range = clip_range
Learning rate
130 self.learning_rate = learning_rate
create workers
135 self.workers = [Worker(47 + i) for i in range(self.n_workers)]
initialize tensors for observations
138 self.obs = np.zeros((self.n_workers, 4, 84, 84), dtype=np.uint8)
139 for worker in self.workers:
140 worker.child.send(("reset", None))
141 for i, worker in enumerate(self.workers):
142 self.obs[i] = worker.child.recv()
model
145 self.model = Model().to(device)
optimizer
148 self.optimizer = optim.Adam(self.model.parameters(), lr=2.5e-4)
GAE with $\gamma = 0.99$ and $\lambda = 0.95$
151 self.gae = GAE(self.n_workers, self.worker_steps, 0.99, 0.95)
PPO Loss
154 self.ppo_loss = ClippedPPOLoss()
Value Loss
157 self.value_loss = ClippedValueFunctionLoss()
159 def sample(self) -> Dict[str, torch.Tensor]:
164 rewards = np.zeros((self.n_workers, self.worker_steps), dtype=np.float32)
165 actions = np.zeros((self.n_workers, self.worker_steps), dtype=np.int32)
166 done = np.zeros((self.n_workers, self.worker_steps), dtype=np.bool)
167 obs = np.zeros((self.n_workers, self.worker_steps, 4, 84, 84), dtype=np.uint8)
168 log_pis = np.zeros((self.n_workers, self.worker_steps), dtype=np.float32)
169 values = np.zeros((self.n_workers, self.worker_steps + 1), dtype=np.float32)
170
171 with torch.no_grad():
sample worker_steps
from each worker
173 for t in range(self.worker_steps):
self.obs
keeps track of the last observation from each worker,
which is the input for the model to sample the next action
176 obs[:, t] = self.obs
sample actions from $\pi_{\theta_{OLD}}$ for each worker;
this returns arrays of size n_workers
179 pi, v = self.model(obs_to_torch(self.obs))
180 values[:, t] = v.cpu().numpy()
181 a = pi.sample()
182 actions[:, t] = a.cpu().numpy()
183 log_pis[:, t] = pi.log_prob(a).cpu().numpy()
run sampled actions on each worker
186 for w, worker in enumerate(self.workers):
187 worker.child.send(("step", actions[w, t]))
188
189 for w, worker in enumerate(self.workers):
get results after executing the actions
191 self.obs[w], rewards[w, t], done[w, t], info = worker.child.recv()
collect episode info, which is available if an episode finished;
this includes total reward and length of the episode -
look at Game
to see how it works.
196 if info:
197 tracker.add('reward', info['reward'])
198 tracker.add('length', info['length'])
Get value of after the final step
201 _, v = self.model(obs_to_torch(self.obs))
202 values[:, self.worker_steps] = v.cpu().numpy()
calculate advantages
205 advantages = self.gae(done, rewards, values)
208 samples = {
209 'obs': obs,
210 'actions': actions,
211 'values': values[:, :-1],
212 'log_pis': log_pis,
213 'advantages': advantages
214 }
samples are currently in [workers, time_step]
table,
we should flatten it for training
218 samples_flat = {}
219 for k, v in samples.items():
220 v = v.reshape(v.shape[0] * v.shape[1], *v.shape[2:])
221 if k == 'obs':
222 samples_flat[k] = obs_to_torch(v)
223 else:
224 samples_flat[k] = torch.tensor(v, device=device)
225
226 return samples_flat
228 def train(self, samples: Dict[str, torch.Tensor]):
It learns faster with a higher number of epochs, but becomes a little unstable; that is, the average episode reward does not monotonically increase over time. May be reducing the clipping range might solve it.
238 for _ in range(self.epochs()):
shuffle for each epoch
240 indexes = torch.randperm(self.batch_size)
for each mini batch
243 for start in range(0, self.batch_size, self.mini_batch_size):
get mini batch
245 end = start + self.mini_batch_size
246 mini_batch_indexes = indexes[start: end]
247 mini_batch = {}
248 for k, v in samples.items():
249 mini_batch[k] = v[mini_batch_indexes]
train
252 loss = self._calc_loss(mini_batch)
Set learning rate
255 for pg in self.optimizer.param_groups:
256 pg['lr'] = self.learning_rate()
Zero out the previously calculated gradients
258 self.optimizer.zero_grad()
Calculate gradients
260 loss.backward()
Clip gradients
262 torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=0.5)
Update parameters based on gradients
264 self.optimizer.step()
266 @staticmethod
267 def _normalize(adv: torch.Tensor):
269 return (adv - adv.mean()) / (adv.std() + 1e-8)
271 def _calc_loss(self, samples: Dict[str, torch.Tensor]) -> torch.Tensor:
$R_t$ returns sampled from $\pi_{\theta_{OLD}}$
277 sampled_return = samples['values'] + samples['advantages']
$\bar{A_t} = \frac{\hat{A_t} - \mu(\hat{A_t})}{\sigma(\hat{A_t})}$, where $\hat{A_t}$ is advantages sampled from $\pi_{\theta_{OLD}}$. Refer to sampling function in Main class below for the calculation of $\hat{A}_t$.
283 sampled_normalized_advantage = self._normalize(samples['advantages'])
Sampled observations are fed into the model to get $\pi_\theta(a_t|s_t)$ and $V^{\pi_\theta}(s_t)$; we are treating observations as state
287 pi, value = self.model(samples['obs'])
$-\log \pi_\theta (a_t|s_t)$, $a_t$ are actions sampled from $\pi_{\theta_{OLD}}$
290 log_pi = pi.log_prob(samples['actions'])
Calculate policy loss
293 policy_loss = self.ppo_loss(log_pi, samples['log_pis'], sampled_normalized_advantage, self.clip_range())
Calculate Entropy Bonus
$\mathcal{L}^{EB}(\theta) = \mathbb{E}\Bigl[ S\bigl[\pi_\theta\bigr] (s_t) \Bigr]$
299 entropy_bonus = pi.entropy()
300 entropy_bonus = entropy_bonus.mean()
Calculate value function loss
303 value_loss = self.value_loss(value, samples['values'], sampled_return, self.clip_range())
$\mathcal{L}^{CLIP+VF+EB} (\theta) = \mathcal{L}^{CLIP} (\theta) + c_1 \mathcal{L}^{VF} (\theta) - c_2 \mathcal{L}^{EB}(\theta)$
308 loss = (policy_loss
309 + self.value_loss_coef() * value_loss
310 - self.entropy_bonus_coef() * entropy_bonus)
for monitoring
313 approx_kl_divergence = .5 * ((samples['log_pis'] - log_pi) ** 2).mean()
Add to tracker
316 tracker.add({'policy_reward': -policy_loss,
317 'value_loss': value_loss,
318 'entropy_bonus': entropy_bonus,
319 'kl_div': approx_kl_divergence,
320 'clip_fraction': self.ppo_loss.clip_fraction})
321
322 return loss
324 def run_training_loop(self):
last 100 episode information
330 tracker.set_queue('reward', 100, True)
331 tracker.set_queue('length', 100, True)
332
333 for update in monit.loop(self.updates):
sample with current policy
335 samples = self.sample()
train the model
338 self.train(samples)
Save tracked indicators.
341 tracker.save()
Add a new line to the screen periodically
343 if (update + 1) % 1_000 == 0:
344 logger.log()
346 def destroy(self):
351 for worker in self.workers:
352 worker.child.send(("close", None))
355def main():
Create the experiment
357 experiment.create(name='ppo')
Configurations
359 configs = {
Number of updates
361 'updates': 10000,
⚙️ Number of epochs to train the model with sampled data.
You can change this while the experiment is running.
365 'epochs': IntDynamicHyperParam(8),
Number of worker processes
367 'n_workers': 8,
Number of steps to run on each process for a single update
369 'worker_steps': 128,
Number of mini batches
371 'batches': 4,
375 'value_loss_coef': FloatDynamicHyperParam(0.5),
379 'entropy_bonus_coef': FloatDynamicHyperParam(0.01),
⚙️ Clip range.
381 'clip_range': FloatDynamicHyperParam(0.1),
385 'learning_rate': FloatDynamicHyperParam(1e-3, (0, 1e-3)),
386 }
387
388 experiment.configs(configs)
Initialize the trainer
391 m = Trainer(
392 updates=configs['updates'],
393 epochs=configs['epochs'],
394 n_workers=configs['n_workers'],
395 worker_steps=configs['worker_steps'],
396 batches=configs['batches'],
397 value_loss_coef=configs['value_loss_coef'],
398 entropy_bonus_coef=configs['entropy_bonus_coef'],
399 clip_range=configs['clip_range'],
400 learning_rate=configs['learning_rate'],
401 )
Run and monitor the experiment
404 with experiment.start():
405 m.run_training_loop()
Stop the workers
407 m.destroy()
411if __name__ == "__main__":
412 main()