mirror of
https://github.com/yunjey/pytorch-tutorial.git
synced 2025-07-24 18:45:33 +08:00
tutorials are added
This commit is contained in:
397
tutorials/00 - PyTorch Basics/basics.ipynb
Normal file
397
tutorials/00 - PyTorch Basics/basics.ipynb
Normal file
@ -0,0 +1,397 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 19,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch \n",
|
||||
"import torchvision\n",
|
||||
"import torch.nn as nn\n",
|
||||
"import torch.utils.data as data\n",
|
||||
"import numpy as np\n",
|
||||
"import torchvision.transforms as transforms\n",
|
||||
"import torchvision.datasets as dsets\n",
|
||||
"from torch.autograd import Variable"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Simple Example"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"-1.2532 -1.1120 0.9717\n",
|
||||
"-2.3617 0.1516 1.1280\n",
|
||||
"-2.1599 0.0828 -1.4305\n",
|
||||
" 0.5265 0.5020 -2.1852\n",
|
||||
"-0.9197 0.1772 -1.1378\n",
|
||||
"[torch.FloatTensor of size 5x3]\n",
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# random normal\n",
|
||||
"x = torch.randn(5, 3)\n",
|
||||
"print (x)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# build a layer\n",
|
||||
"linear = nn.Linear(3, 2)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Parameter containing:\n",
|
||||
" 0.3884 -0.3335 -0.5146\n",
|
||||
"-0.3692 0.1977 -0.4081\n",
|
||||
"[torch.FloatTensor of size 2x3]\n",
|
||||
"\n",
|
||||
"Parameter containing:\n",
|
||||
"-0.4826\n",
|
||||
"-0.0038\n",
|
||||
"[torch.FloatTensor of size 2]\n",
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Sess weight and bias\n",
|
||||
"print (linear.weight)\n",
|
||||
"print (linear.bias)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Variable containing:\n",
|
||||
"-1.0986 -0.1575\n",
|
||||
"-2.0311 0.4378\n",
|
||||
"-0.6131 1.3938\n",
|
||||
" 0.6790 0.7929\n",
|
||||
"-0.3134 0.8351\n",
|
||||
"[torch.FloatTensor of size 5x2]\n",
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# forward propagate\n",
|
||||
"y = linear(Variable(x))\n",
|
||||
"print (y)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Convert numpy array to torch tensor"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# convert numpy array to tensor\n",
|
||||
"a = np.array([[1,2], [3,4]])\n",
|
||||
"b = torch.from_numpy(a)\n",
|
||||
"print (b)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Input pipeline"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### (1) Preprocessing"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Image Preprocessing \n",
|
||||
"transform = transforms.Compose([\n",
|
||||
" transforms.Scale(40),\n",
|
||||
" transforms.RandomHorizontalFlip(),\n",
|
||||
" transforms.RandomCrop(32),\n",
|
||||
" transforms.ToTensor()])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### (2) Define Dataset"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Files already downloaded and verified\n",
|
||||
"torch.Size([3, 32, 32])\n",
|
||||
"6\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# download and loading dataset f\n",
|
||||
"train_dataset = dsets.CIFAR10(root='./data/',\n",
|
||||
" train=True, \n",
|
||||
" transform=transform,\n",
|
||||
" download=True)\n",
|
||||
"\n",
|
||||
"image, label = train_dataset[0]\n",
|
||||
"print (image.size())\n",
|
||||
"print (label)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### (3) Data Loader"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# data loader provides queue and thread in a very simple way\n",
|
||||
"train_loader = data.DataLoader(dataset=train_dataset,\n",
|
||||
" batch_size=100, \n",
|
||||
" shuffle=True,\n",
|
||||
" num_workers=2)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# iteration start then queue and thread start\n",
|
||||
"data_iter = iter(train_loader)\n",
|
||||
"\n",
|
||||
"# mini-batch images and labels\n",
|
||||
"images, labels = data_iter.next()\n",
|
||||
"\n",
|
||||
"for images, labels in train_loader:\n",
|
||||
" # your training code will be written here\n",
|
||||
" pass"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### (4) What about custom dataset not cifar10?"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 25,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class CustomDataset(data.Dataset):\n",
|
||||
" def __init__(self):\n",
|
||||
" pass\n",
|
||||
" def __getitem__(self, index):\n",
|
||||
" # You should build this function to return one data for given index\n",
|
||||
" pass\n",
|
||||
" def __len__(self):\n",
|
||||
" pass"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 26,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"ename": "TypeError",
|
||||
"evalue": "'NoneType' object cannot be interpreted as an integer",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||||
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
|
||||
"\u001b[0;32m<ipython-input-26-a76c7b5c92c3>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mbatch_size\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m100\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mshuffle\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m num_workers=2)\n\u001b[0m",
|
||||
"\u001b[0;32m/home/yunjey/anaconda3/lib/python3.5/site-packages/torch/utils/data/dataloader.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, dataset, batch_size, shuffle, sampler, num_workers, collate_fn, pin_memory)\u001b[0m\n\u001b[1;32m 250\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msampler\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msampler\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 251\u001b[0m \u001b[0;32melif\u001b[0m \u001b[0mshuffle\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 252\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msampler\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mRandomSampler\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 253\u001b[0m \u001b[0;32melif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mshuffle\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 254\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msampler\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mSequentialSampler\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0;32m/home/yunjey/anaconda3/lib/python3.5/site-packages/torch/utils/data/sampler.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, data_source)\u001b[0m\n\u001b[1;32m 45\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 46\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdata_source\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 47\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnum_samples\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata_source\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 48\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 49\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__iter__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0;31mTypeError\u001b[0m: 'NoneType' object cannot be interpreted as an integer"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"custom_dataset = CustomDataset()\n",
|
||||
"data.DataLoader(dataset=custom_dataset,\n",
|
||||
" batch_size=100, \n",
|
||||
" shuffle=True,\n",
|
||||
" num_workers=2)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Using Pretrained Model"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 17,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Downloading: \"https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth\" to /home/yunjey/.torch/models/resnet18-5c106cde.pth\n",
|
||||
"100%|██████████| 46827520/46827520 [07:48<00:00, 99907.53it/s] \n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Download and load pretrained model\n",
|
||||
"resnet = torchvision.models.resnet18(pretrained=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# delete top layer for finetuning\n",
|
||||
"sub_model = nn.Sequentialtial(*list(resnet.children()[:-1]))\n",
|
||||
"\n",
|
||||
"# for test\n",
|
||||
"images = Variable(torch.randn(10, 3, 256, 256))\n",
|
||||
"print (resnet(images).size())\n",
|
||||
"print (sub_model(images).size())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Save and Load Model"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Save and load the trained model\n",
|
||||
"torch.save(sub_model, 'model.pkl')\n",
|
||||
"\n",
|
||||
"model = torch.load('model.pkl')"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"anaconda-cloud": {},
|
||||
"kernelspec": {
|
||||
"display_name": "Python [conda root]",
|
||||
"language": "python",
|
||||
"name": "conda-root-py"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.5.2"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 1
|
||||
}
|
Reference in New Issue
Block a user