mirror of
https://github.com/yunjey/pytorch-tutorial.git
synced 2025-07-08 02:38:51 +08:00
some examples are edited
This commit is contained in:
@ -79,9 +79,8 @@ print('loss after 1 step optimization: ', loss.data[0])
|
|||||||
|
|
||||||
#======================== Loading data from numpy ========================#
|
#======================== Loading data from numpy ========================#
|
||||||
a = np.array([[1,2], [3,4]])
|
a = np.array([[1,2], [3,4]])
|
||||||
b = torch.from_numpy(a)
|
b = torch.from_numpy(a) # convert numpy array to torch tensor
|
||||||
print (b)
|
c = b.numpy() # convert torch tensor to numpy array
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
#===================== Implementing the input pipline =====================#
|
#===================== Implementing the input pipline =====================#
|
||||||
@ -113,6 +112,7 @@ for images, labels in train_loader:
|
|||||||
# Your training code will be written here
|
# Your training code will be written here
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
#===================== Input pipline for custom dataset =====================#
|
#===================== Input pipline for custom dataset =====================#
|
||||||
# You should build custom dataset as below.
|
# You should build custom dataset as below.
|
||||||
class CustomDataset(data.Dataset):
|
class CustomDataset(data.Dataset):
|
||||||
@ -123,14 +123,16 @@ class CustomDataset(data.Dataset):
|
|||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
# TODO
|
# TODO
|
||||||
# 1. Read one data from file (e.g. using numpy.fromfile, PIL.Image.open).
|
# 1. Read one data from file (e.g. using numpy.fromfile, PIL.Image.open).
|
||||||
# 2. Return a data pair (e.g. image and label).
|
# 2. Preprocess the data (e.g. torchvision.Transform).
|
||||||
|
# 3. Return a data pair (e.g. image and label).
|
||||||
pass
|
pass
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
# You should change 0 to the total size of your dataset.
|
# You should change 0 to the total size of your dataset.
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
# Then, you can just use prebuilt torch's data loader.
|
# Then, you can just use prebuilt torch's data loader.
|
||||||
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
|
custom_dataset = CustomDataset()
|
||||||
|
train_loader = torch.utils.data.DataLoader(dataset=custom_dataset,
|
||||||
batch_size=100,
|
batch_size=100,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
num_workers=2)
|
num_workers=2)
|
||||||
@ -153,6 +155,11 @@ outputs = resnet(images)
|
|||||||
print (outputs.size()) # (10, 100)
|
print (outputs.size()) # (10, 100)
|
||||||
|
|
||||||
|
|
||||||
#============================ Save and load model ============================#
|
#============================ Save and load the model ============================#
|
||||||
|
# Save and load the entire model.
|
||||||
torch.save(resnet, 'model.pkl')
|
torch.save(resnet, 'model.pkl')
|
||||||
model = torch.load('model.pkl')
|
model = torch.load('model.pkl')
|
||||||
|
|
||||||
|
# Save and load only the model parameters(recommended).
|
||||||
|
torch.save(resnet.state_dict(), 'params.pkl')
|
||||||
|
resnet.load_state_dict(torch.load('params.pkl'))
|
Reference in New Issue
Block a user