mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-14 17:41:37 +08:00
fix dropout ddpm.unet
This commit is contained in:
@ -26,7 +26,6 @@ from typing import Optional, Tuple, Union, List
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from labml_helpers.module import Module
|
||||
|
||||
@ -92,13 +91,14 @@ class ResidualBlock(Module):
|
||||
Each resolution is processed with two residual blocks.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels: int, out_channels: int, time_channels: int, n_groups: int = 32, dropout_rate: float = 0.1):
|
||||
def __init__(self, in_channels: int, out_channels: int, time_channels: int,
|
||||
n_groups: int = 32, dropout: float = 0.1):
|
||||
"""
|
||||
* `in_channels` is the number of input channels
|
||||
* `out_channels` is the number of input channels
|
||||
* `time_channels` is the number channels in the time step ($t$) embeddings
|
||||
* `n_groups` is the number of groups for [group normalization](../../normalization/group_norm/index.html)
|
||||
* `dropout_rate` is the dropout rate
|
||||
* `dropout` is the dropout rate
|
||||
"""
|
||||
super().__init__()
|
||||
# Group normalization and the first convolution layer
|
||||
@ -122,6 +122,8 @@ class ResidualBlock(Module):
|
||||
self.time_emb = nn.Linear(time_channels, out_channels)
|
||||
self.time_act = Swish()
|
||||
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, x: torch.Tensor, t: torch.Tensor):
|
||||
"""
|
||||
* `x` has shape `[batch_size, in_channels, height, width]`
|
||||
@ -132,7 +134,7 @@ class ResidualBlock(Module):
|
||||
# Add time embeddings
|
||||
h += self.time_emb(self.time_act(t))[:, :, None, None]
|
||||
# Second convolution layer
|
||||
h = self.conv2(F.dropout(self.act2(self.norm2(h)), self.dropout_rate))
|
||||
h = self.conv2(self.dropout(self.act2(self.norm2(h))))
|
||||
|
||||
# Add the shortcut connection and return
|
||||
return h + self.shortcut(x)
|
||||
|
Reference in New Issue
Block a user