mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-15 18:27:20 +08:00
comments
This commit is contained in:
@ -229,7 +229,8 @@ class ReplayBuffer:
|
||||
self.max_size = max_size
|
||||
self.data = []
|
||||
|
||||
def push_and_pop(self, data):
|
||||
def push_and_pop(self, data: torch.Tensor):
|
||||
"""Add/retrieve an image"""
|
||||
data = data.detach()
|
||||
res = []
|
||||
for element in data:
|
||||
@ -247,14 +248,17 @@ class ReplayBuffer:
|
||||
|
||||
|
||||
class Configs(BaseConfigs):
|
||||
"""This is the configurations for the experiment"""
|
||||
"""## Configurations"""
|
||||
|
||||
# `DeviceConfigs` will pick a GPU if available
|
||||
device: torch.device = DeviceConfigs()
|
||||
|
||||
# Hyper-parameters
|
||||
epochs: int = 200
|
||||
dataset_name: str = 'monet2photo'
|
||||
batch_size: int = 1
|
||||
|
||||
data_loader_workers = 8
|
||||
is_save_models = True
|
||||
|
||||
learning_rate = 0.0002
|
||||
adam_betas = (0.5, 0.999)
|
||||
@ -268,30 +272,35 @@ class Configs(BaseConfigs):
|
||||
cycle_loss = torch.nn.L1Loss()
|
||||
identity_loss = torch.nn.L1Loss()
|
||||
|
||||
batch_step = 'cycle_gan_batch_step'
|
||||
|
||||
# Image dimensions
|
||||
img_height = 256
|
||||
img_width = 256
|
||||
img_channels = 3
|
||||
|
||||
# Number of residual blocks in the generator
|
||||
n_residual_blocks = 9
|
||||
|
||||
# Loss coefficients
|
||||
cyclic_loss_coefficient = 10.0
|
||||
identity_loss_coefficient = 5.
|
||||
|
||||
sample_interval = 500
|
||||
|
||||
# Models
|
||||
generator_xy: GeneratorResNet
|
||||
generator_yx: GeneratorResNet
|
||||
discriminator_x: Discriminator
|
||||
discriminator_y: Discriminator
|
||||
|
||||
# Optimizers
|
||||
generator_optimizer: torch.optim.Adam
|
||||
discriminator_optimizer: torch.optim.Adam
|
||||
|
||||
# Learning rate schedules
|
||||
generator_lr_scheduler: torch.optim.lr_scheduler.LambdaLR
|
||||
discriminator_lr_scheduler: torch.optim.lr_scheduler.LambdaLR
|
||||
|
||||
# Data loaders
|
||||
dataloader: DataLoader
|
||||
valid_dataloader: DataLoader
|
||||
|
||||
@ -311,12 +320,16 @@ class Configs(BaseConfigs):
|
||||
gen_x = make_grid(gen_x, nrow=5, normalize=True)
|
||||
gen_y = make_grid(gen_y, nrow=5, normalize=True)
|
||||
|
||||
# arange images along y-axis
|
||||
# Arrange images along y-axis
|
||||
image_grid = torch.cat((data_x, gen_y, data_y, gen_x), 1)
|
||||
|
||||
# Save grid
|
||||
save_image(image_grid, f"images/{self.dataset_name}/{n}.png", normalize=False)
|
||||
|
||||
def run(self):
|
||||
"""
|
||||
## Training
|
||||
|
||||
We aim to solve:
|
||||
$$G^{*}, F^{*} = \arg \min_{G,F} \max_{D_X, D_Y} \mathcal{L}(G, F, D_X, D_Y)$$
|
||||
|
||||
@ -447,7 +460,7 @@ class Configs(BaseConfigs):
|
||||
|
||||
def optimize_generators(self, data_x: torch.Tensor, data_y: torch.Tensor, true_labels: torch.Tensor):
|
||||
"""
|
||||
Optimize the generators with identity, gan and cycle losses.
|
||||
### Optimize the generators with identity, gan and cycle losses.
|
||||
"""
|
||||
|
||||
# Change to training mode
|
||||
@ -501,7 +514,7 @@ class Configs(BaseConfigs):
|
||||
gen_x: torch.Tensor, gen_y: torch.Tensor,
|
||||
true_labels: torch.Tensor, false_labels: torch.Tensor):
|
||||
"""
|
||||
Optimize the discriminators with gan loss.
|
||||
### Optimize the discriminators with gan loss.
|
||||
"""
|
||||
# GAN Loss
|
||||
# \begin{align}
|
||||
@ -529,7 +542,7 @@ class Configs(BaseConfigs):
|
||||
Configs.generator_lr_scheduler, Configs.discriminator_lr_scheduler])
|
||||
def setup_models(self: Configs):
|
||||
"""
|
||||
This method setup the models
|
||||
## setup the models
|
||||
"""
|
||||
input_shape = (self.img_channels, self.img_height, self.img_width)
|
||||
|
||||
@ -560,7 +573,7 @@ def setup_models(self: Configs):
|
||||
@configs.setup([Configs.dataloader, Configs.valid_dataloader])
|
||||
def setup_dataloader(self: Configs):
|
||||
"""
|
||||
This method setup the data loaders
|
||||
## setup the data loaders
|
||||
"""
|
||||
|
||||
# Location of the dataset
|
||||
|
Reference in New Issue
Block a user