mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-16 02:41:38 +08:00
comments
This commit is contained in:
@ -229,7 +229,8 @@ class ReplayBuffer:
|
|||||||
self.max_size = max_size
|
self.max_size = max_size
|
||||||
self.data = []
|
self.data = []
|
||||||
|
|
||||||
def push_and_pop(self, data):
|
def push_and_pop(self, data: torch.Tensor):
|
||||||
|
"""Add/retrieve an image"""
|
||||||
data = data.detach()
|
data = data.detach()
|
||||||
res = []
|
res = []
|
||||||
for element in data:
|
for element in data:
|
||||||
@ -247,14 +248,17 @@ class ReplayBuffer:
|
|||||||
|
|
||||||
|
|
||||||
class Configs(BaseConfigs):
|
class Configs(BaseConfigs):
|
||||||
"""This is the configurations for the experiment"""
|
"""## Configurations"""
|
||||||
|
|
||||||
|
# `DeviceConfigs` will pick a GPU if available
|
||||||
device: torch.device = DeviceConfigs()
|
device: torch.device = DeviceConfigs()
|
||||||
|
|
||||||
|
# Hyper-parameters
|
||||||
epochs: int = 200
|
epochs: int = 200
|
||||||
dataset_name: str = 'monet2photo'
|
dataset_name: str = 'monet2photo'
|
||||||
batch_size: int = 1
|
batch_size: int = 1
|
||||||
|
|
||||||
data_loader_workers = 8
|
data_loader_workers = 8
|
||||||
is_save_models = True
|
|
||||||
|
|
||||||
learning_rate = 0.0002
|
learning_rate = 0.0002
|
||||||
adam_betas = (0.5, 0.999)
|
adam_betas = (0.5, 0.999)
|
||||||
@ -268,30 +272,35 @@ class Configs(BaseConfigs):
|
|||||||
cycle_loss = torch.nn.L1Loss()
|
cycle_loss = torch.nn.L1Loss()
|
||||||
identity_loss = torch.nn.L1Loss()
|
identity_loss = torch.nn.L1Loss()
|
||||||
|
|
||||||
batch_step = 'cycle_gan_batch_step'
|
# Image dimensions
|
||||||
|
|
||||||
img_height = 256
|
img_height = 256
|
||||||
img_width = 256
|
img_width = 256
|
||||||
img_channels = 3
|
img_channels = 3
|
||||||
|
|
||||||
|
# Number of residual blocks in the generator
|
||||||
n_residual_blocks = 9
|
n_residual_blocks = 9
|
||||||
|
|
||||||
|
# Loss coefficients
|
||||||
cyclic_loss_coefficient = 10.0
|
cyclic_loss_coefficient = 10.0
|
||||||
identity_loss_coefficient = 5.
|
identity_loss_coefficient = 5.
|
||||||
|
|
||||||
sample_interval = 500
|
sample_interval = 500
|
||||||
|
|
||||||
|
# Models
|
||||||
generator_xy: GeneratorResNet
|
generator_xy: GeneratorResNet
|
||||||
generator_yx: GeneratorResNet
|
generator_yx: GeneratorResNet
|
||||||
discriminator_x: Discriminator
|
discriminator_x: Discriminator
|
||||||
discriminator_y: Discriminator
|
discriminator_y: Discriminator
|
||||||
|
|
||||||
|
# Optimizers
|
||||||
generator_optimizer: torch.optim.Adam
|
generator_optimizer: torch.optim.Adam
|
||||||
discriminator_optimizer: torch.optim.Adam
|
discriminator_optimizer: torch.optim.Adam
|
||||||
|
|
||||||
|
# Learning rate schedules
|
||||||
generator_lr_scheduler: torch.optim.lr_scheduler.LambdaLR
|
generator_lr_scheduler: torch.optim.lr_scheduler.LambdaLR
|
||||||
discriminator_lr_scheduler: torch.optim.lr_scheduler.LambdaLR
|
discriminator_lr_scheduler: torch.optim.lr_scheduler.LambdaLR
|
||||||
|
|
||||||
|
# Data loaders
|
||||||
dataloader: DataLoader
|
dataloader: DataLoader
|
||||||
valid_dataloader: DataLoader
|
valid_dataloader: DataLoader
|
||||||
|
|
||||||
@ -311,12 +320,16 @@ class Configs(BaseConfigs):
|
|||||||
gen_x = make_grid(gen_x, nrow=5, normalize=True)
|
gen_x = make_grid(gen_x, nrow=5, normalize=True)
|
||||||
gen_y = make_grid(gen_y, nrow=5, normalize=True)
|
gen_y = make_grid(gen_y, nrow=5, normalize=True)
|
||||||
|
|
||||||
# arange images along y-axis
|
# Arrange images along y-axis
|
||||||
image_grid = torch.cat((data_x, gen_y, data_y, gen_x), 1)
|
image_grid = torch.cat((data_x, gen_y, data_y, gen_x), 1)
|
||||||
|
|
||||||
|
# Save grid
|
||||||
save_image(image_grid, f"images/{self.dataset_name}/{n}.png", normalize=False)
|
save_image(image_grid, f"images/{self.dataset_name}/{n}.png", normalize=False)
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
"""
|
"""
|
||||||
|
## Training
|
||||||
|
|
||||||
We aim to solve:
|
We aim to solve:
|
||||||
$$G^{*}, F^{*} = \arg \min_{G,F} \max_{D_X, D_Y} \mathcal{L}(G, F, D_X, D_Y)$$
|
$$G^{*}, F^{*} = \arg \min_{G,F} \max_{D_X, D_Y} \mathcal{L}(G, F, D_X, D_Y)$$
|
||||||
|
|
||||||
@ -447,7 +460,7 @@ class Configs(BaseConfigs):
|
|||||||
|
|
||||||
def optimize_generators(self, data_x: torch.Tensor, data_y: torch.Tensor, true_labels: torch.Tensor):
|
def optimize_generators(self, data_x: torch.Tensor, data_y: torch.Tensor, true_labels: torch.Tensor):
|
||||||
"""
|
"""
|
||||||
Optimize the generators with identity, gan and cycle losses.
|
### Optimize the generators with identity, gan and cycle losses.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Change to training mode
|
# Change to training mode
|
||||||
@ -501,7 +514,7 @@ class Configs(BaseConfigs):
|
|||||||
gen_x: torch.Tensor, gen_y: torch.Tensor,
|
gen_x: torch.Tensor, gen_y: torch.Tensor,
|
||||||
true_labels: torch.Tensor, false_labels: torch.Tensor):
|
true_labels: torch.Tensor, false_labels: torch.Tensor):
|
||||||
"""
|
"""
|
||||||
Optimize the discriminators with gan loss.
|
### Optimize the discriminators with gan loss.
|
||||||
"""
|
"""
|
||||||
# GAN Loss
|
# GAN Loss
|
||||||
# \begin{align}
|
# \begin{align}
|
||||||
@ -529,7 +542,7 @@ class Configs(BaseConfigs):
|
|||||||
Configs.generator_lr_scheduler, Configs.discriminator_lr_scheduler])
|
Configs.generator_lr_scheduler, Configs.discriminator_lr_scheduler])
|
||||||
def setup_models(self: Configs):
|
def setup_models(self: Configs):
|
||||||
"""
|
"""
|
||||||
This method setup the models
|
## setup the models
|
||||||
"""
|
"""
|
||||||
input_shape = (self.img_channels, self.img_height, self.img_width)
|
input_shape = (self.img_channels, self.img_height, self.img_width)
|
||||||
|
|
||||||
@ -560,7 +573,7 @@ def setup_models(self: Configs):
|
|||||||
@configs.setup([Configs.dataloader, Configs.valid_dataloader])
|
@configs.setup([Configs.dataloader, Configs.valid_dataloader])
|
||||||
def setup_dataloader(self: Configs):
|
def setup_dataloader(self: Configs):
|
||||||
"""
|
"""
|
||||||
This method setup the data loaders
|
## setup the data loaders
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Location of the dataset
|
# Location of the dataset
|
||||||
|
Reference in New Issue
Block a user