mirror of
https://github.com/yunjey/pytorch-tutorial.git
synced 2025-07-05 08:26:16 +08:00
@ -61,7 +61,7 @@ optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
|
|||||||
|
|
||||||
# Truncated Backpropagation
|
# Truncated Backpropagation
|
||||||
def detach(states):
|
def detach(states):
|
||||||
return [Variable(state.data) for state in states]
|
return [state.detach() for state in states]
|
||||||
|
|
||||||
# Training
|
# Training
|
||||||
for epoch in range(num_epochs):
|
for epoch in range(num_epochs):
|
||||||
@ -119,4 +119,4 @@ with open(sample_path, 'w') as f:
|
|||||||
print('Sampled [%d/%d] words and save to %s'%(i+1, num_samples, sample_path))
|
print('Sampled [%d/%d] words and save to %s'%(i+1, num_samples, sample_path))
|
||||||
|
|
||||||
# Save the Trained Model
|
# Save the Trained Model
|
||||||
torch.save(model.state_dict(), 'model.pkl')
|
torch.save(model.state_dict(), 'model.pkl')
|
||||||
|
@ -61,7 +61,7 @@ optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
|
|||||||
|
|
||||||
# Truncated Backpropagation
|
# Truncated Backpropagation
|
||||||
def detach(states):
|
def detach(states):
|
||||||
return [Variable(state.data) for state in states]
|
return [state.detach() for state in states]
|
||||||
|
|
||||||
# Training
|
# Training
|
||||||
for epoch in range(num_epochs):
|
for epoch in range(num_epochs):
|
||||||
@ -119,4 +119,4 @@ with open(sample_path, 'w') as f:
|
|||||||
print('Sampled [%d/%d] words and save to %s'%(i+1, num_samples, sample_path))
|
print('Sampled [%d/%d] words and save to %s'%(i+1, num_samples, sample_path))
|
||||||
|
|
||||||
# Save the Trained Model
|
# Save the Trained Model
|
||||||
torch.save(model.state_dict(), 'model.pkl')
|
torch.save(model.state_dict(), 'model.pkl')
|
||||||
|
@ -77,7 +77,7 @@ for epoch in range(200):
|
|||||||
|
|
||||||
noise = Variable(torch.randn(images.size(0), 128)).cuda()
|
noise = Variable(torch.randn(images.size(0), 128)).cuda()
|
||||||
fake_images = generator(noise)
|
fake_images = generator(noise)
|
||||||
outputs = discriminator(fake_images)
|
outputs = discriminator(fake_images.detach())
|
||||||
fake_loss = criterion(outputs, fake_labels)
|
fake_loss = criterion(outputs, fake_labels)
|
||||||
fake_score = outputs
|
fake_score = outputs
|
||||||
|
|
||||||
@ -107,4 +107,4 @@ for epoch in range(200):
|
|||||||
|
|
||||||
# Save the Models
|
# Save the Models
|
||||||
torch.save(generator.state_dict(), './generator.pkl')
|
torch.save(generator.state_dict(), './generator.pkl')
|
||||||
torch.save(discriminator.state_dict(), './discriminator.pkl')
|
torch.save(discriminator.state_dict(), './discriminator.pkl')
|
||||||
|
@ -77,7 +77,7 @@ for epoch in range(200):
|
|||||||
|
|
||||||
noise = Variable(torch.randn(images.size(0), 128))
|
noise = Variable(torch.randn(images.size(0), 128))
|
||||||
fake_images = generator(noise)
|
fake_images = generator(noise)
|
||||||
outputs = discriminator(fake_images)
|
outputs = discriminator(fake_images.detach())
|
||||||
fake_loss = criterion(outputs, fake_labels)
|
fake_loss = criterion(outputs, fake_labels)
|
||||||
fake_score = outputs
|
fake_score = outputs
|
||||||
|
|
||||||
@ -107,4 +107,4 @@ for epoch in range(200):
|
|||||||
|
|
||||||
# Save the Models
|
# Save the Models
|
||||||
torch.save(generator.state_dict(), './generator.pkl')
|
torch.save(generator.state_dict(), './generator.pkl')
|
||||||
torch.save(discriminator.state_dict(), './discriminator.pkl')
|
torch.save(discriminator.state_dict(), './discriminator.pkl')
|
||||||
|
@ -102,7 +102,7 @@ for epoch in range(50):
|
|||||||
|
|
||||||
noise = Variable(torch.randn(images.size(0), 128)).cuda()
|
noise = Variable(torch.randn(images.size(0), 128)).cuda()
|
||||||
fake_images = generator(noise)
|
fake_images = generator(noise)
|
||||||
outputs = discriminator(fake_images)
|
outputs = discriminator(fake_images.detach())
|
||||||
fake_loss = criterion(outputs, fake_labels)
|
fake_loss = criterion(outputs, fake_labels)
|
||||||
fake_score = outputs
|
fake_score = outputs
|
||||||
|
|
||||||
@ -131,4 +131,4 @@ for epoch in range(50):
|
|||||||
|
|
||||||
# Save the Models
|
# Save the Models
|
||||||
torch.save(generator.state_dict(), './generator.pkl')
|
torch.save(generator.state_dict(), './generator.pkl')
|
||||||
torch.save(discriminator.state_dict(), './discriminator.pkl')
|
torch.save(discriminator.state_dict(), './discriminator.pkl')
|
||||||
|
@ -102,7 +102,7 @@ for epoch in range(50):
|
|||||||
|
|
||||||
noise = Variable(torch.randn(images.size(0), 128))
|
noise = Variable(torch.randn(images.size(0), 128))
|
||||||
fake_images = generator(noise)
|
fake_images = generator(noise)
|
||||||
outputs = discriminator(fake_images)
|
outputs = discriminator(fake_images.detch())
|
||||||
fake_loss = criterion(outputs, fake_labels)
|
fake_loss = criterion(outputs, fake_labels)
|
||||||
fake_score = outputs
|
fake_score = outputs
|
||||||
|
|
||||||
@ -131,4 +131,4 @@ for epoch in range(50):
|
|||||||
|
|
||||||
# Save the Models
|
# Save the Models
|
||||||
torch.save(generator.state_dict(), './generator.pkl')
|
torch.save(generator.state_dict(), './generator.pkl')
|
||||||
torch.save(discriminator.state_dict(), './discriminator.pkl')
|
torch.save(discriminator.state_dict(), './discriminator.pkl')
|
||||||
|
Reference in New Issue
Block a user