diff --git a/tutorials/01-basics/pytorch_basics/main.py b/tutorials/01-basics/pytorch_basics/main.py index 17a5070..153d23a 100644 --- a/tutorials/01-basics/pytorch_basics/main.py +++ b/tutorials/01-basics/pytorch_basics/main.py @@ -150,7 +150,7 @@ for param in resnet.parameters(): resnet.fc = nn.Linear(resnet.fc.in_features, 100) # 100 is for example. # For test. -images = Variable(torch.randn(10, 3, 256, 256)) +images = Variable(torch.randn(10, 3, 224, 224)) outputs = resnet(images) print (outputs.size()) # (10, 100) @@ -162,4 +162,4 @@ model = torch.load('model.pkl') # Save and load only the model parameters(recommended). torch.save(resnet.state_dict(), 'params.pkl') -resnet.load_state_dict(torch.load('params.pkl')) \ No newline at end of file +resnet.load_state_dict(torch.load('params.pkl'))