This experiment trains Proximal Policy Optimization (PPO) agent Atari Breakout game on OpenAI Gym. It runs the game environments on multiple processes to sample efficiently.
13from typing import Dict
14
15import numpy as np
16import torch
17from torch import nn
18from torch import optim
19from torch.distributions import Categorical
20
21from labml import monit, tracker, logger, experiment
22from labml.configs import FloatDynamicHyperParam
23from labml_helpers.module import Module
24from labml_nn.rl.game import Worker
25from labml_nn.rl.ppo import ClippedPPOLoss, ClippedValueFunctionLoss
26from labml_nn.rl.ppo.gae import GAESelect device
29if torch.cuda.is_available():
30 device = torch.device("cuda:0")
31else:
32 device = torch.device("cpu")35class Model(Module):40 def __init__(self):
41 super().__init__()The first convolution layer takes a 84x84 frame and produces a 20x20 frame
45 self.conv1 = nn.Conv2d(in_channels=4, out_channels=32, kernel_size=8, stride=4)The second convolution layer takes a 20x20 frame and produces a 9x9 frame
49 self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2)The third convolution layer takes a 9x9 frame and produces a 7x7 frame
53 self.conv3 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1)A fully connected layer takes the flattened frame from third convolution layer, and outputs 512 features
58 self.lin = nn.Linear(in_features=7 * 7 * 64, out_features=512)A fully connected layer to get logits for $\pi$
61 self.pi_logits = nn.Linear(in_features=512, out_features=4)A fully connected layer to get value function
64 self.value = nn.Linear(in_features=512, out_features=1)67 self.activation = nn.ReLU()69 def __call__(self, obs: torch.Tensor):
70 h = self.activation(self.conv1(obs))
71 h = self.activation(self.conv2(h))
72 h = self.activation(self.conv3(h))
73 h = h.reshape((-1, 7 * 7 * 64))
74
75 h = self.activation(self.lin(h))
76
77 pi = Categorical(logits=self.pi_logits(h))
78 value = self.value(h).reshape(-1)
79
80 return pi, valueScale observations from [0, 255] to [0, 1]
83def obs_to_torch(obs: np.ndarray) -> torch.Tensor:85 return torch.tensor(obs, dtype=torch.float32, device=device) / 255.88class Trainer:93 def __init__(self, *,
94 updates: int, epochs: int, n_workers: int, worker_steps: int, batches: int,
95 value_loss_coef: FloatDynamicHyperParam,
96 entropy_bonus_coef: FloatDynamicHyperParam,
97 clip_range: FloatDynamicHyperParam,
98 learning_rate: FloatDynamicHyperParam,
99 ):number of updates
103 self.updates = updatesnumber of epochs to train the model with sampled data
105 self.epochs = epochsnumber of worker processes
107 self.n_workers = n_workersnumber of steps to run on each process for a single update
109 self.worker_steps = worker_stepsnumber of mini batches
111 self.batches = batchestotal number of samples for a single update
113 self.batch_size = self.n_workers * self.worker_stepssize of a mini batch
115 self.mini_batch_size = self.batch_size // self.batches
116 assert (self.batch_size % self.batches == 0)Value loss coefficient
119 self.value_loss_coef = value_loss_coefEntropy bonus coefficient
121 self.entropy_bonus_coef = entropy_bonus_coefClipping range
124 self.clip_range = clip_rangeLearning rate
126 self.learning_rate = learning_ratecreate workers
131 self.workers = [Worker(47 + i) for i in range(self.n_workers)]initialize tensors for observations
134 self.obs = np.zeros((self.n_workers, 4, 84, 84), dtype=np.uint8)
135 for worker in self.workers:
136 worker.child.send(("reset", None))
137 for i, worker in enumerate(self.workers):
138 self.obs[i] = worker.child.recv()model
141 self.model = Model().to(device)optimizer
144 self.optimizer = optim.Adam(self.model.parameters(), lr=2.5e-4)GAE with $\gamma = 0.99$ and $\lambda = 0.95$
147 self.gae = GAE(self.n_workers, self.worker_steps, 0.99, 0.95)PPO Loss
150 self.ppo_loss = ClippedPPOLoss()Value Loss
153 self.value_loss = ClippedValueFunctionLoss()155 def sample(self) -> Dict[str, torch.Tensor]:160 rewards = np.zeros((self.n_workers, self.worker_steps), dtype=np.float32)
161 actions = np.zeros((self.n_workers, self.worker_steps), dtype=np.int32)
162 done = np.zeros((self.n_workers, self.worker_steps), dtype=np.bool)
163 obs = np.zeros((self.n_workers, self.worker_steps, 4, 84, 84), dtype=np.uint8)
164 log_pis = np.zeros((self.n_workers, self.worker_steps), dtype=np.float32)
165 values = np.zeros((self.n_workers, self.worker_steps + 1), dtype=np.float32)
166
167 with torch.no_grad():sample worker_steps from each worker
169 for t in range(self.worker_steps):self.obs keeps track of the last observation from each worker,
which is the input for the model to sample the next action
172 obs[:, t] = self.obssample actions from $\pi_{\theta_{OLD}}$ for each worker;
this returns arrays of size n_workers
175 pi, v = self.model(obs_to_torch(self.obs))
176 values[:, t] = v.cpu().numpy()
177 a = pi.sample()
178 actions[:, t] = a.cpu().numpy()
179 log_pis[:, t] = pi.log_prob(a).cpu().numpy()run sampled actions on each worker
182 for w, worker in enumerate(self.workers):
183 worker.child.send(("step", actions[w, t]))
184
185 for w, worker in enumerate(self.workers):get results after executing the actions
187 self.obs[w], rewards[w, t], done[w, t], info = worker.child.recv()collect episode info, which is available if an episode finished;
this includes total reward and length of the episode -
look at Game to see how it works.
192 if info:
193 tracker.add('reward', info['reward'])
194 tracker.add('length', info['length'])Get value of after the final step
197 _, v = self.model(obs_to_torch(self.obs))
198 values[:, self.worker_steps] = v.cpu().numpy()calculate advantages
201 advantages = self.gae(done, rewards, values)204 samples = {
205 'obs': obs,
206 'actions': actions,
207 'values': values[:, :-1],
208 'log_pis': log_pis,
209 'advantages': advantages
210 }samples are currently in [workers, time_step] table,
we should flatten it for training
214 samples_flat = {}
215 for k, v in samples.items():
216 v = v.reshape(v.shape[0] * v.shape[1], *v.shape[2:])
217 if k == 'obs':
218 samples_flat[k] = obs_to_torch(v)
219 else:
220 samples_flat[k] = torch.tensor(v, device=device)
221
222 return samples_flat224 def train(self, samples: Dict[str, torch.Tensor]):It learns faster with a higher number of epochs, but becomes a little unstable; that is, the average episode reward does not monotonically increase over time. May be reducing the clipping range might solve it.
234 for _ in range(self.epochs):shuffle for each epoch
236 indexes = torch.randperm(self.batch_size)for each mini batch
239 for start in range(0, self.batch_size, self.mini_batch_size):get mini batch
241 end = start + self.mini_batch_size
242 mini_batch_indexes = indexes[start: end]
243 mini_batch = {}
244 for k, v in samples.items():
245 mini_batch[k] = v[mini_batch_indexes]train
248 loss = self._calc_loss(mini_batch)Set learning rate
251 for pg in self.optimizer.param_groups:
252 pg['lr'] = self.learning_rate()Zero out the previously calculated gradients
254 self.optimizer.zero_grad()Calculate gradients
256 loss.backward()Clip gradients
258 torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=0.5)Update parameters based on gradients
260 self.optimizer.step()262 @staticmethod
263 def _normalize(adv: torch.Tensor):265 return (adv - adv.mean()) / (adv.std() + 1e-8)267 def _calc_loss(self, samples: Dict[str, torch.Tensor]) -> torch.Tensor:$R_t$ returns sampled from $\pi_{\theta_{OLD}}$
273 sampled_return = samples['values'] + samples['advantages']$\bar{A_t} = \frac{\hat{A_t} - \mu(\hat{A_t})}{\sigma(\hat{A_t})}$, where $\hat{A_t}$ is advantages sampled from $\pi_{\theta_{OLD}}$. Refer to sampling function in Main class below for the calculation of $\hat{A}_t$.
279 sampled_normalized_advantage = self._normalize(samples['advantages'])Sampled observations are fed into the model to get $\pi_\theta(a_t|s_t)$ and $V^{\pi_\theta}(s_t)$; we are treating observations as state
283 pi, value = self.model(samples['obs'])$-\log \pi_\theta (a_t|s_t)$, $a_t$ are actions sampled from $\pi_{\theta_{OLD}}$
286 log_pi = pi.log_prob(samples['actions'])Calculate policy loss
289 policy_loss = self.ppo_loss(log_pi, samples['log_pis'], sampled_normalized_advantage, self.clip_range())Calculate Entropy Bonus
$\mathcal{L}^{EB}(\theta) = \mathbb{E}\Bigl[ S\bigl[\pi_\theta\bigr] (s_t) \Bigr]$
295 entropy_bonus = pi.entropy()
296 entropy_bonus = entropy_bonus.mean()Calculate value function loss
299 value_loss = self.value_loss(value, samples['values'], sampled_return, self.clip_range())$\mathcal{L}^{CLIP+VF+EB} (\theta) = \mathcal{L}^{CLIP} (\theta) + c_1 \mathcal{L}^{VF} (\theta) - c_2 \mathcal{L}^{EB}(\theta)$
304 loss = (policy_loss
305 + self.value_loss_coef() * value_loss
306 - self.entropy_bonus_coef() * entropy_bonus)for monitoring
309 approx_kl_divergence = .5 * ((samples['log_pis'] - log_pi) ** 2).mean()Add to tracker
312 tracker.add({'policy_reward': -policy_loss,
313 'value_loss': value_loss,
314 'entropy_bonus': entropy_bonus,
315 'kl_div': approx_kl_divergence,
316 'clip_fraction': self.ppo_loss.clip_fraction})
317
318 return loss320 def run_training_loop(self):last 100 episode information
326 tracker.set_queue('reward', 100, True)
327 tracker.set_queue('length', 100, True)
328
329 for update in monit.loop(self.updates):sample with current policy
331 samples = self.sample()train the model
334 self.train(samples)Save tracked indicators.
337 tracker.save()Add a new line to the screen periodically
339 if (update + 1) % 1_000 == 0:
340 logger.log()342 def destroy(self):347 for worker in self.workers:
348 worker.child.send(("close", None))351def main():Create the experiment
353 experiment.create(name='ppo')Configurations
355 configs = {number of updates
357 'updates': 10000,number of epochs to train the model with sampled data
359 'epochs': 4,number of worker processes
361 'n_workers': 8,number of steps to run on each process for a single update
363 'worker_steps': 128,number of mini batches
365 'batches': 4,Value loss coefficient
367 'value_loss_coef': FloatDynamicHyperParam(0.5),Entropy bonus coefficient
369 'entropy_bonus_coef': FloatDynamicHyperParam(0.01),Clip range
371 'clip_range': FloatDynamicHyperParam(0.1),Learning rate
373 'learning_rate': FloatDynamicHyperParam(2.5e-4, (0, 1e-3)),
374 }
375
376 experiment.configs(configs)Initialize the trainer
379 m = Trainer(
380 updates=configs['updates'],
381 epochs=configs['epochs'],
382 n_workers=configs['n_workers'],
383 worker_steps=configs['worker_steps'],
384 batches=configs['batches'],
385 value_loss_coef=configs['value_loss_coef'],
386 entropy_bonus_coef=configs['entropy_bonus_coef'],
387 clip_range=configs['clip_range'],
388 learning_rate=configs['learning_rate'],
389 )Run and monitor the experiment
392 with experiment.start():
393 m.run_training_loop()Stop the workers
395 m.destroy()399if __name__ == "__main__":
400 main()