model serialization code changed

This commit is contained in:
yunjey
2017-03-18 16:28:19 +09:00
parent 6c74dfab60
commit 3a7a9cf07b
16 changed files with 41 additions and 37 deletions

View File

@ -31,7 +31,6 @@ class RNNLM(nn.Module):
self.embed = nn.Embedding(vocab_size, embed_size)
self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
self.linear = nn.Linear(hidden_size, vocab_size)
self.init_weights()
def init_weights(self):
@ -120,4 +119,4 @@ with open(sample_path, 'w') as f:
print('Sampled [%d/%d] words and save to %s'%(i+1, num_samples, sample_path))
# Save the Trained Model
torch.save(model, 'model.pkl')
torch.save(model.state_dict(), 'model.pkl')