mirror of
				https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
				synced 2025-11-04 06:16:05 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			79 lines
		
	
	
		
			2.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			79 lines
		
	
	
		
			2.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
"""
 | 
						|
# Generalized Advantage Estimation (GAE)
 | 
						|
 | 
						|
This is an implementation of paper [Generalized Advantage Estimation](https://arxiv.org/abs/1506.02438).
 | 
						|
"""
 | 
						|
 | 
						|
import numpy as np
 | 
						|
 | 
						|
 | 
						|
class GAE:
 | 
						|
    def __init__(self, n_workers: int, worker_steps: int, gamma: float, lambda_: float):
 | 
						|
        self.lambda_ = lambda_
 | 
						|
        self.gamma = gamma
 | 
						|
        self.worker_steps = worker_steps
 | 
						|
        self.n_workers = n_workers
 | 
						|
 | 
						|
    def __call__(self, done: np.ndarray, rewards: np.ndarray, values: np.ndarray) -> np.ndarray:
 | 
						|
        """
 | 
						|
        ### Calculate advantages
 | 
						|
        \begin{align}
 | 
						|
        \hat{A_t^{(1)}} &= r_t + \gamma V(s_{t+1}) - V(s)
 | 
						|
        \\
 | 
						|
        \hat{A_t^{(2)}} &= r_t + \gamma r_{t+1} +\gamma^2 V(s_{t+2}) - V(s)
 | 
						|
        \\
 | 
						|
        ...
 | 
						|
        \\
 | 
						|
        \hat{A_t^{(\infty)}} &= r_t + \gamma r_{t+1} +\gamma^2 r_{t+1} + ... - V(s)
 | 
						|
        \end{align}
 | 
						|
 | 
						|
        $\hat{A_t^{(1)}}$ is high bias, low variance whilst
 | 
						|
        $\hat{A_t^{(\infty)}}$ is unbiased, high variance.
 | 
						|
 | 
						|
        We take a weighted average of $\hat{A_t^{(k)}}$ to balance bias and variance.
 | 
						|
        This is called Generalized Advantage Estimation.
 | 
						|
        $$\hat{A_t} = \hat{A_t^{GAE}} = \sum_k w_k \hat{A_t^{(k)}}$$
 | 
						|
        We set $w_k = \lambda^{k-1}$, this gives clean calculation for
 | 
						|
        $\hat{A_t}$
 | 
						|
 | 
						|
        \begin{align}
 | 
						|
        \delta_t &= r_t + \gamma V(s_{t+1}) - V(s_t)$
 | 
						|
        \\
 | 
						|
        \hat{A_t} &= \delta_t + \gamma \lambda \delta_{t+1} + ... +
 | 
						|
                             (\gamma \lambda)^{T - t + 1} \delta_{T - 1}$
 | 
						|
        \\
 | 
						|
        &= \delta_t + \gamma \lambda \hat{A_{t+1}}
 | 
						|
        \end{align}
 | 
						|
        """
 | 
						|
 | 
						|
        # advantages table
 | 
						|
        advantages = np.zeros((self.n_workers, self.worker_steps), dtype=np.float32)
 | 
						|
        last_advantage = 0
 | 
						|
 | 
						|
        # $V(s_{t+1})$
 | 
						|
        last_value = values[:, -1]
 | 
						|
 | 
						|
        for t in reversed(range(self.worker_steps)):
 | 
						|
            # mask if episode completed after step $t$
 | 
						|
            mask = 1.0 - done[:, t]
 | 
						|
            last_value = last_value * mask
 | 
						|
            last_advantage = last_advantage * mask
 | 
						|
            # $\delta_t$
 | 
						|
            delta = rewards[:, t] + self.gamma * last_value - values[:, t]
 | 
						|
 | 
						|
            # $\hat{A_t} = \delta_t + \gamma \lambda \hat{A_{t+1}}$
 | 
						|
            last_advantage = delta + self.gamma * self.lambda_ * last_advantage
 | 
						|
 | 
						|
            # note that we are collecting in reverse order.
 | 
						|
            # *My initial code was appending to a list and
 | 
						|
            #   I forgot to reverse it later.
 | 
						|
            # It took me around 4 to 5 hours to find the bug.
 | 
						|
            # The performance of the model was improving
 | 
						|
            #  slightly during initial runs,
 | 
						|
            #  probably because the samples are similar.*
 | 
						|
            advantages[:, t] = last_advantage
 | 
						|
 | 
						|
            last_value = values[:, t]
 | 
						|
 | 
						|
        return advantages
 |