diff --git a/tutorials/01-basics/pytorch_basics/main.py b/tutorials/01-basics/pytorch_basics/main.py index f958d7c..78b692a 100644 --- a/tutorials/01-basics/pytorch_basics/main.py +++ b/tutorials/01-basics/pytorch_basics/main.py @@ -23,9 +23,9 @@ import torchvision.transforms as transforms # ================================================================== # # Create tensors. -x = torch.tensor(1, requires_grad=True) -w = torch.tensor(2, requires_grad=True) -b = torch.tensor(3, requires_grad=True) +x = torch.tensor(1., requires_grad=True) +w = torch.tensor(2., requires_grad=True) +b = torch.tensor(3., requires_grad=True) # Build a computational graph. y = w * x + b # y = 2 * x + 3