mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-10-30 02:08:50 +08:00
🧹 generator input shape
This commit is contained in:
@ -39,11 +39,8 @@ class GeneratorResNet(Module):
|
|||||||
The generator is a residual network.
|
The generator is a residual network.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, input_shape: Tuple[int, int, int], n_residual_blocks: int):
|
def __init__(self, input_channels: int, n_residual_blocks: int):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# The number of channels in the input image, which is 3 for RGB images.
|
|
||||||
channels = input_shape[0]
|
|
||||||
|
|
||||||
# This first block runs a $7\times7$ convolution and maps the image to
|
# This first block runs a $7\times7$ convolution and maps the image to
|
||||||
# a feature map.
|
# a feature map.
|
||||||
# The output feature map has same height and width because we have
|
# The output feature map has same height and width because we have
|
||||||
@ -53,7 +50,7 @@ class GeneratorResNet(Module):
|
|||||||
# `inplace=True` in `ReLU` saves a little bit of memory.
|
# `inplace=True` in `ReLU` saves a little bit of memory.
|
||||||
out_features = 64
|
out_features = 64
|
||||||
layers = [
|
layers = [
|
||||||
nn.Conv2d(channels, out_features, kernel_size=7, padding=3, padding_mode='reflect'),
|
nn.Conv2d(input_channels, out_features, kernel_size=7, padding=3, padding_mode='reflect'),
|
||||||
nn.InstanceNorm2d(out_features),
|
nn.InstanceNorm2d(out_features),
|
||||||
nn.ReLU(inplace=True),
|
nn.ReLU(inplace=True),
|
||||||
]
|
]
|
||||||
@ -88,7 +85,7 @@ class GeneratorResNet(Module):
|
|||||||
in_features = out_features
|
in_features = out_features
|
||||||
|
|
||||||
# Finally we map the feature map to an RGB image
|
# Finally we map the feature map to an RGB image
|
||||||
layers += [nn.Conv2d(out_features, channels, 7, padding=3, padding_mode='reflect'), nn.Tanh()]
|
layers += [nn.Conv2d(out_features, input_channels, 7, padding=3, padding_mode='reflect'), nn.Tanh()]
|
||||||
|
|
||||||
# Create a sequential module with the layers
|
# Create a sequential module with the layers
|
||||||
self.layers = nn.Sequential(*layers)
|
self.layers = nn.Sequential(*layers)
|
||||||
@ -555,8 +552,8 @@ def setup_models(self: Configs):
|
|||||||
input_shape = (self.img_channels, self.img_height, self.img_width)
|
input_shape = (self.img_channels, self.img_height, self.img_width)
|
||||||
|
|
||||||
# Create the models
|
# Create the models
|
||||||
self.generator_xy = GeneratorResNet(input_shape, self.n_residual_blocks).to(self.device)
|
self.generator_xy = GeneratorResNet(self.img_channels, self.n_residual_blocks).to(self.device)
|
||||||
self.generator_yx = GeneratorResNet(input_shape, self.n_residual_blocks).to(self.device)
|
self.generator_yx = GeneratorResNet(self.img_channels, self.n_residual_blocks).to(self.device)
|
||||||
self.discriminator_x = Discriminator(input_shape).to(self.device)
|
self.discriminator_x = Discriminator(input_shape).to(self.device)
|
||||||
self.discriminator_y = Discriminator(input_shape).to(self.device)
|
self.discriminator_y = Discriminator(input_shape).to(self.device)
|
||||||
|
|
||||||
@ -680,9 +677,6 @@ def sample():
|
|||||||
|
|
||||||
# Image transformations
|
# Image transformations
|
||||||
transforms_ = [
|
transforms_ = [
|
||||||
transforms.Resize(int(conf.img_height * 1.12), Image.BICUBIC),
|
|
||||||
transforms.RandomCrop((conf.img_height, conf.img_width)),
|
|
||||||
transforms.RandomHorizontalFlip(),
|
|
||||||
transforms.ToTensor(),
|
transforms.ToTensor(),
|
||||||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
|
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
|
||||||
]
|
]
|
||||||
|
|||||||
Reference in New Issue
Block a user