mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-14 17:41:37 +08:00
Merge pull request #168 from jakehsiao/patch-3
Add activation for timed embedding and dropout for Residual block in DDPM UNet
This commit is contained in:
@ -26,6 +26,7 @@ from typing import Optional, Tuple, Union, List
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from labml_helpers.module import Module
|
from labml_helpers.module import Module
|
||||||
|
|
||||||
@ -91,12 +92,13 @@ class ResidualBlock(Module):
|
|||||||
Each resolution is processed with two residual blocks.
|
Each resolution is processed with two residual blocks.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, in_channels: int, out_channels: int, time_channels: int, n_groups: int = 32):
|
def __init__(self, in_channels: int, out_channels: int, time_channels: int, n_groups: int = 32, dropout_rate: float = 0.1):
|
||||||
"""
|
"""
|
||||||
* `in_channels` is the number of input channels
|
* `in_channels` is the number of input channels
|
||||||
* `out_channels` is the number of input channels
|
* `out_channels` is the number of input channels
|
||||||
* `time_channels` is the number channels in the time step ($t$) embeddings
|
* `time_channels` is the number channels in the time step ($t$) embeddings
|
||||||
* `n_groups` is the number of groups for [group normalization](../../normalization/group_norm/index.html)
|
* `n_groups` is the number of groups for [group normalization](../../normalization/group_norm/index.html)
|
||||||
|
* `dropout_rate` is the dropout rate
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# Group normalization and the first convolution layer
|
# Group normalization and the first convolution layer
|
||||||
@ -118,6 +120,7 @@ class ResidualBlock(Module):
|
|||||||
|
|
||||||
# Linear layer for time embeddings
|
# Linear layer for time embeddings
|
||||||
self.time_emb = nn.Linear(time_channels, out_channels)
|
self.time_emb = nn.Linear(time_channels, out_channels)
|
||||||
|
self.time_act = Swish()
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, t: torch.Tensor):
|
def forward(self, x: torch.Tensor, t: torch.Tensor):
|
||||||
"""
|
"""
|
||||||
@ -127,9 +130,9 @@ class ResidualBlock(Module):
|
|||||||
# First convolution layer
|
# First convolution layer
|
||||||
h = self.conv1(self.act1(self.norm1(x)))
|
h = self.conv1(self.act1(self.norm1(x)))
|
||||||
# Add time embeddings
|
# Add time embeddings
|
||||||
h += self.time_emb(t)[:, :, None, None]
|
h += self.time_emb(self.time_act(t))[:, :, None, None]
|
||||||
# Second convolution layer
|
# Second convolution layer
|
||||||
h = self.conv2(self.act2(self.norm2(h)))
|
h = self.conv2(F.dropout(self.act2(self.norm2(h)), self.dropout_rate))
|
||||||
|
|
||||||
# Add the shortcut connection and return
|
# Add the shortcut connection and return
|
||||||
return h + self.shortcut(x)
|
return h + self.shortcut(x)
|
||||||
|
Reference in New Issue
Block a user