This experiment trains Proximal Policy Optimization (PPO) agent Atari Breakout game on OpenAI Gym. It runs the game environments on multiple processes to sample efficiently.
16from typing import Dict
17
18import numpy as np
19import torch
20from torch import nn
21from torch import optim
22from torch.distributions import Categorical
23
24from labml import monit, tracker, logger, experiment
25from labml.configs import FloatDynamicHyperParam, IntDynamicHyperParam
26from labml_helpers.module import Module
27from labml_nn.rl.game import Worker
28from labml_nn.rl.ppo import ClippedPPOLoss, ClippedValueFunctionLoss
29from labml_nn.rl.ppo.gae import GAESelect device
32if torch.cuda.is_available():
33    device = torch.device("cuda:0")
34else:
35    device = torch.device("cpu")38class Model(Module):43    def __init__(self):
44        super().__init__()The first convolution layer takes a 84x84 frame and produces a 20x20 frame
48        self.conv1 = nn.Conv2d(in_channels=4, out_channels=32, kernel_size=8, stride=4)The second convolution layer takes a 20x20 frame and produces a 9x9 frame
52        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2)The third convolution layer takes a 9x9 frame and produces a 7x7 frame
56        self.conv3 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1)A fully connected layer takes the flattened frame from third convolution layer, and outputs 512 features
61        self.lin = nn.Linear(in_features=7 * 7 * 64, out_features=512)A fully connected layer to get logits for
64        self.pi_logits = nn.Linear(in_features=512, out_features=4)A fully connected layer to get value function
67        self.value = nn.Linear(in_features=512, out_features=1)70        self.activation = nn.ReLU()72    def forward(self, obs: torch.Tensor):
73        h = self.activation(self.conv1(obs))
74        h = self.activation(self.conv2(h))
75        h = self.activation(self.conv3(h))
76        h = h.reshape((-1, 7 * 7 * 64))
77
78        h = self.activation(self.lin(h))
79
80        pi = Categorical(logits=self.pi_logits(h))
81        value = self.value(h).reshape(-1)
82
83        return pi, valueScale observations from [0, 255]
 to [0, 1]
 
86def obs_to_torch(obs: np.ndarray) -> torch.Tensor:88    return torch.tensor(obs, dtype=torch.float32, device=device) / 255.91class Trainer:96    def __init__(self, *,
97                 updates: int, epochs: IntDynamicHyperParam,
98                 n_workers: int, worker_steps: int, batches: int,
99                 value_loss_coef: FloatDynamicHyperParam,
100                 entropy_bonus_coef: FloatDynamicHyperParam,
101                 clip_range: FloatDynamicHyperParam,
102                 learning_rate: FloatDynamicHyperParam,
103                 ):number of updates
107        self.updates = updatesnumber of epochs to train the model with sampled data
109        self.epochs = epochsnumber of worker processes
111        self.n_workers = n_workersnumber of steps to run on each process for a single update
113        self.worker_steps = worker_stepsnumber of mini batches
115        self.batches = batchestotal number of samples for a single update
117        self.batch_size = self.n_workers * self.worker_stepssize of a mini batch
119        self.mini_batch_size = self.batch_size // self.batches
120        assert (self.batch_size % self.batches == 0)Value loss coefficient
123        self.value_loss_coef = value_loss_coefEntropy bonus coefficient
125        self.entropy_bonus_coef = entropy_bonus_coefClipping range
128        self.clip_range = clip_rangeLearning rate
130        self.learning_rate = learning_ratecreate workers
135        self.workers = [Worker(47 + i) for i in range(self.n_workers)]initialize tensors for observations
138        self.obs = np.zeros((self.n_workers, 4, 84, 84), dtype=np.uint8)
139        for worker in self.workers:
140            worker.child.send(("reset", None))
141        for i, worker in enumerate(self.workers):
142            self.obs[i] = worker.child.recv()model
145        self.model = Model().to(device)optimizer
148        self.optimizer = optim.Adam(self.model.parameters(), lr=2.5e-4)GAE with and
151        self.gae = GAE(self.n_workers, self.worker_steps, 0.99, 0.95)PPO Loss
154        self.ppo_loss = ClippedPPOLoss()Value Loss
157        self.value_loss = ClippedValueFunctionLoss()159    def sample(self) -> Dict[str, torch.Tensor]:164        rewards = np.zeros((self.n_workers, self.worker_steps), dtype=np.float32)
165        actions = np.zeros((self.n_workers, self.worker_steps), dtype=np.int32)
166        done = np.zeros((self.n_workers, self.worker_steps), dtype=np.bool)
167        obs = np.zeros((self.n_workers, self.worker_steps, 4, 84, 84), dtype=np.uint8)
168        log_pis = np.zeros((self.n_workers, self.worker_steps), dtype=np.float32)
169        values = np.zeros((self.n_workers, self.worker_steps + 1), dtype=np.float32)
170
171        with torch.no_grad():sample worker_steps
 from each worker 
173            for t in range(self.worker_steps):self.obs
 keeps track of the last observation from each worker,  which is the input for the model to sample the next action 
176                obs[:, t] = self.obssample actions from  for each worker;  this returns arrays of size n_workers
 
179                pi, v = self.model(obs_to_torch(self.obs))
180                values[:, t] = v.cpu().numpy()
181                a = pi.sample()
182                actions[:, t] = a.cpu().numpy()
183                log_pis[:, t] = pi.log_prob(a).cpu().numpy()run sampled actions on each worker
186                for w, worker in enumerate(self.workers):
187                    worker.child.send(("step", actions[w, t]))
188
189                for w, worker in enumerate(self.workers):get results after executing the actions
191                    self.obs[w], rewards[w, t], done[w, t], info = worker.child.recv()collect episode info, which is available if an episode finished;  this includes total reward and length of the episode -  look at Game
 to see how it works. 
196                    if info:
197                        tracker.add('reward', info['reward'])
198                        tracker.add('length', info['length'])Get value of after the final step
201            _, v = self.model(obs_to_torch(self.obs))
202            values[:, self.worker_steps] = v.cpu().numpy()calculate advantages
205        advantages = self.gae(done, rewards, values)208        samples = {
209            'obs': obs,
210            'actions': actions,
211            'values': values[:, :-1],
212            'log_pis': log_pis,
213            'advantages': advantages
214        }samples are currently in [workers, time_step]
 table, we should flatten it for training 
218        samples_flat = {}
219        for k, v in samples.items():
220            v = v.reshape(v.shape[0] * v.shape[1], *v.shape[2:])
221            if k == 'obs':
222                samples_flat[k] = obs_to_torch(v)
223            else:
224                samples_flat[k] = torch.tensor(v, device=device)
225
226        return samples_flat228    def train(self, samples: Dict[str, torch.Tensor]):It learns faster with a higher number of epochs, but becomes a little unstable; that is, the average episode reward does not monotonically increase over time. May be reducing the clipping range might solve it.
238        for _ in range(self.epochs()):shuffle for each epoch
240            indexes = torch.randperm(self.batch_size)for each mini batch
243            for start in range(0, self.batch_size, self.mini_batch_size):get mini batch
245                end = start + self.mini_batch_size
246                mini_batch_indexes = indexes[start: end]
247                mini_batch = {}
248                for k, v in samples.items():
249                    mini_batch[k] = v[mini_batch_indexes]train
252                loss = self._calc_loss(mini_batch)Set learning rate
255                for pg in self.optimizer.param_groups:
256                    pg['lr'] = self.learning_rate()Zero out the previously calculated gradients
258                self.optimizer.zero_grad()Calculate gradients
260                loss.backward()Clip gradients
262                torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=0.5)Update parameters based on gradients
264                self.optimizer.step()266    @staticmethod
267    def _normalize(adv: torch.Tensor):269        return (adv - adv.mean()) / (adv.std() + 1e-8)271    def _calc_loss(self, samples: Dict[str, torch.Tensor]) -> torch.Tensor:returns sampled from
277        sampled_return = samples['values'] + samples['advantages'], where is advantages sampled from . Refer to sampling function in Main class below for the calculation of .
283        sampled_normalized_advantage = self._normalize(samples['advantages'])Sampled observations are fed into the model to get and ; we are treating observations as state
287        pi, value = self.model(samples['obs']), are actions sampled from
290        log_pi = pi.log_prob(samples['actions'])Calculate policy loss
293        policy_loss = self.ppo_loss(log_pi, samples['log_pis'], sampled_normalized_advantage, self.clip_range())299        entropy_bonus = pi.entropy()
300        entropy_bonus = entropy_bonus.mean()Calculate value function loss
303        value_loss = self.value_loss(value, samples['values'], sampled_return, self.clip_range())308        loss = (policy_loss
309                + self.value_loss_coef() * value_loss
310                - self.entropy_bonus_coef() * entropy_bonus)for monitoring
313        approx_kl_divergence = .5 * ((samples['log_pis'] - log_pi) ** 2).mean()Add to tracker
316        tracker.add({'policy_reward': -policy_loss,
317                     'value_loss': value_loss,
318                     'entropy_bonus': entropy_bonus,
319                     'kl_div': approx_kl_divergence,
320                     'clip_fraction': self.ppo_loss.clip_fraction})
321
322        return loss324    def run_training_loop(self):last 100 episode information
330        tracker.set_queue('reward', 100, True)
331        tracker.set_queue('length', 100, True)
332
333        for update in monit.loop(self.updates):sample with current policy
335            samples = self.sample()train the model
338            self.train(samples)Save tracked indicators.
341            tracker.save()Add a new line to the screen periodically
343            if (update + 1) % 1_000 == 0:
344                logger.log()346    def destroy(self):351        for worker in self.workers:
352            worker.child.send(("close", None))355def main():Create the experiment
357    experiment.create(name='ppo')Configurations
359    configs = {Number of updates
361        'updates': 10000,⚙️ Number of epochs to train the model with sampled data. You can change this while the experiment is running.  
365        'epochs': IntDynamicHyperParam(8),Number of worker processes
367        'n_workers': 8,Number of steps to run on each process for a single update
369        'worker_steps': 128,Number of mini batches
371        'batches': 4,375        'value_loss_coef': FloatDynamicHyperParam(0.5),379        'entropy_bonus_coef': FloatDynamicHyperParam(0.01),⚙️ Clip range.
381        'clip_range': FloatDynamicHyperParam(0.1),385        'learning_rate': FloatDynamicHyperParam(1e-3, (0, 1e-3)),
386    }
387
388    experiment.configs(configs)Initialize the trainer
391    m = Trainer(**configs)Run and monitor the experiment
394    with experiment.start():
395        m.run_training_loop()Stop the workers
397    m.destroy()401if __name__ == "__main__":
402    main()