diff --git a/tutorials/03-advanced/variational_auto_encoder/main.py b/tutorials/03-advanced/variational_auto_encoder/main.py index 17eb96a..d48214a 100644 --- a/tutorials/03-advanced/variational_auto_encoder/main.py +++ b/tutorials/03-advanced/variational_auto_encoder/main.py @@ -37,7 +37,7 @@ class VAE(nn.Module): nn.Linear(h_dim, image_size), nn.Sigmoid()) - def reparametrize(self, mu, log_var): + def reparameterize(self, mu, log_var): """"z = mean + eps * sigma where eps is sampled from N(0, 1).""" eps = to_var(torch.randn(mu.size(0), mu.size(1))) z = mu + eps * torch.exp(log_var/2) # 2 for convert var to std @@ -46,7 +46,7 @@ class VAE(nn.Module): def forward(self, x): h = self.encoder(x) mu, log_var = torch.chunk(h, 2, dim=1) # mean and log variance. - z = self.reparametrize(mu, log_var) + z = self.reparameterize(mu, log_var) out = self.decoder(z) return out, mu, log_var