diff --git a/tutorials/00 - PyTorch Basics/main.py b/tutorials/00 - PyTorch Basics/main.py index 1bd907b..4c99ea0 100644 --- a/tutorials/00 - PyTorch Basics/main.py +++ b/tutorials/00 - PyTorch Basics/main.py @@ -79,9 +79,8 @@ print('loss after 1 step optimization: ', loss.data[0]) #======================== Loading data from numpy ========================# a = np.array([[1,2], [3,4]]) -b = torch.from_numpy(a) -print (b) - +b = torch.from_numpy(a) # convert numpy array to torch tensor +c = b.numpy() # convert torch tensor to numpy array #===================== Implementing the input pipline =====================# @@ -113,6 +112,7 @@ for images, labels in train_loader: # Your training code will be written here pass + #===================== Input pipline for custom dataset =====================# # You should build custom dataset as below. class CustomDataset(data.Dataset): @@ -123,14 +123,16 @@ class CustomDataset(data.Dataset): def __getitem__(self, index): # TODO # 1. Read one data from file (e.g. using numpy.fromfile, PIL.Image.open). - # 2. Return a data pair (e.g. image and label). + # 2. Preprocess the data (e.g. torchvision.Transform). + # 3. Return a data pair (e.g. image and label). pass def __len__(self): # You should change 0 to the total size of your dataset. return 0 # Then, you can just use prebuilt torch's data loader. -train_loader = torch.utils.data.DataLoader(dataset=train_dataset, +custom_dataset = CustomDataset() +train_loader = torch.utils.data.DataLoader(dataset=custom_dataset, batch_size=100, shuffle=True, num_workers=2) @@ -153,6 +155,11 @@ outputs = resnet(images) print (outputs.size()) # (10, 100) -#============================ Save and load model ============================# +#============================ Save and load the model ============================# +# Save and load the entire model. torch.save(resnet, 'model.pkl') -model = torch.load('model.pkl') \ No newline at end of file +model = torch.load('model.pkl') + +# Save and load only the model parameters(recommended). +torch.save(resnet.state_dict(), 'params.pkl') +resnet.load_state_dict(torch.load('params.pkl')) \ No newline at end of file