diff --git a/labml_nn/gan/cycle_gan.py b/labml_nn/gan/cycle_gan.py index dd80c309..dff94354 100644 --- a/labml_nn/gan/cycle_gan.py +++ b/labml_nn/gan/cycle_gan.py @@ -39,11 +39,8 @@ class GeneratorResNet(Module): The generator is a residual network. """ - def __init__(self, input_shape: Tuple[int, int, int], n_residual_blocks: int): + def __init__(self, input_channels: int, n_residual_blocks: int): super().__init__() - # The number of channels in the input image, which is 3 for RGB images. - channels = input_shape[0] - # This first block runs a $7\times7$ convolution and maps the image to # a feature map. # The output feature map has same height and width because we have @@ -53,7 +50,7 @@ class GeneratorResNet(Module): # `inplace=True` in `ReLU` saves a little bit of memory. out_features = 64 layers = [ - nn.Conv2d(channels, out_features, kernel_size=7, padding=3, padding_mode='reflect'), + nn.Conv2d(input_channels, out_features, kernel_size=7, padding=3, padding_mode='reflect'), nn.InstanceNorm2d(out_features), nn.ReLU(inplace=True), ] @@ -88,7 +85,7 @@ class GeneratorResNet(Module): in_features = out_features # Finally we map the feature map to an RGB image - layers += [nn.Conv2d(out_features, channels, 7, padding=3, padding_mode='reflect'), nn.Tanh()] + layers += [nn.Conv2d(out_features, input_channels, 7, padding=3, padding_mode='reflect'), nn.Tanh()] # Create a sequential module with the layers self.layers = nn.Sequential(*layers) @@ -555,8 +552,8 @@ def setup_models(self: Configs): input_shape = (self.img_channels, self.img_height, self.img_width) # Create the models - self.generator_xy = GeneratorResNet(input_shape, self.n_residual_blocks).to(self.device) - self.generator_yx = GeneratorResNet(input_shape, self.n_residual_blocks).to(self.device) + self.generator_xy = GeneratorResNet(self.img_channels, self.n_residual_blocks).to(self.device) + self.generator_yx = GeneratorResNet(self.img_channels, self.n_residual_blocks).to(self.device) self.discriminator_x = Discriminator(input_shape).to(self.device) self.discriminator_y = Discriminator(input_shape).to(self.device) @@ -680,9 +677,6 @@ def sample(): # Image transformations transforms_ = [ - transforms.Resize(int(conf.img_height * 1.12), Image.BICUBIC), - transforms.RandomCrop((conf.img_height, conf.img_width)), - transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]