mirror of
https://github.com/yunjey/pytorch-tutorial.git
synced 2025-07-07 01:54:41 +08:00
Fixed the input image size for the resnet18
This commit is contained in:
@ -150,7 +150,7 @@ for param in resnet.parameters():
|
|||||||
resnet.fc = nn.Linear(resnet.fc.in_features, 100) # 100 is for example.
|
resnet.fc = nn.Linear(resnet.fc.in_features, 100) # 100 is for example.
|
||||||
|
|
||||||
# For test.
|
# For test.
|
||||||
images = Variable(torch.randn(10, 3, 256, 256))
|
images = Variable(torch.randn(10, 3, 224, 224))
|
||||||
outputs = resnet(images)
|
outputs = resnet(images)
|
||||||
print (outputs.size()) # (10, 100)
|
print (outputs.size()) # (10, 100)
|
||||||
|
|
||||||
@ -162,4 +162,4 @@ model = torch.load('model.pkl')
|
|||||||
|
|
||||||
# Save and load only the model parameters(recommended).
|
# Save and load only the model parameters(recommended).
|
||||||
torch.save(resnet.state_dict(), 'params.pkl')
|
torch.save(resnet.state_dict(), 'params.pkl')
|
||||||
resnet.load_state_dict(torch.load('params.pkl'))
|
resnet.load_state_dict(torch.load('params.pkl'))
|
||||||
|
Reference in New Issue
Block a user