PPO Experiment with Atari Breakout

This experiment trains Proximal Policy Optimization (PPO) agent Atari Breakout game on OpenAI Gym. It runs the game environments on multiple processes to sample efficiently.

13from typing import Dict
14
15import numpy as np
16import torch
17from torch import nn
18from torch import optim
19from torch.distributions import Categorical
20
21from labml import monit, tracker, logger, experiment
22from labml.configs import FloatDynamicHyperParam
23from labml_helpers.module import Module
24from labml_nn.rl.game import Worker
25from labml_nn.rl.ppo import ClippedPPOLoss, ClippedValueFunctionLoss
26from labml_nn.rl.ppo.gae import GAE

Select device

29if torch.cuda.is_available():
30    device = torch.device("cuda:0")
31else:
32    device = torch.device("cpu")

Model

35class Model(Module):
40    def __init__(self):
41        super().__init__()

The first convolution layer takes a 84x84 frame and produces a 20x20 frame

45        self.conv1 = nn.Conv2d(in_channels=4, out_channels=32, kernel_size=8, stride=4)

The second convolution layer takes a 20x20 frame and produces a 9x9 frame

49        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2)

The third convolution layer takes a 9x9 frame and produces a 7x7 frame

53        self.conv3 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1)

A fully connected layer takes the flattened frame from third convolution layer, and outputs 512 features

58        self.lin = nn.Linear(in_features=7 * 7 * 64, out_features=512)

A fully connected layer to get logits for $\pi$

61        self.pi_logits = nn.Linear(in_features=512, out_features=4)

A fully connected layer to get value function

64        self.value = nn.Linear(in_features=512, out_features=1)
67        self.activation = nn.ReLU()
69    def __call__(self, obs: torch.Tensor):
70        h = self.activation(self.conv1(obs))
71        h = self.activation(self.conv2(h))
72        h = self.activation(self.conv3(h))
73        h = h.reshape((-1, 7 * 7 * 64))
74
75        h = self.activation(self.lin(h))
76
77        pi = Categorical(logits=self.pi_logits(h))
78        value = self.value(h).reshape(-1)
79
80        return pi, value

Scale observations from [0, 255] to [0, 1]

83def obs_to_torch(obs: np.ndarray) -> torch.Tensor:
85    return torch.tensor(obs, dtype=torch.float32, device=device) / 255.

Trainer

88class Trainer:
93    def __init__(self, *,
94                 updates: int, epochs: int, n_workers: int, worker_steps: int, batches: int,
95                 value_loss_coef: FloatDynamicHyperParam,
96                 entropy_bonus_coef: FloatDynamicHyperParam,
97                 clip_range: FloatDynamicHyperParam,
98                 learning_rate: FloatDynamicHyperParam,
99                 ):

Configurations

number of updates

103        self.updates = updates

number of epochs to train the model with sampled data

105        self.epochs = epochs

number of worker processes

107        self.n_workers = n_workers

number of steps to run on each process for a single update

109        self.worker_steps = worker_steps

number of mini batches

111        self.batches = batches

total number of samples for a single update

113        self.batch_size = self.n_workers * self.worker_steps

size of a mini batch

115        self.mini_batch_size = self.batch_size // self.batches
116        assert (self.batch_size % self.batches == 0)

Value loss coefficient

119        self.value_loss_coef = value_loss_coef

Entropy bonus coefficient

121        self.entropy_bonus_coef = entropy_bonus_coef

Clipping range

124        self.clip_range = clip_range

Learning rate

126        self.learning_rate = learning_rate

Initialize

create workers

131        self.workers = [Worker(47 + i) for i in range(self.n_workers)]

initialize tensors for observations

134        self.obs = np.zeros((self.n_workers, 4, 84, 84), dtype=np.uint8)
135        for worker in self.workers:
136            worker.child.send(("reset", None))
137        for i, worker in enumerate(self.workers):
138            self.obs[i] = worker.child.recv()

model

141        self.model = Model().to(device)

optimizer

144        self.optimizer = optim.Adam(self.model.parameters(), lr=2.5e-4)

GAE with $\gamma = 0.99$ and $\lambda = 0.95$

147        self.gae = GAE(self.n_workers, self.worker_steps, 0.99, 0.95)

PPO Loss

150        self.ppo_loss = ClippedPPOLoss()

Value Loss

153        self.value_loss = ClippedValueFunctionLoss()

Sample data with current policy

155    def sample(self) -> Dict[str, torch.Tensor]:
160        rewards = np.zeros((self.n_workers, self.worker_steps), dtype=np.float32)
161        actions = np.zeros((self.n_workers, self.worker_steps), dtype=np.int32)
162        done = np.zeros((self.n_workers, self.worker_steps), dtype=np.bool)
163        obs = np.zeros((self.n_workers, self.worker_steps, 4, 84, 84), dtype=np.uint8)
164        log_pis = np.zeros((self.n_workers, self.worker_steps), dtype=np.float32)
165        values = np.zeros((self.n_workers, self.worker_steps + 1), dtype=np.float32)
166
167        with torch.no_grad():

sample worker_steps from each worker

169            for t in range(self.worker_steps):

self.obs keeps track of the last observation from each worker, which is the input for the model to sample the next action

172                obs[:, t] = self.obs

sample actions from $\pi_{\theta_{OLD}}$ for each worker; this returns arrays of size n_workers

175                pi, v = self.model(obs_to_torch(self.obs))
176                values[:, t] = v.cpu().numpy()
177                a = pi.sample()
178                actions[:, t] = a.cpu().numpy()
179                log_pis[:, t] = pi.log_prob(a).cpu().numpy()

run sampled actions on each worker

182                for w, worker in enumerate(self.workers):
183                    worker.child.send(("step", actions[w, t]))
184
185                for w, worker in enumerate(self.workers):

get results after executing the actions

187                    self.obs[w], rewards[w, t], done[w, t], info = worker.child.recv()

collect episode info, which is available if an episode finished; this includes total reward and length of the episode - look at Game to see how it works.

192                    if info:
193                        tracker.add('reward', info['reward'])
194                        tracker.add('length', info['length'])

Get value of after the final step

197            _, v = self.model(obs_to_torch(self.obs))
198            values[:, self.worker_steps] = v.cpu().numpy()

calculate advantages

201        advantages = self.gae(done, rewards, values)
204        samples = {
205            'obs': obs,
206            'actions': actions,
207            'values': values[:, :-1],
208            'log_pis': log_pis,
209            'advantages': advantages
210        }

samples are currently in [workers, time_step] table, we should flatten it for training

214        samples_flat = {}
215        for k, v in samples.items():
216            v = v.reshape(v.shape[0] * v.shape[1], *v.shape[2:])
217            if k == 'obs':
218                samples_flat[k] = obs_to_torch(v)
219            else:
220                samples_flat[k] = torch.tensor(v, device=device)
221
222        return samples_flat

Train the model based on samples

224    def train(self, samples: Dict[str, torch.Tensor]):

It learns faster with a higher number of epochs, but becomes a little unstable; that is, the average episode reward does not monotonically increase over time. May be reducing the clipping range might solve it.

234        for _ in range(self.epochs):

shuffle for each epoch

236            indexes = torch.randperm(self.batch_size)

for each mini batch

239            for start in range(0, self.batch_size, self.mini_batch_size):

get mini batch

241                end = start + self.mini_batch_size
242                mini_batch_indexes = indexes[start: end]
243                mini_batch = {}
244                for k, v in samples.items():
245                    mini_batch[k] = v[mini_batch_indexes]

train

248                loss = self._calc_loss(mini_batch)

Set learning rate

251                for pg in self.optimizer.param_groups:
252                    pg['lr'] = self.learning_rate()

Zero out the previously calculated gradients

254                self.optimizer.zero_grad()

Calculate gradients

256                loss.backward()

Clip gradients

258                torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=0.5)

Update parameters based on gradients

260                self.optimizer.step()

Normalize advantage function

262    @staticmethod
263    def _normalize(adv: torch.Tensor):
265        return (adv - adv.mean()) / (adv.std() + 1e-8)

Calculate total loss

267    def _calc_loss(self, samples: Dict[str, torch.Tensor]) -> torch.Tensor:

$R_t$ returns sampled from $\pi_{\theta_{OLD}}$

273        sampled_return = samples['values'] + samples['advantages']

$\bar{A_t} = \frac{\hat{A_t} - \mu(\hat{A_t})}{\sigma(\hat{A_t})}$, where $\hat{A_t}$ is advantages sampled from $\pi_{\theta_{OLD}}$. Refer to sampling function in Main class below for the calculation of $\hat{A}_t$.

279        sampled_normalized_advantage = self._normalize(samples['advantages'])

Sampled observations are fed into the model to get $\pi_\theta(a_t|s_t)$ and $V^{\pi_\theta}(s_t)$; we are treating observations as state

283        pi, value = self.model(samples['obs'])

$-\log \pi_\theta (a_t|s_t)$, $a_t$ are actions sampled from $\pi_{\theta_{OLD}}$

286        log_pi = pi.log_prob(samples['actions'])

Calculate policy loss

289        policy_loss = self.ppo_loss(log_pi, samples['log_pis'], sampled_normalized_advantage, self.clip_range())

Calculate Entropy Bonus

$\mathcal{L}^{EB}(\theta) = \mathbb{E}\Bigl[ S\bigl[\pi_\theta\bigr] (s_t) \Bigr]$

295        entropy_bonus = pi.entropy()
296        entropy_bonus = entropy_bonus.mean()

Calculate value function loss

299        value_loss = self.value_loss(value, samples['values'], sampled_return, self.clip_range())

$\mathcal{L}^{CLIP+VF+EB} (\theta) = \mathcal{L}^{CLIP} (\theta) + c_1 \mathcal{L}^{VF} (\theta) - c_2 \mathcal{L}^{EB}(\theta)$

304        loss = (policy_loss
305                + self.value_loss_coef() * value_loss
306                - self.entropy_bonus_coef() * entropy_bonus)

for monitoring

309        approx_kl_divergence = .5 * ((samples['log_pis'] - log_pi) ** 2).mean()

Add to tracker

312        tracker.add({'policy_reward': -policy_loss,
313                     'value_loss': value_loss,
314                     'entropy_bonus': entropy_bonus,
315                     'kl_div': approx_kl_divergence,
316                     'clip_fraction': self.ppo_loss.clip_fraction})
317
318        return loss

Run training loop

320    def run_training_loop(self):

last 100 episode information

326        tracker.set_queue('reward', 100, True)
327        tracker.set_queue('length', 100, True)
328
329        for update in monit.loop(self.updates):

sample with current policy

331            samples = self.sample()

train the model

334            self.train(samples)

Save tracked indicators.

337            tracker.save()

Add a new line to the screen periodically

339            if (update + 1) % 1_000 == 0:
340                logger.log()

Destroy

Stop the workers

342    def destroy(self):
347        for worker in self.workers:
348            worker.child.send(("close", None))
351def main():

Create the experiment

353    experiment.create(name='ppo')

Configurations

355    configs = {

number of updates

357        'updates': 10000,

number of epochs to train the model with sampled data

359        'epochs': 4,

number of worker processes

361        'n_workers': 8,

number of steps to run on each process for a single update

363        'worker_steps': 128,

number of mini batches

365        'batches': 4,

Value loss coefficient

367        'value_loss_coef': FloatDynamicHyperParam(0.5),

Entropy bonus coefficient

369        'entropy_bonus_coef': FloatDynamicHyperParam(0.01),

Clip range

371        'clip_range': FloatDynamicHyperParam(0.1),

Learning rate

373        'learning_rate': FloatDynamicHyperParam(2.5e-4, (0, 1e-3)),
374    }
375
376    experiment.configs(configs)

Initialize the trainer

379    m = Trainer(
380        updates=configs['updates'],
381        epochs=configs['epochs'],
382        n_workers=configs['n_workers'],
383        worker_steps=configs['worker_steps'],
384        batches=configs['batches'],
385        value_loss_coef=configs['value_loss_coef'],
386        entropy_bonus_coef=configs['entropy_bonus_coef'],
387        clip_range=configs['clip_range'],
388        learning_rate=configs['learning_rate'],
389    )

Run and monitor the experiment

392    with experiment.start():
393        m.run_training_loop()

Stop the workers

395    m.destroy()

Run it

399if __name__ == "__main__":
400    main()