diff --git a/docs/rl/ppo/experiment.html b/docs/rl/ppo/experiment.html index bf46a78b..a826c552 100644 --- a/docs/rl/ppo/experiment.html +++ b/docs/rl/ppo/experiment.html @@ -75,22 +75,24 @@
This experiment trains Proximal Policy Optimization (PPO) agent Atari Breakout game on OpenAI Gym. It runs the game environments on multiple processes to sample efficiently.
+13from typing import Dict
-14
-15import numpy as np
-16import torch
-17from torch import nn
-18from torch import optim
-19from torch.distributions import Categorical
-20
-21from labml import monit, tracker, logger, experiment
-22from labml.configs import FloatDynamicHyperParam
-23from labml_helpers.module import Module
-24from labml_nn.rl.game import Worker
-25from labml_nn.rl.ppo import ClippedPPOLoss, ClippedValueFunctionLoss
-26from labml_nn.rl.ppo.gae import GAE16from typing import Dict
+17
+18import numpy as np
+19import torch
+20from torch import nn
+21from torch import optim
+22from torch.distributions import Categorical
+23
+24from labml import monit, tracker, logger, experiment
+25from labml.configs import FloatDynamicHyperParam, IntDynamicHyperParam
+26from labml_helpers.module import Module
+27from labml_nn.rl.game import Worker
+28from labml_nn.rl.ppo import ClippedPPOLoss, ClippedValueFunctionLoss
+29from labml_nn.rl.ppo.gae import GAESelect device
29if torch.cuda.is_available():
-30    device = torch.device("cuda:0")
-31else:
-32    device = torch.device("cpu")32if torch.cuda.is_available():
+33    device = torch.device("cuda:0")
+34else:
+35    device = torch.device("cpu")35class Model(Module):38class Model(Module):40    def __init__(self):
-41        super().__init__()43    def __init__(self):
+44        super().__init__()45        self.conv1 = nn.Conv2d(in_channels=4, out_channels=32, kernel_size=8, stride=4)48        self.conv1 = nn.Conv2d(in_channels=4, out_channels=32, kernel_size=8, stride=4)49        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2)52        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2)53        self.conv3 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1)56        self.conv3 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1)58        self.lin = nn.Linear(in_features=7 * 7 * 64, out_features=512)61        self.lin = nn.Linear(in_features=7 * 7 * 64, out_features=512)A fully connected layer to get logits for $\pi$
61        self.pi_logits = nn.Linear(in_features=512, out_features=4)64        self.pi_logits = nn.Linear(in_features=512, out_features=4)A fully connected layer to get value function
64        self.value = nn.Linear(in_features=512, out_features=1)67        self.value = nn.Linear(in_features=512, out_features=1)67        self.activation = nn.ReLU()70        self.activation = nn.ReLU()69    def __call__(self, obs: torch.Tensor):
-70        h = self.activation(self.conv1(obs))
-71        h = self.activation(self.conv2(h))
-72        h = self.activation(self.conv3(h))
-73        h = h.reshape((-1, 7 * 7 * 64))
-74
-75        h = self.activation(self.lin(h))
-76
-77        pi = Categorical(logits=self.pi_logits(h))
-78        value = self.value(h).reshape(-1)
+                72    def __call__(self, obs: torch.Tensor):
+73        h = self.activation(self.conv1(obs))
+74        h = self.activation(self.conv2(h))
+75        h = self.activation(self.conv3(h))
+76        h = h.reshape((-1, 7 * 7 * 64))
+77
+78        h = self.activation(self.lin(h))
 79
-80        return pi, value
+80        pi = Categorical(logits=self.pi_logits(h))
+81        value = self.value(h).reshape(-1)
+82
+83        return pi, valueScale observations from [0, 255] to [0, 1]
83def obs_to_torch(obs: np.ndarray) -> torch.Tensor:86def obs_to_torch(obs: np.ndarray) -> torch.Tensor:85    return torch.tensor(obs, dtype=torch.float32, device=device) / 255.88    return torch.tensor(obs, dtype=torch.float32, device=device) / 255.88class Trainer:91class Trainer:93    def __init__(self, *,
-94                 updates: int, epochs: int, n_workers: int, worker_steps: int, batches: int,
-95                 value_loss_coef: FloatDynamicHyperParam,
-96                 entropy_bonus_coef: FloatDynamicHyperParam,
-97                 clip_range: FloatDynamicHyperParam,
-98                 learning_rate: FloatDynamicHyperParam,
-99                 ):96    def __init__(self, *,
+97                 updates: int, epochs: IntDynamicHyperParam,
+98                 n_workers: int, worker_steps: int, batches: int,
+99                 value_loss_coef: FloatDynamicHyperParam,
+100                 entropy_bonus_coef: FloatDynamicHyperParam,
+101                 clip_range: FloatDynamicHyperParam,
+102                 learning_rate: FloatDynamicHyperParam,
+103                 ):number of updates
103        self.updates = updates107        self.updates = updatesnumber of epochs to train the model with sampled data
105        self.epochs = epochs109        self.epochs = epochsnumber of worker processes
107        self.n_workers = n_workers111        self.n_workers = n_workersnumber of steps to run on each process for a single update
109        self.worker_steps = worker_steps113        self.worker_steps = worker_stepsnumber of mini batches
111        self.batches = batches115        self.batches = batchestotal number of samples for a single update
113        self.batch_size = self.n_workers * self.worker_steps117        self.batch_size = self.n_workers * self.worker_stepssize of a mini batch
115        self.mini_batch_size = self.batch_size // self.batches
-116        assert (self.batch_size % self.batches == 0)119        self.mini_batch_size = self.batch_size // self.batches
+120        assert (self.batch_size % self.batches == 0)Value loss coefficient
119        self.value_loss_coef = value_loss_coef123        self.value_loss_coef = value_loss_coefEntropy bonus coefficient
121        self.entropy_bonus_coef = entropy_bonus_coef125        self.entropy_bonus_coef = entropy_bonus_coefClipping range
124        self.clip_range = clip_range128        self.clip_range = clip_rangeLearning rate
126        self.learning_rate = learning_rate130        self.learning_rate = learning_ratecreate workers
131        self.workers = [Worker(47 + i) for i in range(self.n_workers)]135        self.workers = [Worker(47 + i) for i in range(self.n_workers)]initialize tensors for observations
134        self.obs = np.zeros((self.n_workers, 4, 84, 84), dtype=np.uint8)
-135        for worker in self.workers:
-136            worker.child.send(("reset", None))
-137        for i, worker in enumerate(self.workers):
-138            self.obs[i] = worker.child.recv()138        self.obs = np.zeros((self.n_workers, 4, 84, 84), dtype=np.uint8)
+139        for worker in self.workers:
+140            worker.child.send(("reset", None))
+141        for i, worker in enumerate(self.workers):
+142            self.obs[i] = worker.child.recv()model
141        self.model = Model().to(device)145        self.model = Model().to(device)optimizer
144        self.optimizer = optim.Adam(self.model.parameters(), lr=2.5e-4)148        self.optimizer = optim.Adam(self.model.parameters(), lr=2.5e-4)GAE with $\gamma = 0.99$ and $\lambda = 0.95$
147        self.gae = GAE(self.n_workers, self.worker_steps, 0.99, 0.95)151        self.gae = GAE(self.n_workers, self.worker_steps, 0.99, 0.95)PPO Loss
150        self.ppo_loss = ClippedPPOLoss()154        self.ppo_loss = ClippedPPOLoss()Value Loss
153        self.value_loss = ClippedValueFunctionLoss()157        self.value_loss = ClippedValueFunctionLoss()155    def sample(self) -> Dict[str, torch.Tensor]:159    def sample(self) -> Dict[str, torch.Tensor]:160        rewards = np.zeros((self.n_workers, self.worker_steps), dtype=np.float32)
-161        actions = np.zeros((self.n_workers, self.worker_steps), dtype=np.int32)
-162        done = np.zeros((self.n_workers, self.worker_steps), dtype=np.bool)
-163        obs = np.zeros((self.n_workers, self.worker_steps, 4, 84, 84), dtype=np.uint8)
-164        log_pis = np.zeros((self.n_workers, self.worker_steps), dtype=np.float32)
-165        values = np.zeros((self.n_workers, self.worker_steps + 1), dtype=np.float32)
-166
-167        with torch.no_grad():164        rewards = np.zeros((self.n_workers, self.worker_steps), dtype=np.float32)
+165        actions = np.zeros((self.n_workers, self.worker_steps), dtype=np.int32)
+166        done = np.zeros((self.n_workers, self.worker_steps), dtype=np.bool)
+167        obs = np.zeros((self.n_workers, self.worker_steps, 4, 84, 84), dtype=np.uint8)
+168        log_pis = np.zeros((self.n_workers, self.worker_steps), dtype=np.float32)
+169        values = np.zeros((self.n_workers, self.worker_steps + 1), dtype=np.float32)
+170
+171        with torch.no_grad():sample worker_steps from each worker
169            for t in range(self.worker_steps):173            for t in range(self.worker_steps):172                obs[:, t] = self.obs176                obs[:, t] = self.obsn_workers
             175                pi, v = self.model(obs_to_torch(self.obs))
-176                values[:, t] = v.cpu().numpy()
-177                a = pi.sample()
-178                actions[:, t] = a.cpu().numpy()
-179                log_pis[:, t] = pi.log_prob(a).cpu().numpy()179                pi, v = self.model(obs_to_torch(self.obs))
+180                values[:, t] = v.cpu().numpy()
+181                a = pi.sample()
+182                actions[:, t] = a.cpu().numpy()
+183                log_pis[:, t] = pi.log_prob(a).cpu().numpy()run sampled actions on each worker
182                for w, worker in enumerate(self.workers):
-183                    worker.child.send(("step", actions[w, t]))
-184
-185                for w, worker in enumerate(self.workers):186                for w, worker in enumerate(self.workers):
+187                    worker.child.send(("step", actions[w, t]))
+188
+189                for w, worker in enumerate(self.workers):get results after executing the actions
187                    self.obs[w], rewards[w, t], done[w, t], info = worker.child.recv()191                    self.obs[w], rewards[w, t], done[w, t], info = worker.child.recv()Game to see how it works.
             192                    if info:
-193                        tracker.add('reward', info['reward'])
-194                        tracker.add('length', info['length'])196                    if info:
+197                        tracker.add('reward', info['reward'])
+198                        tracker.add('length', info['length'])Get value of after the final step
197            _, v = self.model(obs_to_torch(self.obs))
-198            values[:, self.worker_steps] = v.cpu().numpy()201            _, v = self.model(obs_to_torch(self.obs))
+202            values[:, self.worker_steps] = v.cpu().numpy()calculate advantages
201        advantages = self.gae(done, rewards, values)205        advantages = self.gae(done, rewards, values)204        samples = {
-205            'obs': obs,
-206            'actions': actions,
-207            'values': values[:, :-1],
-208            'log_pis': log_pis,
-209            'advantages': advantages
-210        }208        samples = {
+209            'obs': obs,
+210            'actions': actions,
+211            'values': values[:, :-1],
+212            'log_pis': log_pis,
+213            'advantages': advantages
+214        }214        samples_flat = {}
-215        for k, v in samples.items():
-216            v = v.reshape(v.shape[0] * v.shape[1], *v.shape[2:])
-217            if k == 'obs':
-218                samples_flat[k] = obs_to_torch(v)
-219            else:
-220                samples_flat[k] = torch.tensor(v, device=device)
-221
-222        return samples_flat218        samples_flat = {}
+219        for k, v in samples.items():
+220            v = v.reshape(v.shape[0] * v.shape[1], *v.shape[2:])
+221            if k == 'obs':
+222                samples_flat[k] = obs_to_torch(v)
+223            else:
+224                samples_flat[k] = torch.tensor(v, device=device)
+225
+226        return samples_flat224    def train(self, samples: Dict[str, torch.Tensor]):228    def train(self, samples: Dict[str, torch.Tensor]):234        for _ in range(self.epochs):238        for _ in range(self.epochs()):shuffle for each epoch
236            indexes = torch.randperm(self.batch_size)240            indexes = torch.randperm(self.batch_size)for each mini batch
239            for start in range(0, self.batch_size, self.mini_batch_size):243            for start in range(0, self.batch_size, self.mini_batch_size):get mini batch
241                end = start + self.mini_batch_size
-242                mini_batch_indexes = indexes[start: end]
-243                mini_batch = {}
-244                for k, v in samples.items():
-245                    mini_batch[k] = v[mini_batch_indexes]245                end = start + self.mini_batch_size
+246                mini_batch_indexes = indexes[start: end]
+247                mini_batch = {}
+248                for k, v in samples.items():
+249                    mini_batch[k] = v[mini_batch_indexes]train
248                loss = self._calc_loss(mini_batch)252                loss = self._calc_loss(mini_batch)Set learning rate
251                for pg in self.optimizer.param_groups:
-252                    pg['lr'] = self.learning_rate()255                for pg in self.optimizer.param_groups:
+256                    pg['lr'] = self.learning_rate()Zero out the previously calculated gradients
254                self.optimizer.zero_grad()258                self.optimizer.zero_grad()Calculate gradients
256                loss.backward()260                loss.backward()Clip gradients
258                torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=0.5)262                torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=0.5)Update parameters based on gradients
260                self.optimizer.step()264                self.optimizer.step()262    @staticmethod
-263    def _normalize(adv: torch.Tensor):266    @staticmethod
+267    def _normalize(adv: torch.Tensor):265        return (adv - adv.mean()) / (adv.std() + 1e-8)269        return (adv - adv.mean()) / (adv.std() + 1e-8)267    def _calc_loss(self, samples: Dict[str, torch.Tensor]) -> torch.Tensor:271    def _calc_loss(self, samples: Dict[str, torch.Tensor]) -> torch.Tensor:$R_t$ returns sampled from $\pi_{\theta_{OLD}}$
273        sampled_return = samples['values'] + samples['advantages']277        sampled_return = samples['values'] + samples['advantages']279        sampled_normalized_advantage = self._normalize(samples['advantages'])283        sampled_normalized_advantage = self._normalize(samples['advantages'])283        pi, value = self.model(samples['obs'])287        pi, value = self.model(samples['obs'])$-\log \pi_\theta (a_t|s_t)$, $a_t$ are actions sampled from $\pi_{\theta_{OLD}}$
286        log_pi = pi.log_prob(samples['actions'])290        log_pi = pi.log_prob(samples['actions'])Calculate policy loss
289        policy_loss = self.ppo_loss(log_pi, samples['log_pis'], sampled_normalized_advantage, self.clip_range())293        policy_loss = self.ppo_loss(log_pi, samples['log_pis'], sampled_normalized_advantage, self.clip_range())295        entropy_bonus = pi.entropy()
-296        entropy_bonus = entropy_bonus.mean()299        entropy_bonus = pi.entropy()
+300        entropy_bonus = entropy_bonus.mean()Calculate value function loss
299        value_loss = self.value_loss(value, samples['values'], sampled_return, self.clip_range())303        value_loss = self.value_loss(value, samples['values'], sampled_return, self.clip_range())304        loss = (policy_loss
-305                + self.value_loss_coef() * value_loss
-306                - self.entropy_bonus_coef() * entropy_bonus)308        loss = (policy_loss
+309                + self.value_loss_coef() * value_loss
+310                - self.entropy_bonus_coef() * entropy_bonus)for monitoring
309        approx_kl_divergence = .5 * ((samples['log_pis'] - log_pi) ** 2).mean()313        approx_kl_divergence = .5 * ((samples['log_pis'] - log_pi) ** 2).mean()Add to tracker
312        tracker.add({'policy_reward': -policy_loss,
-313                     'value_loss': value_loss,
-314                     'entropy_bonus': entropy_bonus,
-315                     'kl_div': approx_kl_divergence,
-316                     'clip_fraction': self.ppo_loss.clip_fraction})
-317
-318        return loss316        tracker.add({'policy_reward': -policy_loss,
+317                     'value_loss': value_loss,
+318                     'entropy_bonus': entropy_bonus,
+319                     'kl_div': approx_kl_divergence,
+320                     'clip_fraction': self.ppo_loss.clip_fraction})
+321
+322        return loss320    def run_training_loop(self):324    def run_training_loop(self):last 100 episode information
326        tracker.set_queue('reward', 100, True)
-327        tracker.set_queue('length', 100, True)
-328
-329        for update in monit.loop(self.updates):330        tracker.set_queue('reward', 100, True)
+331        tracker.set_queue('length', 100, True)
+332
+333        for update in monit.loop(self.updates):sample with current policy
331            samples = self.sample()335            samples = self.sample()train the model
334            self.train(samples)338            self.train(samples)Save tracked indicators.
337            tracker.save()341            tracker.save()Add a new line to the screen periodically
339            if (update + 1) % 1_000 == 0:
-340                logger.log()343            if (update + 1) % 1_000 == 0:
+344                logger.log()Stop the workers
342    def destroy(self):346    def destroy(self):347        for worker in self.workers:
-348            worker.child.send(("close", None))351        for worker in self.workers:
+352            worker.child.send(("close", None))351def main():355def main():Create the experiment
353    experiment.create(name='ppo')357    experiment.create(name='ppo')Configurations
355    configs = {359    configs = {number of updates
+Number of updates
357        'updates': 10000,361        'updates': 10000,number of epochs to train the model with sampled data
+⚙️ Number of epochs to train the model with sampled data.
+You can change this while the experiment is running.
+
359        'epochs': 4,365        'epochs': IntDynamicHyperParam(8),number of worker processes
+Number of worker processes
361        'n_workers': 8,367        'n_workers': 8,number of steps to run on each process for a single update
+Number of steps to run on each process for a single update
363        'worker_steps': 128,369        'worker_steps': 128,number of mini batches
+Number of mini batches
365        'batches': 4,371        'batches': 4,Value loss coefficient
+⚙️ Value loss coefficient.
+You can change this while the experiment is running.
+
367        'value_loss_coef': FloatDynamicHyperParam(0.5),375        'value_loss_coef': FloatDynamicHyperParam(0.5),Entropy bonus coefficient
+⚙️ Entropy bonus coefficient.
+You can change this while the experiment is running.
+
369        'entropy_bonus_coef': FloatDynamicHyperParam(0.01),379        'entropy_bonus_coef': FloatDynamicHyperParam(0.01),Clip range
+⚙️ Clip range.
371        'clip_range': FloatDynamicHyperParam(0.1),381        'clip_range': FloatDynamicHyperParam(0.1),Learning rate
+You can change this while the experiment is running.
+
+⚙️ Learning rate.
373        'learning_rate': FloatDynamicHyperParam(2.5e-4, (0, 1e-3)),
-374    }
-375
-376    experiment.configs(configs)385        'learning_rate': FloatDynamicHyperParam(1e-3, (0, 1e-3)),
+386    }
+387
+388    experiment.configs(configs)Initialize the trainer
379    m = Trainer(
-380        updates=configs['updates'],
-381        epochs=configs['epochs'],
-382        n_workers=configs['n_workers'],
-383        worker_steps=configs['worker_steps'],
-384        batches=configs['batches'],
-385        value_loss_coef=configs['value_loss_coef'],
-386        entropy_bonus_coef=configs['entropy_bonus_coef'],
-387        clip_range=configs['clip_range'],
-388        learning_rate=configs['learning_rate'],
-389    )391    m = Trainer(
+392        updates=configs['updates'],
+393        epochs=configs['epochs'],
+394        n_workers=configs['n_workers'],
+395        worker_steps=configs['worker_steps'],
+396        batches=configs['batches'],
+397        value_loss_coef=configs['value_loss_coef'],
+398        entropy_bonus_coef=configs['entropy_bonus_coef'],
+399        clip_range=configs['clip_range'],
+400        learning_rate=configs['learning_rate'],
+401    )Run and monitor the experiment
392    with experiment.start():
-393        m.run_training_loop()404    with experiment.start():
+405        m.run_training_loop()Stop the workers
395    m.destroy()407    m.destroy()399if __name__ == "__main__":
-400    main()411if __name__ == "__main__":
+412    main()This is a PyTorch implementation of paper Generalized Advantage Estimation.
+You can find an experiment that uses it here.
13import numpy as np15import numpy as np16class GAE:18class GAE:17    def __init__(self, n_workers: int, worker_steps: int, gamma: float, lambda_: float):
-18        self.lambda_ = lambda_
-19        self.gamma = gamma
-20        self.worker_steps = worker_steps
-21        self.n_workers = n_workers19    def __init__(self, n_workers: int, worker_steps: int, gamma: float, lambda_: float):
+20        self.lambda_ = lambda_
+21        self.gamma = gamma
+22        self.worker_steps = worker_steps
+23        self.n_workers = n_workers23    def __call__(self, done: np.ndarray, rewards: np.ndarray, values: np.ndarray) -> np.ndarray:25    def __call__(self, done: np.ndarray, rewards: np.ndarray, values: np.ndarray) -> np.ndarray:advantages table
56        advantages = np.zeros((self.n_workers, self.worker_steps), dtype=np.float32)
-57        last_advantage = 058        advantages = np.zeros((self.n_workers, self.worker_steps), dtype=np.float32)
+59        last_advantage = 0$V(s_{t+1})$
60        last_value = values[:, -1]
-61
-62        for t in reversed(range(self.worker_steps)):62        last_value = values[:, -1]
+63
+64        for t in reversed(range(self.worker_steps)):mask if episode completed after step $t$
64            mask = 1.0 - done[:, t]
-65            last_value = last_value * mask
-66            last_advantage = last_advantage * mask66            mask = 1.0 - done[:, t]
+67            last_value = last_value * mask
+68            last_advantage = last_advantage * mask$\delta_t$
68            delta = rewards[:, t] + self.gamma * last_value - values[:, t]70            delta = rewards[:, t] + self.gamma * last_value - values[:, t]$\hat{A_t} = \delta_t + \gamma \lambda \hat{A_{t+1}}$
71            last_advantage = delta + self.gamma * self.lambda_ * last_advantage73            last_advantage = delta + self.gamma * self.lambda_ * last_advantage80            advantages[:, t] = last_advantage
-81
-82            last_value = values[:, t]
+                82            advantages[:, t] = last_advantage
 83
-84        return advantages
+84            last_value = values[:, t]
+85
+86        return advantagesYou can find an experiment that uses it here. The experiment uses Generalized Advantage Estimation.
+26import torch
-27
-28from labml_helpers.module import Module
-29from labml_nn.rl.ppo.gae import GAE29import torch
+30
+31from labml_helpers.module import Module
+32from labml_nn.rl.ppo.gae import GAE32class ClippedPPOLoss(Module):35class ClippedPPOLoss(Module):133    def __init__(self):
-134        super().__init__()136    def __init__(self):
+137        super().__init__()136    def __call__(self, log_pi: torch.Tensor, sampled_log_pi: torch.Tensor,
-137                 advantage: torch.Tensor, clip: float) -> torch.Tensor:139    def __call__(self, log_pi: torch.Tensor, sampled_log_pi: torch.Tensor,
+140                 advantage: torch.Tensor, clip: float) -> torch.Tensor:140        ratio = torch.exp(log_pi - sampled_log_pi)143        ratio = torch.exp(log_pi - sampled_log_pi)169        clipped_ratio = ratio.clamp(min=1.0 - clip,
-170                                    max=1.0 + clip)
-171        policy_reward = torch.min(ratio * advantage,
-172                                  clipped_ratio * advantage)
-173
-174        self.clip_fraction = (abs((ratio - 1.0)) > clip).to(torch.float).mean()
-175
-176        return -policy_reward.mean()172        clipped_ratio = ratio.clamp(min=1.0 - clip,
+173                                    max=1.0 + clip)
+174        policy_reward = torch.min(ratio * advantage,
+175                                  clipped_ratio * advantage)
+176
+177        self.clip_fraction = (abs((ratio - 1.0)) > clip).to(torch.float).mean()
+178
+179        return -policy_reward.mean()179class ClippedValueFunctionLoss(Module):182class ClippedValueFunctionLoss(Module):200    def __call__(self, value: torch.Tensor, sampled_value: torch.Tensor, sampled_return: torch.Tensor, clip: float):
-201        clipped_value = sampled_value + (value - sampled_value).clamp(min=-clip, max=clip)
-202        vf_loss = torch.max((value - sampled_return) ** 2, (clipped_value - sampled_return) ** 2)
-203        return 0.5 * vf_loss.mean()203    def __call__(self, value: torch.Tensor, sampled_value: torch.Tensor, sampled_return: torch.Tensor, clip: float):
+204        clipped_value = sampled_value + (value - sampled_value).clamp(min=-clip, max=clip)
+205        vf_loss = torch.max((value - sampled_return) ** 2, (clipped_value - sampled_return) ** 2)
+206        return 0.5 * vf_loss.mean()You can find an experiment that uses it here. The experiment uses Generalized Advantage Estimation.
+