delete unnecessary code and replace

This commit is contained in:
cjfghk5697
2023-01-19 23:40:23 +09:00
parent 59dde18a94
commit ce816b9be3

View File

@ -20,7 +20,6 @@ from typing import List
import numpy as np
import torch
import torch as th
import torch.nn as nn
import torch.nn.functional as F
@ -174,7 +173,7 @@ class UNetModel(nn.Module):
x = self.middle_block(x, t_emb, cond)
# Output half of the U-Net
for module in self.output_blocks:
x = th.cat([x, x_input_block.pop()], dim=1)
x = torch.cat([x, x_input_block.pop()], dim=1)
x = module(x, t_emb, cond)
# Final normalization and $3 \times 3$ convolution