From 50bd0556a569c02fc2010c08b4e76122e6bed2bb Mon Sep 17 00:00:00 2001
From: Varuna Jayasiri This experiment trains Proximal Policy Optimization (PPO) agent Atari Breakout game on OpenAI Gym.
It runs the game environments on multiple processes to sample efficiently. Select device A fully connected layer to get logits for $\pi$ A fully connected layer to get value function Scale observations from number of updates number of epochs to train the model with sampled data number of worker processes number of steps to run on each process for a single update number of mini batches total number of samples for a single update size of a mini batch Value loss coefficient Entropy bonus coefficient Clipping range Learning rate create workers initialize tensors for observations model optimizer GAE with $\gamma = 0.99$ and $\lambda = 0.95$ PPO Loss Value Loss sample run sampled actions on each worker get results after executing the actions Get value of after the final step calculate advantages shuffle for each epoch for each mini batch get mini batch train Set learning rate Zero out the previously calculated gradients Calculate gradients Clip gradients Update parameters based on gradients $R_t$ returns sampled from $\pi_{\theta_{OLD}}$ $-\log \pi_\theta (a_t|s_t)$, $a_t$ are actions sampled from $\pi_{\theta_{OLD}}$ Calculate policy loss Calculate value function loss for monitoring Add to tracker last 100 episode information sample with current policy train the model Save tracked indicators. Add a new line to the screen periodically Stop the workers Create the experiment Configurations number of updates Number of updates number of epochs to train the model with sampled data ⚙️ Number of epochs to train the model with sampled data.
+You can change this while the experiment is running.
+ number of worker processes Number of worker processes number of steps to run on each process for a single update Number of steps to run on each process for a single update number of mini batches Number of mini batches Value loss coefficient ⚙️ Value loss coefficient.
+You can change this while the experiment is running.
+ Entropy bonus coefficient ⚙️ Entropy bonus coefficient.
+You can change this while the experiment is running.
+ Clip range ⚙️ Clip range. Learning rate You can change this while the experiment is running.
+ Initialize the trainer Run and monitor the experiment Stop the workers This is a PyTorch implementation of paper
Generalized Advantage Estimation. You can find an experiment that uses it here. advantages table $V(s_{t+1})$ mask if episode completed after step $t$ $\delta_t$ $\hat{A_t} = \delta_t + \gamma \lambda \hat{A_{t+1}}$PPO Experiment with Atari Breakout
13from typing import Dict
-14
-15import numpy as np
-16import torch
-17from torch import nn
-18from torch import optim
-19from torch.distributions import Categorical
-20
-21from labml import monit, tracker, logger, experiment
-22from labml.configs import FloatDynamicHyperParam
-23from labml_helpers.module import Module
-24from labml_nn.rl.game import Worker
-25from labml_nn.rl.ppo import ClippedPPOLoss, ClippedValueFunctionLoss
-26from labml_nn.rl.ppo.gae import GAE
16from typing import Dict
+17
+18import numpy as np
+19import torch
+20from torch import nn
+21from torch import optim
+22from torch.distributions import Categorical
+23
+24from labml import monit, tracker, logger, experiment
+25from labml.configs import FloatDynamicHyperParam, IntDynamicHyperParam
+26from labml_helpers.module import Module
+27from labml_nn.rl.game import Worker
+28from labml_nn.rl.ppo import ClippedPPOLoss, ClippedValueFunctionLoss
+29from labml_nn.rl.ppo.gae import GAE
29if torch.cuda.is_available():
-30 device = torch.device("cuda:0")
-31else:
-32 device = torch.device("cpu")
32if torch.cuda.is_available():
+33 device = torch.device("cuda:0")
+34else:
+35 device = torch.device("cpu")
Model
35class Model(Module):
38class Model(Module):
40 def __init__(self):
-41 super().__init__()
43 def __init__(self):
+44 super().__init__()
45 self.conv1 = nn.Conv2d(in_channels=4, out_channels=32, kernel_size=8, stride=4)
48 self.conv1 = nn.Conv2d(in_channels=4, out_channels=32, kernel_size=8, stride=4)
49 self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2)
52 self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2)
53 self.conv3 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1)
56 self.conv3 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1)
58 self.lin = nn.Linear(in_features=7 * 7 * 64, out_features=512)
61 self.lin = nn.Linear(in_features=7 * 7 * 64, out_features=512)
61 self.pi_logits = nn.Linear(in_features=512, out_features=4)
64 self.pi_logits = nn.Linear(in_features=512, out_features=4)
64 self.value = nn.Linear(in_features=512, out_features=1)
67 self.value = nn.Linear(in_features=512, out_features=1)
67 self.activation = nn.ReLU()
70 self.activation = nn.ReLU()
69 def __call__(self, obs: torch.Tensor):
-70 h = self.activation(self.conv1(obs))
-71 h = self.activation(self.conv2(h))
-72 h = self.activation(self.conv3(h))
-73 h = h.reshape((-1, 7 * 7 * 64))
-74
-75 h = self.activation(self.lin(h))
-76
-77 pi = Categorical(logits=self.pi_logits(h))
-78 value = self.value(h).reshape(-1)
+
72 def __call__(self, obs: torch.Tensor):
+73 h = self.activation(self.conv1(obs))
+74 h = self.activation(self.conv2(h))
+75 h = self.activation(self.conv3(h))
+76 h = h.reshape((-1, 7 * 7 * 64))
+77
+78 h = self.activation(self.lin(h))
79
-80 return pi, value
[0, 255] to [0, 1]83def obs_to_torch(obs: np.ndarray) -> torch.Tensor:
86def obs_to_torch(obs: np.ndarray) -> torch.Tensor:
85 return torch.tensor(obs, dtype=torch.float32, device=device) / 255.
88 return torch.tensor(obs, dtype=torch.float32, device=device) / 255.
Trainer
88class Trainer:
91class Trainer:
93 def __init__(self, *,
-94 updates: int, epochs: int, n_workers: int, worker_steps: int, batches: int,
-95 value_loss_coef: FloatDynamicHyperParam,
-96 entropy_bonus_coef: FloatDynamicHyperParam,
-97 clip_range: FloatDynamicHyperParam,
-98 learning_rate: FloatDynamicHyperParam,
-99 ):
96 def __init__(self, *,
+97 updates: int, epochs: IntDynamicHyperParam,
+98 n_workers: int, worker_steps: int, batches: int,
+99 value_loss_coef: FloatDynamicHyperParam,
+100 entropy_bonus_coef: FloatDynamicHyperParam,
+101 clip_range: FloatDynamicHyperParam,
+102 learning_rate: FloatDynamicHyperParam,
+103 ):
103 self.updates = updates
107 self.updates = updates
105 self.epochs = epochs
109 self.epochs = epochs
107 self.n_workers = n_workers
111 self.n_workers = n_workers
109 self.worker_steps = worker_steps
113 self.worker_steps = worker_steps
111 self.batches = batches
115 self.batches = batches
113 self.batch_size = self.n_workers * self.worker_steps
117 self.batch_size = self.n_workers * self.worker_steps
115 self.mini_batch_size = self.batch_size // self.batches
-116 assert (self.batch_size % self.batches == 0)
119 self.mini_batch_size = self.batch_size // self.batches
+120 assert (self.batch_size % self.batches == 0)
119 self.value_loss_coef = value_loss_coef
123 self.value_loss_coef = value_loss_coef
121 self.entropy_bonus_coef = entropy_bonus_coef
125 self.entropy_bonus_coef = entropy_bonus_coef
124 self.clip_range = clip_range
128 self.clip_range = clip_range
126 self.learning_rate = learning_rate
130 self.learning_rate = learning_rate
131 self.workers = [Worker(47 + i) for i in range(self.n_workers)]
135 self.workers = [Worker(47 + i) for i in range(self.n_workers)]
134 self.obs = np.zeros((self.n_workers, 4, 84, 84), dtype=np.uint8)
-135 for worker in self.workers:
-136 worker.child.send(("reset", None))
-137 for i, worker in enumerate(self.workers):
-138 self.obs[i] = worker.child.recv()
138 self.obs = np.zeros((self.n_workers, 4, 84, 84), dtype=np.uint8)
+139 for worker in self.workers:
+140 worker.child.send(("reset", None))
+141 for i, worker in enumerate(self.workers):
+142 self.obs[i] = worker.child.recv()
141 self.model = Model().to(device)
145 self.model = Model().to(device)
144 self.optimizer = optim.Adam(self.model.parameters(), lr=2.5e-4)
148 self.optimizer = optim.Adam(self.model.parameters(), lr=2.5e-4)
147 self.gae = GAE(self.n_workers, self.worker_steps, 0.99, 0.95)
151 self.gae = GAE(self.n_workers, self.worker_steps, 0.99, 0.95)
150 self.ppo_loss = ClippedPPOLoss()
154 self.ppo_loss = ClippedPPOLoss()
153 self.value_loss = ClippedValueFunctionLoss()
157 self.value_loss = ClippedValueFunctionLoss()
Sample data with current policy
155 def sample(self) -> Dict[str, torch.Tensor]:
159 def sample(self) -> Dict[str, torch.Tensor]:
160 rewards = np.zeros((self.n_workers, self.worker_steps), dtype=np.float32)
-161 actions = np.zeros((self.n_workers, self.worker_steps), dtype=np.int32)
-162 done = np.zeros((self.n_workers, self.worker_steps), dtype=np.bool)
-163 obs = np.zeros((self.n_workers, self.worker_steps, 4, 84, 84), dtype=np.uint8)
-164 log_pis = np.zeros((self.n_workers, self.worker_steps), dtype=np.float32)
-165 values = np.zeros((self.n_workers, self.worker_steps + 1), dtype=np.float32)
-166
-167 with torch.no_grad():
164 rewards = np.zeros((self.n_workers, self.worker_steps), dtype=np.float32)
+165 actions = np.zeros((self.n_workers, self.worker_steps), dtype=np.int32)
+166 done = np.zeros((self.n_workers, self.worker_steps), dtype=np.bool)
+167 obs = np.zeros((self.n_workers, self.worker_steps, 4, 84, 84), dtype=np.uint8)
+168 log_pis = np.zeros((self.n_workers, self.worker_steps), dtype=np.float32)
+169 values = np.zeros((self.n_workers, self.worker_steps + 1), dtype=np.float32)
+170
+171 with torch.no_grad():
worker_steps from each worker169 for t in range(self.worker_steps):
173 for t in range(self.worker_steps):
172 obs[:, t] = self.obs
176 obs[:, t] = self.obs
n_workers
175 pi, v = self.model(obs_to_torch(self.obs))
-176 values[:, t] = v.cpu().numpy()
-177 a = pi.sample()
-178 actions[:, t] = a.cpu().numpy()
-179 log_pis[:, t] = pi.log_prob(a).cpu().numpy()
179 pi, v = self.model(obs_to_torch(self.obs))
+180 values[:, t] = v.cpu().numpy()
+181 a = pi.sample()
+182 actions[:, t] = a.cpu().numpy()
+183 log_pis[:, t] = pi.log_prob(a).cpu().numpy()
182 for w, worker in enumerate(self.workers):
-183 worker.child.send(("step", actions[w, t]))
-184
-185 for w, worker in enumerate(self.workers):
186 for w, worker in enumerate(self.workers):
+187 worker.child.send(("step", actions[w, t]))
+188
+189 for w, worker in enumerate(self.workers):
187 self.obs[w], rewards[w, t], done[w, t], info = worker.child.recv()
191 self.obs[w], rewards[w, t], done[w, t], info = worker.child.recv()
Game to see how it works.
192 if info:
-193 tracker.add('reward', info['reward'])
-194 tracker.add('length', info['length'])
196 if info:
+197 tracker.add('reward', info['reward'])
+198 tracker.add('length', info['length'])
197 _, v = self.model(obs_to_torch(self.obs))
-198 values[:, self.worker_steps] = v.cpu().numpy()
201 _, v = self.model(obs_to_torch(self.obs))
+202 values[:, self.worker_steps] = v.cpu().numpy()
201 advantages = self.gae(done, rewards, values)
205 advantages = self.gae(done, rewards, values)
204 samples = {
-205 'obs': obs,
-206 'actions': actions,
-207 'values': values[:, :-1],
-208 'log_pis': log_pis,
-209 'advantages': advantages
-210 }
208 samples = {
+209 'obs': obs,
+210 'actions': actions,
+211 'values': values[:, :-1],
+212 'log_pis': log_pis,
+213 'advantages': advantages
+214 }
214 samples_flat = {}
-215 for k, v in samples.items():
-216 v = v.reshape(v.shape[0] * v.shape[1], *v.shape[2:])
-217 if k == 'obs':
-218 samples_flat[k] = obs_to_torch(v)
-219 else:
-220 samples_flat[k] = torch.tensor(v, device=device)
-221
-222 return samples_flat
218 samples_flat = {}
+219 for k, v in samples.items():
+220 v = v.reshape(v.shape[0] * v.shape[1], *v.shape[2:])
+221 if k == 'obs':
+222 samples_flat[k] = obs_to_torch(v)
+223 else:
+224 samples_flat[k] = torch.tensor(v, device=device)
+225
+226 return samples_flat
Train the model based on samples
224 def train(self, samples: Dict[str, torch.Tensor]):
228 def train(self, samples: Dict[str, torch.Tensor]):
234 for _ in range(self.epochs):
238 for _ in range(self.epochs()):
236 indexes = torch.randperm(self.batch_size)
240 indexes = torch.randperm(self.batch_size)
239 for start in range(0, self.batch_size, self.mini_batch_size):
243 for start in range(0, self.batch_size, self.mini_batch_size):
241 end = start + self.mini_batch_size
-242 mini_batch_indexes = indexes[start: end]
-243 mini_batch = {}
-244 for k, v in samples.items():
-245 mini_batch[k] = v[mini_batch_indexes]
245 end = start + self.mini_batch_size
+246 mini_batch_indexes = indexes[start: end]
+247 mini_batch = {}
+248 for k, v in samples.items():
+249 mini_batch[k] = v[mini_batch_indexes]
248 loss = self._calc_loss(mini_batch)
252 loss = self._calc_loss(mini_batch)
251 for pg in self.optimizer.param_groups:
-252 pg['lr'] = self.learning_rate()
255 for pg in self.optimizer.param_groups:
+256 pg['lr'] = self.learning_rate()
254 self.optimizer.zero_grad()
258 self.optimizer.zero_grad()
256 loss.backward()
260 loss.backward()
258 torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=0.5)
262 torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=0.5)
260 self.optimizer.step()
264 self.optimizer.step()
Normalize advantage function
262 @staticmethod
-263 def _normalize(adv: torch.Tensor):
266 @staticmethod
+267 def _normalize(adv: torch.Tensor):
265 return (adv - adv.mean()) / (adv.std() + 1e-8)
269 return (adv - adv.mean()) / (adv.std() + 1e-8)
Calculate total loss
267 def _calc_loss(self, samples: Dict[str, torch.Tensor]) -> torch.Tensor:
271 def _calc_loss(self, samples: Dict[str, torch.Tensor]) -> torch.Tensor:
273 sampled_return = samples['values'] + samples['advantages']
277 sampled_return = samples['values'] + samples['advantages']
279 sampled_normalized_advantage = self._normalize(samples['advantages'])
283 sampled_normalized_advantage = self._normalize(samples['advantages'])
283 pi, value = self.model(samples['obs'])
287 pi, value = self.model(samples['obs'])
286 log_pi = pi.log_prob(samples['actions'])
290 log_pi = pi.log_prob(samples['actions'])
289 policy_loss = self.ppo_loss(log_pi, samples['log_pis'], sampled_normalized_advantage, self.clip_range())
293 policy_loss = self.ppo_loss(log_pi, samples['log_pis'], sampled_normalized_advantage, self.clip_range())
295 entropy_bonus = pi.entropy()
-296 entropy_bonus = entropy_bonus.mean()
299 entropy_bonus = pi.entropy()
+300 entropy_bonus = entropy_bonus.mean()
299 value_loss = self.value_loss(value, samples['values'], sampled_return, self.clip_range())
303 value_loss = self.value_loss(value, samples['values'], sampled_return, self.clip_range())
304 loss = (policy_loss
-305 + self.value_loss_coef() * value_loss
-306 - self.entropy_bonus_coef() * entropy_bonus)
308 loss = (policy_loss
+309 + self.value_loss_coef() * value_loss
+310 - self.entropy_bonus_coef() * entropy_bonus)
309 approx_kl_divergence = .5 * ((samples['log_pis'] - log_pi) ** 2).mean()
313 approx_kl_divergence = .5 * ((samples['log_pis'] - log_pi) ** 2).mean()
312 tracker.add({'policy_reward': -policy_loss,
-313 'value_loss': value_loss,
-314 'entropy_bonus': entropy_bonus,
-315 'kl_div': approx_kl_divergence,
-316 'clip_fraction': self.ppo_loss.clip_fraction})
-317
-318 return loss
316 tracker.add({'policy_reward': -policy_loss,
+317 'value_loss': value_loss,
+318 'entropy_bonus': entropy_bonus,
+319 'kl_div': approx_kl_divergence,
+320 'clip_fraction': self.ppo_loss.clip_fraction})
+321
+322 return loss
Run training loop
320 def run_training_loop(self):
324 def run_training_loop(self):
326 tracker.set_queue('reward', 100, True)
-327 tracker.set_queue('length', 100, True)
-328
-329 for update in monit.loop(self.updates):
330 tracker.set_queue('reward', 100, True)
+331 tracker.set_queue('length', 100, True)
+332
+333 for update in monit.loop(self.updates):
331 samples = self.sample()
335 samples = self.sample()
334 self.train(samples)
338 self.train(samples)
337 tracker.save()
341 tracker.save()
339 if (update + 1) % 1_000 == 0:
-340 logger.log()
343 if (update + 1) % 1_000 == 0:
+344 logger.log()
342 def destroy(self):
346 def destroy(self):
347 for worker in self.workers:
-348 worker.child.send(("close", None))
351 for worker in self.workers:
+352 worker.child.send(("close", None))
351def main():
355def main():
353 experiment.create(name='ppo')
357 experiment.create(name='ppo')
355 configs = {
359 configs = {
357 'updates': 10000,
361 'updates': 10000,
359 'epochs': 4,
365 'epochs': IntDynamicHyperParam(8),
361 'n_workers': 8,
367 'n_workers': 8,
363 'worker_steps': 128,
369 'worker_steps': 128,
365 'batches': 4,
371 'batches': 4,
367 'value_loss_coef': FloatDynamicHyperParam(0.5),
375 'value_loss_coef': FloatDynamicHyperParam(0.5),
369 'entropy_bonus_coef': FloatDynamicHyperParam(0.01),
379 'entropy_bonus_coef': FloatDynamicHyperParam(0.01),
371 'clip_range': FloatDynamicHyperParam(0.1),
381 'clip_range': FloatDynamicHyperParam(0.1),
+⚙️ Learning rate.
373 'learning_rate': FloatDynamicHyperParam(2.5e-4, (0, 1e-3)),
-374 }
-375
-376 experiment.configs(configs)
385 'learning_rate': FloatDynamicHyperParam(1e-3, (0, 1e-3)),
+386 }
+387
+388 experiment.configs(configs)
379 m = Trainer(
-380 updates=configs['updates'],
-381 epochs=configs['epochs'],
-382 n_workers=configs['n_workers'],
-383 worker_steps=configs['worker_steps'],
-384 batches=configs['batches'],
-385 value_loss_coef=configs['value_loss_coef'],
-386 entropy_bonus_coef=configs['entropy_bonus_coef'],
-387 clip_range=configs['clip_range'],
-388 learning_rate=configs['learning_rate'],
-389 )
391 m = Trainer(
+392 updates=configs['updates'],
+393 epochs=configs['epochs'],
+394 n_workers=configs['n_workers'],
+395 worker_steps=configs['worker_steps'],
+396 batches=configs['batches'],
+397 value_loss_coef=configs['value_loss_coef'],
+398 entropy_bonus_coef=configs['entropy_bonus_coef'],
+399 clip_range=configs['clip_range'],
+400 learning_rate=configs['learning_rate'],
+401 )
392 with experiment.start():
-393 m.run_training_loop()
404 with experiment.start():
+405 m.run_training_loop()
395 m.destroy()
407 m.destroy()
Run it
399if __name__ == "__main__":
-400 main()
411if __name__ == "__main__":
+412 main()
Generalized Advantage Estimation (GAE)
13import numpy as np
15import numpy as np
16class GAE:
18class GAE:
17 def __init__(self, n_workers: int, worker_steps: int, gamma: float, lambda_: float):
-18 self.lambda_ = lambda_
-19 self.gamma = gamma
-20 self.worker_steps = worker_steps
-21 self.n_workers = n_workers
19 def __init__(self, n_workers: int, worker_steps: int, gamma: float, lambda_: float):
+20 self.lambda_ = lambda_
+21 self.gamma = gamma
+22 self.worker_steps = worker_steps
+23 self.n_workers = n_workers
23 def __call__(self, done: np.ndarray, rewards: np.ndarray, values: np.ndarray) -> np.ndarray:
25 def __call__(self, done: np.ndarray, rewards: np.ndarray, values: np.ndarray) -> np.ndarray:
56 advantages = np.zeros((self.n_workers, self.worker_steps), dtype=np.float32)
-57 last_advantage = 0
58 advantages = np.zeros((self.n_workers, self.worker_steps), dtype=np.float32)
+59 last_advantage = 0
60 last_value = values[:, -1]
-61
-62 for t in reversed(range(self.worker_steps)):
62 last_value = values[:, -1]
+63
+64 for t in reversed(range(self.worker_steps)):
64 mask = 1.0 - done[:, t]
-65 last_value = last_value * mask
-66 last_advantage = last_advantage * mask
66 mask = 1.0 - done[:, t]
+67 last_value = last_value * mask
+68 last_advantage = last_advantage * mask
68 delta = rewards[:, t] + self.gamma * last_value - values[:, t]
70 delta = rewards[:, t] + self.gamma * last_value - values[:, t]
71 last_advantage = delta + self.gamma * self.lambda_ * last_advantage
73 last_advantage = delta + self.gamma * self.lambda_ * last_advantage
80 advantages[:, t] = last_advantage
-81
-82 last_value = values[:, t]
+
82 advantages[:, t] = last_advantage
83
-84 return advantages
You can find an experiment that uses it here. The experiment uses Generalized Advantage Estimation.
+26import torch
-27
-28from labml_helpers.module import Module
-29from labml_nn.rl.ppo.gae import GAE29import torch
+30
+31from labml_helpers.module import Module
+32from labml_nn.rl.ppo.gae import GAE32class ClippedPPOLoss(Module):35class ClippedPPOLoss(Module):133 def __init__(self):
-134 super().__init__()136 def __init__(self):
+137 super().__init__()136 def __call__(self, log_pi: torch.Tensor, sampled_log_pi: torch.Tensor,
-137 advantage: torch.Tensor, clip: float) -> torch.Tensor:139 def __call__(self, log_pi: torch.Tensor, sampled_log_pi: torch.Tensor,
+140 advantage: torch.Tensor, clip: float) -> torch.Tensor:140 ratio = torch.exp(log_pi - sampled_log_pi)143 ratio = torch.exp(log_pi - sampled_log_pi)169 clipped_ratio = ratio.clamp(min=1.0 - clip,
-170 max=1.0 + clip)
-171 policy_reward = torch.min(ratio * advantage,
-172 clipped_ratio * advantage)
-173
-174 self.clip_fraction = (abs((ratio - 1.0)) > clip).to(torch.float).mean()
-175
-176 return -policy_reward.mean()172 clipped_ratio = ratio.clamp(min=1.0 - clip,
+173 max=1.0 + clip)
+174 policy_reward = torch.min(ratio * advantage,
+175 clipped_ratio * advantage)
+176
+177 self.clip_fraction = (abs((ratio - 1.0)) > clip).to(torch.float).mean()
+178
+179 return -policy_reward.mean()179class ClippedValueFunctionLoss(Module):182class ClippedValueFunctionLoss(Module):200 def __call__(self, value: torch.Tensor, sampled_value: torch.Tensor, sampled_return: torch.Tensor, clip: float):
-201 clipped_value = sampled_value + (value - sampled_value).clamp(min=-clip, max=clip)
-202 vf_loss = torch.max((value - sampled_return) ** 2, (clipped_value - sampled_return) ** 2)
-203 return 0.5 * vf_loss.mean()203 def __call__(self, value: torch.Tensor, sampled_value: torch.Tensor, sampled_return: torch.Tensor, clip: float):
+204 clipped_value = sampled_value + (value - sampled_value).clamp(min=-clip, max=clip)
+205 vf_loss = torch.max((value - sampled_return) ** 2, (clipped_value - sampled_return) ** 2)
+206 return 0.5 * vf_loss.mean()You can find an experiment that uses it here. The experiment uses Generalized Advantage Estimation.
+