mirror of
				https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
				synced 2025-10-31 02:39:16 +08:00 
			
		
		
		
	ppo
This commit is contained in:
		
							
								
								
									
										72
									
								
								labml_nn/rl/ppo/gae.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										72
									
								
								labml_nn/rl/ppo/gae.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,72 @@ | ||||
| import numpy as np | ||||
|  | ||||
|  | ||||
| class GAE: | ||||
|     def __init__(self, n_workers: int, worker_steps: int, gamma: float, lambda_: float): | ||||
|         self.lambda_ = lambda_ | ||||
|         self.gamma = gamma | ||||
|         self.worker_steps = worker_steps | ||||
|         self.n_workers = n_workers | ||||
|  | ||||
|     def __call__(self, done: np.ndarray, rewards: np.ndarray, values: np.ndarray) -> np.ndarray: | ||||
|         """ | ||||
|         ### Calculate advantages | ||||
|         \begin{align} | ||||
|         \hat{A_t^{(1)}} &= r_t + \gamma V(s_{t+1}) - V(s) | ||||
|         \\ | ||||
|         \hat{A_t^{(2)}} &= r_t + \gamma r_{t+1} +\gamma^2 V(s_{t+2}) - V(s) | ||||
|         \\ | ||||
|         ... | ||||
|         \\ | ||||
|         \hat{A_t^{(\infty)}} &= r_t + \gamma r_{t+1} +\gamma^2 r_{t+1} + ... - V(s) | ||||
|         \end{align} | ||||
|  | ||||
|         $\hat{A_t^{(1)}}$ is high bias, low variance whilst | ||||
|         $\hat{A_t^{(\infty)}}$ is unbiased, high variance. | ||||
|  | ||||
|         We take a weighted average of $\hat{A_t^{(k)}}$ to balance bias and variance. | ||||
|         This is called Generalized Advantage Estimation. | ||||
|         $$\hat{A_t} = \hat{A_t^{GAE}} = \sum_k w_k \hat{A_t^{(k)}}$$ | ||||
|         We set $w_k = \lambda^{k-1}$, this gives clean calculation for | ||||
|         $\hat{A_t}$ | ||||
|  | ||||
|         \begin{align} | ||||
|         \delta_t &= r_t + \gamma V(s_{t+1}) - V(s_t)$ | ||||
|         \\ | ||||
|         \hat{A_t} &= \delta_t + \gamma \lambda \delta_{t+1} + ... + | ||||
|                              (\gamma \lambda)^{T - t + 1} \delta_{T - 1}$ | ||||
|         \\ | ||||
|         &= \delta_t + \gamma \lambda \hat{A_{t+1}} | ||||
|         \end{align} | ||||
|         """ | ||||
|  | ||||
|         # advantages table | ||||
|         advantages = np.zeros((self.n_workers, self.worker_steps), dtype=np.float32) | ||||
|         last_advantage = 0 | ||||
|  | ||||
|         # $V(s_{t+1})$ | ||||
|         last_value = values[:, -1] | ||||
|  | ||||
|         for t in reversed(range(self.worker_steps)): | ||||
|             # mask if episode completed after step $t$ | ||||
|             mask = 1.0 - done[:, t] | ||||
|             last_value = last_value * mask | ||||
|             last_advantage = last_advantage * mask | ||||
|             # $\delta_t$ | ||||
|             delta = rewards[:, t] + self.gamma * last_value - values[:, t] | ||||
|  | ||||
|             # $\hat{A_t} = \delta_t + \gamma \lambda \hat{A_{t+1}}$ | ||||
|             last_advantage = delta + self.gamma * self.lambda_ * last_advantage | ||||
|  | ||||
|             # note that we are collecting in reverse order. | ||||
|             # *My initial code was appending to a list and | ||||
|             #   I forgot to reverse it later. | ||||
|             # It took me around 4 to 5 hours to find the bug. | ||||
|             # The performance of the model was improving | ||||
|             #  slightly during initial runs, | ||||
|             #  probably because the samples are similar.* | ||||
|             advantages[:, t] = last_advantage | ||||
|  | ||||
|             last_value = values[:, t] | ||||
|  | ||||
|         return advantages | ||||
		Reference in New Issue
	
	Block a user
	 Varuna Jayasiri
					Varuna Jayasiri