diff --git a/labml_nn/diffusion/stable_diffusion/model/unet.py b/labml_nn/diffusion/stable_diffusion/model/unet.py index ccf08140..311482dd 100644 --- a/labml_nn/diffusion/stable_diffusion/model/unet.py +++ b/labml_nn/diffusion/stable_diffusion/model/unet.py @@ -20,7 +20,6 @@ from typing import List import numpy as np import torch -import torch as th import torch.nn as nn import torch.nn.functional as F @@ -174,7 +173,7 @@ class UNetModel(nn.Module): x = self.middle_block(x, t_emb, cond) # Output half of the U-Net for module in self.output_blocks: - x = th.cat([x, x_input_block.pop()], dim=1) + x = torch.cat([x, x_input_block.pop()], dim=1) x = module(x, t_emb, cond) # Final normalization and $3 \times 3$ convolution