diff --git a/labml_nn/diffusion/ddpm/unet.py b/labml_nn/diffusion/ddpm/unet.py index e44a39de..4a2a0572 100644 --- a/labml_nn/diffusion/ddpm/unet.py +++ b/labml_nn/diffusion/ddpm/unet.py @@ -26,7 +26,6 @@ from typing import Optional, Tuple, Union, List import torch from torch import nn -import torch.nn.functional as F from labml_helpers.module import Module @@ -92,13 +91,14 @@ class ResidualBlock(Module): Each resolution is processed with two residual blocks. """ - def __init__(self, in_channels: int, out_channels: int, time_channels: int, n_groups: int = 32, dropout_rate: float = 0.1): + def __init__(self, in_channels: int, out_channels: int, time_channels: int, + n_groups: int = 32, dropout: float = 0.1): """ * `in_channels` is the number of input channels * `out_channels` is the number of input channels * `time_channels` is the number channels in the time step ($t$) embeddings * `n_groups` is the number of groups for [group normalization](../../normalization/group_norm/index.html) - * `dropout_rate` is the dropout rate + * `dropout` is the dropout rate """ super().__init__() # Group normalization and the first convolution layer @@ -122,6 +122,8 @@ class ResidualBlock(Module): self.time_emb = nn.Linear(time_channels, out_channels) self.time_act = Swish() + self.dropout = nn.Dropout(dropout) + def forward(self, x: torch.Tensor, t: torch.Tensor): """ * `x` has shape `[batch_size, in_channels, height, width]` @@ -132,7 +134,7 @@ class ResidualBlock(Module): # Add time embeddings h += self.time_emb(self.time_act(t))[:, :, None, None] # Second convolution layer - h = self.conv2(F.dropout(self.act2(self.norm2(h)), self.dropout_rate)) + h = self.conv2(self.dropout(self.act2(self.norm2(h)))) # Add the shortcut connection and return return h + self.shortcut(x)