This experiment trains Proximal Policy Optimization (PPO) agent Atari Breakout game on OpenAI Gym. It runs the game environments on multiple processes to sample efficiently.
15from typing import Dict
16
17import numpy as np
18import torch
19from torch import nn
20from torch import optim
21from torch.distributions import Categorical
22
23from labml import monit, tracker, logger, experiment
24from labml.configs import FloatDynamicHyperParam, IntDynamicHyperParam
25from labml_nn.rl.game import Worker
26from labml_nn.rl.ppo import ClippedPPOLoss, ClippedValueFunctionLoss
27from labml_nn.rl.ppo.gae import GAE
Select device
30if torch.cuda.is_available():
31 device = torch.device("cuda:0")
32else:
33 device = torch.device("cpu")
36class Model(nn.Module):
41 def __init__(self):
42 super().__init__()
The first convolution layer takes a 84x84 frame and produces a 20x20 frame
46 self.conv1 = nn.Conv2d(in_channels=4, out_channels=32, kernel_size=8, stride=4)
The second convolution layer takes a 20x20 frame and produces a 9x9 frame
50 self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2)
The third convolution layer takes a 9x9 frame and produces a 7x7 frame
54 self.conv3 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1)
A fully connected layer takes the flattened frame from third convolution layer, and outputs 512 features
59 self.lin = nn.Linear(in_features=7 * 7 * 64, out_features=512)
A fully connected layer to get logits for
62 self.pi_logits = nn.Linear(in_features=512, out_features=4)
A fully connected layer to get value function
65 self.value = nn.Linear(in_features=512, out_features=1)
68 self.activation = nn.ReLU()
70 def forward(self, obs: torch.Tensor):
71 h = self.activation(self.conv1(obs))
72 h = self.activation(self.conv2(h))
73 h = self.activation(self.conv3(h))
74 h = h.reshape((-1, 7 * 7 * 64))
75
76 h = self.activation(self.lin(h))
77
78 pi = Categorical(logits=self.pi_logits(h))
79 value = self.value(h).reshape(-1)
80
81 return pi, value
Scale observations from [0, 255]
to [0, 1]
84def obs_to_torch(obs: np.ndarray) -> torch.Tensor:
86 return torch.tensor(obs, dtype=torch.float32, device=device) / 255.
89class Trainer:
94 def __init__(self, *,
95 updates: int, epochs: IntDynamicHyperParam,
96 n_workers: int, worker_steps: int, batches: int,
97 value_loss_coef: FloatDynamicHyperParam,
98 entropy_bonus_coef: FloatDynamicHyperParam,
99 clip_range: FloatDynamicHyperParam,
100 learning_rate: FloatDynamicHyperParam,
101 ):
number of updates
105 self.updates = updates
number of epochs to train the model with sampled data
107 self.epochs = epochs
number of worker processes
109 self.n_workers = n_workers
number of steps to run on each process for a single update
111 self.worker_steps = worker_steps
number of mini batches
113 self.batches = batches
total number of samples for a single update
115 self.batch_size = self.n_workers * self.worker_steps
size of a mini batch
117 self.mini_batch_size = self.batch_size // self.batches
118 assert (self.batch_size % self.batches == 0)
Value loss coefficient
121 self.value_loss_coef = value_loss_coef
Entropy bonus coefficient
123 self.entropy_bonus_coef = entropy_bonus_coef
Clipping range
126 self.clip_range = clip_range
Learning rate
128 self.learning_rate = learning_rate
create workers
133 self.workers = [Worker(47 + i) for i in range(self.n_workers)]
initialize tensors for observations
136 self.obs = np.zeros((self.n_workers, 4, 84, 84), dtype=np.uint8)
137 for worker in self.workers:
138 worker.child.send(("reset", None))
139 for i, worker in enumerate(self.workers):
140 self.obs[i] = worker.child.recv()
model
143 self.model = Model().to(device)
optimizer
146 self.optimizer = optim.Adam(self.model.parameters(), lr=2.5e-4)
GAE with and
149 self.gae = GAE(self.n_workers, self.worker_steps, 0.99, 0.95)
PPO Loss
152 self.ppo_loss = ClippedPPOLoss()
Value Loss
155 self.value_loss = ClippedValueFunctionLoss()
157 def sample(self) -> Dict[str, torch.Tensor]:
162 rewards = np.zeros((self.n_workers, self.worker_steps), dtype=np.float32)
163 actions = np.zeros((self.n_workers, self.worker_steps), dtype=np.int32)
164 done = np.zeros((self.n_workers, self.worker_steps), dtype=np.bool)
165 obs = np.zeros((self.n_workers, self.worker_steps, 4, 84, 84), dtype=np.uint8)
166 log_pis = np.zeros((self.n_workers, self.worker_steps), dtype=np.float32)
167 values = np.zeros((self.n_workers, self.worker_steps + 1), dtype=np.float32)
168
169 with torch.no_grad():
sample worker_steps
from each worker
171 for t in range(self.worker_steps):
self.obs
keeps track of the last observation from each worker, which is the input for the model to sample the next action
174 obs[:, t] = self.obs
sample actions from for each worker; this returns arrays of size n_workers
177 pi, v = self.model(obs_to_torch(self.obs))
178 values[:, t] = v.cpu().numpy()
179 a = pi.sample()
180 actions[:, t] = a.cpu().numpy()
181 log_pis[:, t] = pi.log_prob(a).cpu().numpy()
run sampled actions on each worker
184 for w, worker in enumerate(self.workers):
185 worker.child.send(("step", actions[w, t]))
186
187 for w, worker in enumerate(self.workers):
get results after executing the actions
189 self.obs[w], rewards[w, t], done[w, t], info = worker.child.recv()
collect episode info, which is available if an episode finished; this includes total reward and length of the episode - look at Game
to see how it works.
194 if info:
195 tracker.add('reward', info['reward'])
196 tracker.add('length', info['length'])
Get value of after the final step
199 _, v = self.model(obs_to_torch(self.obs))
200 values[:, self.worker_steps] = v.cpu().numpy()
calculate advantages
203 advantages = self.gae(done, rewards, values)
206 samples = {
207 'obs': obs,
208 'actions': actions,
209 'values': values[:, :-1],
210 'log_pis': log_pis,
211 'advantages': advantages
212 }
samples are currently in [workers, time_step]
table, we should flatten it for training
216 samples_flat = {}
217 for k, v in samples.items():
218 v = v.reshape(v.shape[0] * v.shape[1], *v.shape[2:])
219 if k == 'obs':
220 samples_flat[k] = obs_to_torch(v)
221 else:
222 samples_flat[k] = torch.tensor(v, device=device)
223
224 return samples_flat
226 def train(self, samples: Dict[str, torch.Tensor]):
It learns faster with a higher number of epochs, but becomes a little unstable; that is, the average episode reward does not monotonically increase over time. May be reducing the clipping range might solve it.
236 for _ in range(self.epochs()):
shuffle for each epoch
238 indexes = torch.randperm(self.batch_size)
for each mini batch
241 for start in range(0, self.batch_size, self.mini_batch_size):
get mini batch
243 end = start + self.mini_batch_size
244 mini_batch_indexes = indexes[start: end]
245 mini_batch = {}
246 for k, v in samples.items():
247 mini_batch[k] = v[mini_batch_indexes]
train
250 loss = self._calc_loss(mini_batch)
Set learning rate
253 for pg in self.optimizer.param_groups:
254 pg['lr'] = self.learning_rate()
Zero out the previously calculated gradients
256 self.optimizer.zero_grad()
Calculate gradients
258 loss.backward()
Clip gradients
260 torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=0.5)
Update parameters based on gradients
262 self.optimizer.step()
264 @staticmethod
265 def _normalize(adv: torch.Tensor):
267 return (adv - adv.mean()) / (adv.std() + 1e-8)
269 def _calc_loss(self, samples: Dict[str, torch.Tensor]) -> torch.Tensor:
returns sampled from
275 sampled_return = samples['values'] + samples['advantages']
, where is advantages sampled from . Refer to sampling function in Main class below for the calculation of .
281 sampled_normalized_advantage = self._normalize(samples['advantages'])
Sampled observations are fed into the model to get and ; we are treating observations as state
285 pi, value = self.model(samples['obs'])
, are actions sampled from
288 log_pi = pi.log_prob(samples['actions'])
Calculate policy loss
291 policy_loss = self.ppo_loss(log_pi, samples['log_pis'], sampled_normalized_advantage, self.clip_range())
297 entropy_bonus = pi.entropy()
298 entropy_bonus = entropy_bonus.mean()
Calculate value function loss
301 value_loss = self.value_loss(value, samples['values'], sampled_return, self.clip_range())
306 loss = (policy_loss
307 + self.value_loss_coef() * value_loss
308 - self.entropy_bonus_coef() * entropy_bonus)
for monitoring
311 approx_kl_divergence = .5 * ((samples['log_pis'] - log_pi) ** 2).mean()
Add to tracker
314 tracker.add({'policy_reward': -policy_loss,
315 'value_loss': value_loss,
316 'entropy_bonus': entropy_bonus,
317 'kl_div': approx_kl_divergence,
318 'clip_fraction': self.ppo_loss.clip_fraction})
319
320 return loss
322 def run_training_loop(self):
last 100 episode information
328 tracker.set_queue('reward', 100, True)
329 tracker.set_queue('length', 100, True)
330
331 for update in monit.loop(self.updates):
sample with current policy
333 samples = self.sample()
train the model
336 self.train(samples)
Save tracked indicators.
339 tracker.save()
Add a new line to the screen periodically
341 if (update + 1) % 1_000 == 0:
342 logger.log()
344 def destroy(self):
349 for worker in self.workers:
350 worker.child.send(("close", None))
353def main():
Create the experiment
355 experiment.create(name='ppo')
Configurations
357 configs = {
Number of updates
359 'updates': 10000,
⚙️ Number of epochs to train the model with sampled data. You can change this while the experiment is running.
362 'epochs': IntDynamicHyperParam(8),
Number of worker processes
364 'n_workers': 8,
Number of steps to run on each process for a single update
366 'worker_steps': 128,
Number of mini batches
368 'batches': 4,
⚙️ Value loss coefficient. You can change this while the experiment is running.
371 'value_loss_coef': FloatDynamicHyperParam(0.5),
⚙️ Entropy bonus coefficient. You can change this while the experiment is running.
374 'entropy_bonus_coef': FloatDynamicHyperParam(0.01),
⚙️ Clip range.
376 'clip_range': FloatDynamicHyperParam(0.1),
You can change this while the experiment is running. ⚙️ Learning rate.
379 'learning_rate': FloatDynamicHyperParam(1e-3, (0, 1e-3)),
380 }
381
382 experiment.configs(configs)
Initialize the trainer
385 m = Trainer(**configs)
Run and monitor the experiment
388 with experiment.start():
389 m.run_training_loop()
Stop the workers
391 m.destroy()
395if __name__ == "__main__":
396 main()