mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-16 02:41:38 +08:00
delete unnecessary code and replace
This commit is contained in:
@ -20,7 +20,6 @@ from typing import List
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch as th
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
@ -174,7 +173,7 @@ class UNetModel(nn.Module):
|
|||||||
x = self.middle_block(x, t_emb, cond)
|
x = self.middle_block(x, t_emb, cond)
|
||||||
# Output half of the U-Net
|
# Output half of the U-Net
|
||||||
for module in self.output_blocks:
|
for module in self.output_blocks:
|
||||||
x = th.cat([x, x_input_block.pop()], dim=1)
|
x = torch.cat([x, x_input_block.pop()], dim=1)
|
||||||
x = module(x, t_emb, cond)
|
x = module(x, t_emb, cond)
|
||||||
|
|
||||||
# Final normalization and $3 \times 3$ convolution
|
# Final normalization and $3 \times 3$ convolution
|
||||||
|
Reference in New Issue
Block a user