delete unnecessary code and replace

This commit is contained in:
cjfghk5697
2023-01-19 23:40:23 +09:00
parent 59dde18a94
commit ce816b9be3

View File

@ -20,7 +20,6 @@ from typing import List
import numpy as np import numpy as np
import torch import torch
import torch as th
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
@ -174,7 +173,7 @@ class UNetModel(nn.Module):
x = self.middle_block(x, t_emb, cond) x = self.middle_block(x, t_emb, cond)
# Output half of the U-Net # Output half of the U-Net
for module in self.output_blocks: for module in self.output_blocks:
x = th.cat([x, x_input_block.pop()], dim=1) x = torch.cat([x, x_input_block.pop()], dim=1)
x = module(x, t_emb, cond) x = module(x, t_emb, cond)
# Final normalization and $3 \times 3$ convolution # Final normalization and $3 \times 3$ convolution