mirror of
https://github.com/yunjey/pytorch-tutorial.git
synced 2025-07-21 14:59:18 +08:00
model serialization code changed
This commit is contained in:
@ -31,7 +31,6 @@ class RNNLM(nn.Module):
|
||||
self.embed = nn.Embedding(vocab_size, embed_size)
|
||||
self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
|
||||
self.linear = nn.Linear(hidden_size, vocab_size)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def init_weights(self):
|
||||
@ -120,4 +119,4 @@ with open(sample_path, 'w') as f:
|
||||
print('Sampled [%d/%d] words and save to %s'%(i+1, num_samples, sample_path))
|
||||
|
||||
# Save the Trained Model
|
||||
torch.save(model, 'model.pkl')
|
||||
torch.save(model.state_dict(), 'model.pkl')
|
Reference in New Issue
Block a user