Generalized Advantage Estimation (GAE)

This is a PyTorch implementation of paper Generalized Advantage Estimation.

13import numpy as np
16class GAE:
17    def __init__(self, n_workers: int, worker_steps: int, gamma: float, lambda_: float):
18        self.lambda_ = lambda_
19        self.gamma = gamma
20        self.worker_steps = worker_steps
21        self.n_workers = n_workers

Calculate advantages

$\hat{A_t^{(1)}}$ is high bias, low variance, whilst $\hat{A_t^{(\infty)}}$ is unbiased, high variance.

We take a weighted average of $\hat{A_t^{(k)}}$ to balance bias and variance. This is called Generalized Advantage Estimation. We set $w_k = \lambda^{k-1}$, this gives clean calculation for $\hat{A_t}$

23    def __call__(self, done: np.ndarray, rewards: np.ndarray, values: np.ndarray) -> np.ndarray:

advantages table

56        advantages = np.zeros((self.n_workers, self.worker_steps), dtype=np.float32)
57        last_advantage = 0

$V(s_{t+1})$

60        last_value = values[:, -1]
61
62        for t in reversed(range(self.worker_steps)):

mask if episode completed after step $t$

64            mask = 1.0 - done[:, t]
65            last_value = last_value * mask
66            last_advantage = last_advantage * mask

$\delta_t$

68            delta = rewards[:, t] + self.gamma * last_value - values[:, t]

$\hat{A_t} = \delta_t + \gamma \lambda \hat{A_{t+1}}$

71            last_advantage = delta + self.gamma * self.lambda_ * last_advantage

note that we are collecting in reverse order. My initial code was appending to a list and I forgot to reverse it later. It took me around 4 to 5 hours to find the bug. The performance of the model was improving slightly during initial runs, probably because the samples are similar.

80            advantages[:, t] = last_advantage
81
82            last_value = values[:, t]
83
84        return advantages