fix dropout ddpm.unet

This commit is contained in:
Varuna Jayasiri
2023-02-17 14:28:24 +05:30
parent 4db39b578c
commit d198a44fa2

View File

@ -26,7 +26,6 @@ from typing import Optional, Tuple, Union, List
import torch import torch
from torch import nn from torch import nn
import torch.nn.functional as F
from labml_helpers.module import Module from labml_helpers.module import Module
@ -92,13 +91,14 @@ class ResidualBlock(Module):
Each resolution is processed with two residual blocks. Each resolution is processed with two residual blocks.
""" """
def __init__(self, in_channels: int, out_channels: int, time_channels: int, n_groups: int = 32, dropout_rate: float = 0.1): def __init__(self, in_channels: int, out_channels: int, time_channels: int,
n_groups: int = 32, dropout: float = 0.1):
""" """
* `in_channels` is the number of input channels * `in_channels` is the number of input channels
* `out_channels` is the number of input channels * `out_channels` is the number of input channels
* `time_channels` is the number channels in the time step ($t$) embeddings * `time_channels` is the number channels in the time step ($t$) embeddings
* `n_groups` is the number of groups for [group normalization](../../normalization/group_norm/index.html) * `n_groups` is the number of groups for [group normalization](../../normalization/group_norm/index.html)
* `dropout_rate` is the dropout rate * `dropout` is the dropout rate
""" """
super().__init__() super().__init__()
# Group normalization and the first convolution layer # Group normalization and the first convolution layer
@ -122,6 +122,8 @@ class ResidualBlock(Module):
self.time_emb = nn.Linear(time_channels, out_channels) self.time_emb = nn.Linear(time_channels, out_channels)
self.time_act = Swish() self.time_act = Swish()
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor, t: torch.Tensor): def forward(self, x: torch.Tensor, t: torch.Tensor):
""" """
* `x` has shape `[batch_size, in_channels, height, width]` * `x` has shape `[batch_size, in_channels, height, width]`
@ -132,7 +134,7 @@ class ResidualBlock(Module):
# Add time embeddings # Add time embeddings
h += self.time_emb(self.time_act(t))[:, :, None, None] h += self.time_emb(self.time_act(t))[:, :, None, None]
# Second convolution layer # Second convolution layer
h = self.conv2(F.dropout(self.act2(self.norm2(h)), self.dropout_rate)) h = self.conv2(self.dropout(self.act2(self.norm2(h))))
# Add the shortcut connection and return # Add the shortcut connection and return
return h + self.shortcut(x) return h + self.shortcut(x)