mirror of
				https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
				synced 2025-10-31 10:48:49 +08:00 
			
		
		
		
	gan step
This commit is contained in:
		| @ -2,7 +2,7 @@ | |||||||
| # Generative Adversarial Networks experiment with MNIST | # Generative Adversarial Networks experiment with MNIST | ||||||
| """ | """ | ||||||
|  |  | ||||||
| from typing import Optional | from typing import Any | ||||||
|  |  | ||||||
| import torch | import torch | ||||||
| import torch.nn as nn | import torch.nn as nn | ||||||
| @ -16,7 +16,7 @@ from labml_helpers.datasets.mnist import MNISTConfigs | |||||||
| from labml_helpers.device import DeviceConfigs | from labml_helpers.device import DeviceConfigs | ||||||
| from labml_helpers.module import Module | from labml_helpers.module import Module | ||||||
| from labml_helpers.optimizer import OptimizerConfigs | from labml_helpers.optimizer import OptimizerConfigs | ||||||
| from labml_helpers.train_valid import MODE_STATE, BatchStepProtocol, TrainValidConfigs, hook_model_outputs | from labml_helpers.train_valid import TrainValidConfigs, hook_model_outputs, BatchIndex | ||||||
| from labml_nn.gan import DiscriminatorLogitsLoss, GeneratorLogitsLoss | from labml_nn.gan import DiscriminatorLogitsLoss, GeneratorLogitsLoss | ||||||
|  |  | ||||||
|  |  | ||||||
| @ -79,89 +79,6 @@ class Discriminator(Module): | |||||||
|         return self.layers(x.view(x.shape[0], -1)) |         return self.layers(x.view(x.shape[0], -1)) | ||||||
|  |  | ||||||
|  |  | ||||||
| class GANBatchStep(BatchStepProtocol): |  | ||||||
|     def __init__(self, *, |  | ||||||
|                  discriminator: Module, |  | ||||||
|                  generator: Module, |  | ||||||
|                  discriminator_optimizer: Optional[torch.optim.Adam], |  | ||||||
|                  generator_optimizer: Optional[torch.optim.Adam], |  | ||||||
|                  discriminator_loss: DiscriminatorLogitsLoss, |  | ||||||
|                  generator_loss: GeneratorLogitsLoss, |  | ||||||
|                  discriminator_k: int): |  | ||||||
|  |  | ||||||
|         self.discriminator_k = discriminator_k |  | ||||||
|         self.generator = generator |  | ||||||
|         self.discriminator = discriminator |  | ||||||
|         self.generator_loss = generator_loss |  | ||||||
|         self.discriminator_loss = discriminator_loss |  | ||||||
|         self.generator_optimizer = generator_optimizer |  | ||||||
|         self.discriminator_optimizer = discriminator_optimizer |  | ||||||
|  |  | ||||||
|         hook_model_outputs(self.generator, 'generator') |  | ||||||
|         hook_model_outputs(self.discriminator, 'discriminator') |  | ||||||
|         tracker.set_scalar("loss.generator.*", True) |  | ||||||
|         tracker.set_scalar("loss.discriminator.*", True) |  | ||||||
|         tracker.set_image("generated", True, 1 / 100) |  | ||||||
|  |  | ||||||
|     def prepare_for_iteration(self): |  | ||||||
|         if MODE_STATE.is_train: |  | ||||||
|             self.generator.train() |  | ||||||
|             self.discriminator.train() |  | ||||||
|         else: |  | ||||||
|             self.generator.eval() |  | ||||||
|             self.discriminator.eval() |  | ||||||
|  |  | ||||||
|     def process(self, batch: any, state: any): |  | ||||||
|         device = self.discriminator.device |  | ||||||
|         data, target = batch |  | ||||||
|         data, target = data.to(device), target.to(device) |  | ||||||
|  |  | ||||||
|         # Train the discriminator |  | ||||||
|         with monit.section("discriminator"): |  | ||||||
|             for _ in range(self.discriminator_k): |  | ||||||
|                 latent = torch.randn(data.shape[0], 100, device=device) |  | ||||||
|                 if MODE_STATE.is_train: |  | ||||||
|                     self.discriminator_optimizer.zero_grad() |  | ||||||
|                 logits_true = self.discriminator(data) |  | ||||||
|                 logits_false = self.discriminator(self.generator(latent).detach()) |  | ||||||
|                 loss_true, loss_false = self.discriminator_loss(logits_true, logits_false) |  | ||||||
|                 loss = loss_true + loss_false |  | ||||||
|  |  | ||||||
|                 # Log stuff |  | ||||||
|                 tracker.add("loss.discriminator.true.", loss_true) |  | ||||||
|                 tracker.add("loss.discriminator.false.", loss_false) |  | ||||||
|                 tracker.add("loss.discriminator.", loss) |  | ||||||
|  |  | ||||||
|                 # Train |  | ||||||
|                 if MODE_STATE.is_train: |  | ||||||
|                     loss.backward() |  | ||||||
|                     if MODE_STATE.is_log_parameters: |  | ||||||
|                         pytorch_utils.store_model_indicators(self.discriminator, 'discriminator') |  | ||||||
|                     self.discriminator_optimizer.step() |  | ||||||
|  |  | ||||||
|         # Train the generator |  | ||||||
|         with monit.section("generator"): |  | ||||||
|             latent = torch.randn(data.shape[0], 100, device=device) |  | ||||||
|             if MODE_STATE.is_train: |  | ||||||
|                 self.generator_optimizer.zero_grad() |  | ||||||
|             generated_images = self.generator(latent) |  | ||||||
|             logits = self.discriminator(generated_images) |  | ||||||
|             loss = self.generator_loss(logits) |  | ||||||
|  |  | ||||||
|             # Log stuff |  | ||||||
|             tracker.add('generated', generated_images[0:5]) |  | ||||||
|             tracker.add("loss.generator.", loss) |  | ||||||
|  |  | ||||||
|             # Train |  | ||||||
|             if MODE_STATE.is_train: |  | ||||||
|                 loss.backward() |  | ||||||
|                 if MODE_STATE.is_log_parameters: |  | ||||||
|                     pytorch_utils.store_model_indicators(self.generator, 'generator') |  | ||||||
|                 self.generator_optimizer.step() |  | ||||||
|  |  | ||||||
|         return {'samples': len(data)}, None |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class Configs(MNISTConfigs, TrainValidConfigs): | class Configs(MNISTConfigs, TrainValidConfigs): | ||||||
|     device: torch.device = DeviceConfigs() |     device: torch.device = DeviceConfigs() | ||||||
|     epochs: int = 10 |     epochs: int = 10 | ||||||
| @ -173,10 +90,77 @@ class Configs(MNISTConfigs, TrainValidConfigs): | |||||||
|     discriminator_optimizer: torch.optim.Adam |     discriminator_optimizer: torch.optim.Adam | ||||||
|     generator_loss: GeneratorLogitsLoss |     generator_loss: GeneratorLogitsLoss | ||||||
|     discriminator_loss: DiscriminatorLogitsLoss |     discriminator_loss: DiscriminatorLogitsLoss | ||||||
|     batch_step = 'gan_batch_step' |  | ||||||
|     label_smoothing: float = 0.2 |     label_smoothing: float = 0.2 | ||||||
|     discriminator_k: int = 1 |     discriminator_k: int = 1 | ||||||
|  |  | ||||||
|  |     log_params_updates: int = 2 ** 32  # 0 if not | ||||||
|  |  | ||||||
|  |     def init(self): | ||||||
|  |         self.state_modules = [] | ||||||
|  |         self.generator = Generator().to(self.device) | ||||||
|  |         self.discriminator = Discriminator().to(self.device) | ||||||
|  |         self.generator_loss = GeneratorLogitsLoss(self.label_smoothing).to(self.device) | ||||||
|  |         self.discriminator_loss = DiscriminatorLogitsLoss(self.label_smoothing).to(self.device) | ||||||
|  |  | ||||||
|  |         hook_model_outputs(self.mode, self.generator, 'generator') | ||||||
|  |         hook_model_outputs(self.mode, self.discriminator, 'discriminator') | ||||||
|  |         tracker.set_scalar("loss.generator.*", True) | ||||||
|  |         tracker.set_scalar("loss.discriminator.*", True) | ||||||
|  |         tracker.set_image("generated", True, 1 / 100) | ||||||
|  |  | ||||||
|  |     def step(self, batch: Any, batch_idx: BatchIndex): | ||||||
|  |         self.generator.train(self.mode.is_train) | ||||||
|  |         self.discriminator.train(self.mode.is_train) | ||||||
|  |  | ||||||
|  |         data, target = batch[0].to(self.device), batch[1].to(self.device) | ||||||
|  |  | ||||||
|  |         # Increment step in training mode | ||||||
|  |         if self.mode.is_train: | ||||||
|  |             tracker.add_global_step(len(data)) | ||||||
|  |  | ||||||
|  |         # Train the discriminator | ||||||
|  |         with monit.section("discriminator"): | ||||||
|  |             for _ in range(self.discriminator_k): | ||||||
|  |                 latent = torch.randn(data.shape[0], 100, device=self.device) | ||||||
|  |                 logits_true = self.discriminator(data) | ||||||
|  |                 logits_false = self.discriminator(self.generator(latent).detach()) | ||||||
|  |                 loss_true, loss_false = self.discriminator_loss(logits_true, logits_false) | ||||||
|  |                 loss = loss_true + loss_false | ||||||
|  |  | ||||||
|  |                 # Log stuff | ||||||
|  |                 tracker.add("loss.discriminator.true.", loss_true) | ||||||
|  |                 tracker.add("loss.discriminator.false.", loss_false) | ||||||
|  |                 tracker.add("loss.discriminator.", loss) | ||||||
|  |  | ||||||
|  |                 # Train | ||||||
|  |                 if self.mode.is_train: | ||||||
|  |                     self.discriminator_optimizer.zero_grad() | ||||||
|  |                     loss.backward() | ||||||
|  |                     if batch_idx.is_interval(self.log_params_updates): | ||||||
|  |                         pytorch_utils.store_model_indicators(self.discriminator, 'discriminator') | ||||||
|  |                     self.discriminator_optimizer.step() | ||||||
|  |  | ||||||
|  |         # Train the generator | ||||||
|  |         with monit.section("generator"): | ||||||
|  |             latent = torch.randn(data.shape[0], 100, device=self.device) | ||||||
|  |             generated_images = self.generator(latent) | ||||||
|  |             logits = self.discriminator(generated_images) | ||||||
|  |             loss = self.generator_loss(logits) | ||||||
|  |  | ||||||
|  |             # Log stuff | ||||||
|  |             tracker.add('generated', generated_images[0:5]) | ||||||
|  |             tracker.add("loss.generator.", loss) | ||||||
|  |  | ||||||
|  |             # Train | ||||||
|  |             if self.mode.is_train: | ||||||
|  |                 self.generator_optimizer.zero_grad() | ||||||
|  |                 loss.backward() | ||||||
|  |                 if batch_idx.is_interval(self.log_params_updates): | ||||||
|  |                     pytorch_utils.store_model_indicators(self.generator, 'generator') | ||||||
|  |                 self.generator_optimizer.step() | ||||||
|  |  | ||||||
|  |         tracker.save() | ||||||
|  |  | ||||||
|  |  | ||||||
| @option(Configs.dataset_transforms) | @option(Configs.dataset_transforms) | ||||||
| def mnist_transforms(): | def mnist_transforms(): | ||||||
| @ -186,23 +170,6 @@ def mnist_transforms(): | |||||||
|     ]) |     ]) | ||||||
|  |  | ||||||
|  |  | ||||||
| @option(Configs.batch_step) |  | ||||||
| def gan_batch_step(c: Configs): |  | ||||||
|     return GANBatchStep(discriminator=c.discriminator, |  | ||||||
|                         generator=c.generator, |  | ||||||
|                         discriminator_optimizer=c.discriminator_optimizer, |  | ||||||
|                         generator_optimizer=c.generator_optimizer, |  | ||||||
|                         discriminator_loss=c.discriminator_loss, |  | ||||||
|                         generator_loss=c.generator_loss, |  | ||||||
|                         discriminator_k=c.discriminator_k) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| calculate(Configs.generator, 'mlp', lambda c: Generator().to(c.device)) |  | ||||||
| calculate(Configs.discriminator, 'mlp', lambda c: Discriminator().to(c.device)) |  | ||||||
| calculate(Configs.generator_loss, lambda c: GeneratorLogitsLoss(c.label_smoothing).to(c.device)) |  | ||||||
| calculate(Configs.discriminator_loss, lambda c: DiscriminatorLogitsLoss(c.label_smoothing).to(c.device)) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @option(Configs.discriminator_optimizer) | @option(Configs.discriminator_optimizer) | ||||||
| def _discriminator_optimizer(c: Configs): | def _discriminator_optimizer(c: Configs): | ||||||
|     opt_conf = OptimizerConfigs() |     opt_conf = OptimizerConfigs() | ||||||
|  | |||||||
		Reference in New Issue
	
	Block a user
	 Varuna Jayasiri
					Varuna Jayasiri