From dd297a130785320e5a82d5a05f10c58ff39956a8 Mon Sep 17 00:00:00 2001 From: qy-yang Date: Tue, 27 Aug 2019 17:31:42 +0800 Subject: [PATCH] fix #187 Typos in language model line 79 and generative_adversarial_network line 25-28 --- tutorials/02-intermediate/language_model/main.py | 2 +- .../03-advanced/generative_adversarial_network/main.py | 8 ++++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/tutorials/02-intermediate/language_model/main.py b/tutorials/02-intermediate/language_model/main.py index 3c03db0..ef135bb 100644 --- a/tutorials/02-intermediate/language_model/main.py +++ b/tutorials/02-intermediate/language_model/main.py @@ -76,7 +76,7 @@ for epoch in range(num_epochs): loss = criterion(outputs, targets.reshape(-1)) # Backward and optimize - model.zero_grad() + optimizer.zero_grad() loss.backward() clip_grad_norm_(model.parameters(), 0.5) optimizer.step() diff --git a/tutorials/03-advanced/generative_adversarial_network/main.py b/tutorials/03-advanced/generative_adversarial_network/main.py index 34f4127..c2062cf 100644 --- a/tutorials/03-advanced/generative_adversarial_network/main.py +++ b/tutorials/03-advanced/generative_adversarial_network/main.py @@ -22,10 +22,14 @@ if not os.path.exists(sample_dir): os.makedirs(sample_dir) # Image processing +# transform = transforms.Compose([ +# transforms.ToTensor(), +# transforms.Normalize(mean=(0.5, 0.5, 0.5), # 3 for RGB channels +# std=(0.5, 0.5, 0.5))]) transform = transforms.Compose([ transforms.ToTensor(), - transforms.Normalize(mean=(0.5, 0.5, 0.5), # 3 for RGB channels - std=(0.5, 0.5, 0.5))]) + transforms.Normalize(mean=[0.5], # 1 for greyscale channels + std=[0.5])]) # MNIST dataset mnist = torchvision.datasets.MNIST(root='../../data/',