mirror of
https://github.com/yunjey/pytorch-tutorial.git
synced 2025-07-06 01:15:59 +08:00
add examples to pytorch basics
This commit is contained in:
@ -27,10 +27,10 @@ b = Variable(torch.Tensor([3]), requires_grad=True)
|
|||||||
# Build a computational graph.
|
# Build a computational graph.
|
||||||
y = w * x + b # y = 2 * x + 3
|
y = w * x + b # y = 2 * x + 3
|
||||||
|
|
||||||
# Compute gradients
|
# Compute gradients.
|
||||||
y.backward()
|
y.backward()
|
||||||
|
|
||||||
# Print out the gradients
|
# Print out the gradients.
|
||||||
print(x.grad) # x.grad = 2
|
print(x.grad) # x.grad = 2
|
||||||
print(w.grad) # w.grad = 1
|
print(w.grad) # w.grad = 1
|
||||||
print(b.grad) # b.grad = 1
|
print(b.grad) # b.grad = 1
|
||||||
@ -146,7 +146,7 @@ for param in resnet.parameters():
|
|||||||
# Replace top layer for finetuning.
|
# Replace top layer for finetuning.
|
||||||
resnet.fc = nn.Linear(resnet.fc.in_features, 100) # 100 is for example.
|
resnet.fc = nn.Linear(resnet.fc.in_features, 100) # 100 is for example.
|
||||||
|
|
||||||
# For test
|
# For test.
|
||||||
images = Variable(torch.randn(10, 3, 256, 256))
|
images = Variable(torch.randn(10, 3, 256, 256))
|
||||||
outputs = resnet(images)
|
outputs = resnet(images)
|
||||||
print (outputs.size()) # (10, 100)
|
print (outputs.size()) # (10, 100)
|
||||||
|
Reference in New Issue
Block a user