mirror of
https://github.com/yunjey/pytorch-tutorial.git
synced 2025-07-08 02:38:51 +08:00
some examples are edited
This commit is contained in:
@ -79,9 +79,8 @@ print('loss after 1 step optimization: ', loss.data[0])
|
||||
|
||||
#======================== Loading data from numpy ========================#
|
||||
a = np.array([[1,2], [3,4]])
|
||||
b = torch.from_numpy(a)
|
||||
print (b)
|
||||
|
||||
b = torch.from_numpy(a) # convert numpy array to torch tensor
|
||||
c = b.numpy() # convert torch tensor to numpy array
|
||||
|
||||
|
||||
#===================== Implementing the input pipline =====================#
|
||||
@ -113,6 +112,7 @@ for images, labels in train_loader:
|
||||
# Your training code will be written here
|
||||
pass
|
||||
|
||||
|
||||
#===================== Input pipline for custom dataset =====================#
|
||||
# You should build custom dataset as below.
|
||||
class CustomDataset(data.Dataset):
|
||||
@ -123,14 +123,16 @@ class CustomDataset(data.Dataset):
|
||||
def __getitem__(self, index):
|
||||
# TODO
|
||||
# 1. Read one data from file (e.g. using numpy.fromfile, PIL.Image.open).
|
||||
# 2. Return a data pair (e.g. image and label).
|
||||
# 2. Preprocess the data (e.g. torchvision.Transform).
|
||||
# 3. Return a data pair (e.g. image and label).
|
||||
pass
|
||||
def __len__(self):
|
||||
# You should change 0 to the total size of your dataset.
|
||||
return 0
|
||||
|
||||
# Then, you can just use prebuilt torch's data loader.
|
||||
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
|
||||
custom_dataset = CustomDataset()
|
||||
train_loader = torch.utils.data.DataLoader(dataset=custom_dataset,
|
||||
batch_size=100,
|
||||
shuffle=True,
|
||||
num_workers=2)
|
||||
@ -153,6 +155,11 @@ outputs = resnet(images)
|
||||
print (outputs.size()) # (10, 100)
|
||||
|
||||
|
||||
#============================ Save and load model ============================#
|
||||
#============================ Save and load the model ============================#
|
||||
# Save and load the entire model.
|
||||
torch.save(resnet, 'model.pkl')
|
||||
model = torch.load('model.pkl')
|
||||
|
||||
# Save and load only the model parameters(recommended).
|
||||
torch.save(resnet.state_dict(), 'params.pkl')
|
||||
resnet.load_state_dict(torch.load('params.pkl'))
|
Reference in New Issue
Block a user