mirror of
https://github.com/yunjey/pytorch-tutorial.git
synced 2025-07-06 01:15:59 +08:00
fix #187 Typos in language model line 79 and generative_adversarial_network line 25-28
This commit is contained in:
@ -76,7 +76,7 @@ for epoch in range(num_epochs):
|
|||||||
loss = criterion(outputs, targets.reshape(-1))
|
loss = criterion(outputs, targets.reshape(-1))
|
||||||
|
|
||||||
# Backward and optimize
|
# Backward and optimize
|
||||||
model.zero_grad()
|
optimizer.zero_grad()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
clip_grad_norm_(model.parameters(), 0.5)
|
clip_grad_norm_(model.parameters(), 0.5)
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
@ -22,10 +22,14 @@ if not os.path.exists(sample_dir):
|
|||||||
os.makedirs(sample_dir)
|
os.makedirs(sample_dir)
|
||||||
|
|
||||||
# Image processing
|
# Image processing
|
||||||
|
# transform = transforms.Compose([
|
||||||
|
# transforms.ToTensor(),
|
||||||
|
# transforms.Normalize(mean=(0.5, 0.5, 0.5), # 3 for RGB channels
|
||||||
|
# std=(0.5, 0.5, 0.5))])
|
||||||
transform = transforms.Compose([
|
transform = transforms.Compose([
|
||||||
transforms.ToTensor(),
|
transforms.ToTensor(),
|
||||||
transforms.Normalize(mean=(0.5, 0.5, 0.5), # 3 for RGB channels
|
transforms.Normalize(mean=[0.5], # 1 for greyscale channels
|
||||||
std=(0.5, 0.5, 0.5))])
|
std=[0.5])])
|
||||||
|
|
||||||
# MNIST dataset
|
# MNIST dataset
|
||||||
mnist = torchvision.datasets.MNIST(root='../../data/',
|
mnist = torchvision.datasets.MNIST(root='../../data/',
|
||||||
|
Reference in New Issue
Block a user