mirror of
				https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
				synced 2025-11-01 03:43:09 +08:00 
			
		
		
		
	🥳 simple gan
This commit is contained in:
		| @ -6,35 +6,41 @@ import torch.utils.data | ||||
| from labml_helpers.module import Module | ||||
|  | ||||
|  | ||||
| def create_labels(n: int, r1: float, r2: float, device: torch.device = None): | ||||
|     return torch.empty(n, 1, requires_grad=False, device=device).uniform_(r1, r2) | ||||
|  | ||||
|  | ||||
| class DiscriminatorLogitsLoss(Module): | ||||
|     def __init__(self): | ||||
|     def __init__(self, smoothing: float = 0.2): | ||||
|         super().__init__() | ||||
|         self.loss_true = nn.BCEWithLogitsLoss() | ||||
|         self.loss_false = nn.BCEWithLogitsLoss() | ||||
|         self.register_buffer('labels_true', torch.ones(256, 1, requires_grad=False), False) | ||||
|         self.register_buffer('labels_false', torch.zeros(256, 1, requires_grad=False), False) | ||||
|         self.smoothing = smoothing | ||||
|         self.register_buffer('labels_true', create_labels(256, 1.0 - smoothing, 1.0), False) | ||||
|         self.register_buffer('labels_false', create_labels(256, 0.0, smoothing), False) | ||||
|  | ||||
|     def __call__(self, logits_true: torch.Tensor, logits_false: torch.Tensor): | ||||
|         if len(logits_true) > len(self.labels_true): | ||||
|             self.register_buffer("labels_true", | ||||
|                                  self.labels_true.new_ones(len(logits_true), 1, requires_grad=False), False) | ||||
|                                  create_labels(len(logits_true), 1.0 - self.smoothing, 1.0, logits_true.device), False) | ||||
|         if len(logits_false) > len(self.labels_false): | ||||
|             self.register_buffer("labels_false", | ||||
|                                  self.labels_false.new_zeros(len(logits_false), 1, requires_grad=False), False) | ||||
|                                  create_labels(len(logits_false), 0.0, self.smoothing, logits_false.device), False) | ||||
|  | ||||
|         return self.loss_true(logits_true, self.labels_true[:len(logits_true)]), \ | ||||
|                self.loss_false(logits_false, self.labels_false[:len(logits_false)]) | ||||
|  | ||||
|  | ||||
| class GeneratorLogitsLoss(Module): | ||||
|     def __init__(self): | ||||
|     def __init__(self, smoothing: float = 0.2): | ||||
|         super().__init__() | ||||
|         self.loss_true = nn.BCEWithLogitsLoss() | ||||
|         self.register_buffer('fake_labels', torch.ones(256, 1, requires_grad=False), False) | ||||
|         self.smoothing = smoothing | ||||
|         self.register_buffer('fake_labels', create_labels(256, 1.0 - smoothing, 1.0), False) | ||||
|  | ||||
|     def __call__(self, logits: torch.Tensor): | ||||
|         if len(logits) > len(self.fake_labels): | ||||
|             self.register_buffer("fake_labels", | ||||
|                                  self.fake_labels.new_ones(len(logits), 1, requires_grad=False), False) | ||||
|                                  create_labels(len(logits), 1.0 - self.smoothing, 1.0, logits.device), False) | ||||
|  | ||||
|         return self.loss_true(logits, self.fake_labels[:len(logits)]) | ||||
|  | ||||
| @ -4,6 +4,7 @@ import matplotlib.pyplot as plt | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| import torch.utils.data | ||||
| from torchvision import transforms | ||||
|  | ||||
| from labml import tracker, monit, experiment | ||||
| from labml.configs import option, calculate | ||||
| @ -72,7 +73,7 @@ class GANBatchStep(BatchStepProtocol): | ||||
|         self.discriminator_optimizer = discriminator_optimizer | ||||
|         tracker.set_scalar("loss.generator.*", True) | ||||
|         tracker.set_scalar("loss.discriminator.*", True) | ||||
|         tracker.set_image("generated", True) | ||||
|         tracker.set_image("generated", True, 1 / 100) | ||||
|  | ||||
|     def prepare_for_iteration(self): | ||||
|         if MODE_STATE.is_train: | ||||
| @ -92,7 +93,7 @@ class GANBatchStep(BatchStepProtocol): | ||||
|             if MODE_STATE.is_train: | ||||
|                 self.generator_optimizer.zero_grad() | ||||
|             generated_images = self.generator(latent) | ||||
|             # tracker.add('generated', generated_images[0:1]) | ||||
|             tracker.add('generated', generated_images[0:5]) | ||||
|             logits = self.discriminator(generated_images) | ||||
|             loss = self.generator_loss(logits) | ||||
|             tracker.add("loss.generator.", loss) | ||||
| @ -127,9 +128,18 @@ class Configs(MNISTConfigs, TrainValidConfigs): | ||||
|     generator: Module | ||||
|     generator_optimizer: torch.optim.Adam | ||||
|     discriminator_optimizer: torch.optim.Adam | ||||
|     discriminator_loss = DiscriminatorLogitsLoss() | ||||
|     generator_loss = GeneratorLogitsLoss() | ||||
|     generator_loss: GeneratorLogitsLoss | ||||
|     discriminator_loss: DiscriminatorLogitsLoss | ||||
|     batch_step = 'gan_batch_step' | ||||
|     label_smoothing: float = 0.2 | ||||
|  | ||||
|  | ||||
| @option(Configs.dataset_transforms) | ||||
| def mnist_transforms(): | ||||
|     return transforms.Compose([ | ||||
|         transforms.ToTensor(), | ||||
|         transforms.Normalize((0.5,), (0.5,)) | ||||
|     ]) | ||||
|  | ||||
|  | ||||
| @option(Configs.batch_step) | ||||
| @ -144,6 +154,8 @@ def gan_batch_step(c: Configs): | ||||
|  | ||||
| calculate(Configs.generator, lambda c: Generator().to(c.device)) | ||||
| calculate(Configs.discriminator, lambda c: Discriminator().to(c.device)) | ||||
| calculate(Configs.generator_loss, lambda c: GeneratorLogitsLoss(c.label_smoothing).to(c.device)) | ||||
| calculate(Configs.discriminator_loss, lambda c: DiscriminatorLogitsLoss(c.label_smoothing).to(c.device)) | ||||
|  | ||||
|  | ||||
| @option(Configs.discriminator_optimizer) | ||||
| @ -166,8 +178,12 @@ def main(): | ||||
|     experiment.configs(conf, | ||||
|                        {'generator_optimizer.learning_rate': 2.5e-4, | ||||
|                         'generator_optimizer.optimizer': 'Adam', | ||||
|                         'generator_optimizer.betas': (0.5, 0.999), | ||||
|                         'discriminator_optimizer.learning_rate': 2.5e-4, | ||||
|                         'discriminator_optimizer.optimizer': 'Adam'}, | ||||
|                         'discriminator_optimizer.optimizer': 'Adam', | ||||
|                         'discriminator_optimizer.betas': (0.5, 0.999), | ||||
|                         'label_smoothing': 0.01 | ||||
|                         }, | ||||
|                        'run') | ||||
|     with experiment.start(): | ||||
|         conf.run() | ||||
|  | ||||
		Reference in New Issue
	
	Block a user
	 Varuna Jayasiri
					Varuna Jayasiri