diff --git a/labml_nn/diffusion/stable_diffusion/model/unet.py b/labml_nn/diffusion/stable_diffusion/model/unet.py index 311482dd..761bded9 100644 --- a/labml_nn/diffusion/stable_diffusion/model/unet.py +++ b/labml_nn/diffusion/stable_diffusion/model/unet.py @@ -49,7 +49,9 @@ class UNetModel(nn.Module): :param n_res_blocks: number of residual blocks at each level :param attention_levels: are the levels at which attention should be performed :param channel_multipliers: are the multiplicative factors for number of channels for each level - :param n_heads: the number of attention heads in the transformers + :param n_heads: is the number of attention heads in the transformers + :param tf_layers: is the number of transformer layers in the transformers + :param d_cond: is the size of the conditional embedding in the transformers """ super().__init__() self.channels = channels