mirror of
https://github.com/yunjey/pytorch-tutorial.git
synced 2025-07-24 10:08:24 +08:00
code for saving the model is added
This commit is contained in:
@ -58,4 +58,7 @@ predicted = model(Variable(torch.from_numpy(x_train))).data.numpy()
|
||||
plt.plot(x_train, y_train, 'ro', label='Original data')
|
||||
plt.plot(x_train, predicted, label='Fitted line')
|
||||
plt.legend()
|
||||
plt.show()
|
||||
plt.show()
|
||||
|
||||
# Save the Model
|
||||
torch.save(model, 'model.pkl')
|
Reference in New Issue
Block a user