3from utils.train import Trainer # Default custom training class
4from models.resnet import *
5from torchvision import models

GPU Check

8device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
9print("Device:  " + str(device))

Use different train/test data augmentations

12transform_test = transforms.Compose(
13        [transforms.ToTensor(),
14         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

Get Cifar 10 Datasets

17save='./data/Cifar10'
18transform_train = transforms.Compose([
19        transforms.RandomHorizontalFlip(p=1.0),
20        transforms.RandomRotation(20),
21        transforms.RandomCrop(32, (2, 2), pad_if_needed=False, padding_mode='constant'),
22        transforms.ToTensor(),
23        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

Get Cifar 10 Datasets

26trainset = torchvision.datasets.CIFAR10(root=save, train=True, download=True, transform=transform_train)
27testset = torchvision.datasets.CIFAR10(root=save, train=False, download=True, transform=transform_test)

Get Cifar 10 Dataloaders

30trainloader = torch.utils.data.DataLoader(trainset, batch_size=64,
31                                          shuffle=True, num_workers=4)
32
33testloader = torch.utils.data.DataLoader(testset, batch_size=64,
34                                         shuffle=False, num_workers=4)

Load the pre-trained model

40model_ft = models.resnet18(pretrained=True)
41num_ftrs = model_ft.fc.in_features
42model_ft.fc = nn.Sequential(
43    nn.Dropout(0.5),
44    nn.Linear(num_ftrs, 10)
45)
46
47
48model_ft = model_ft.to(device)

Loss function

51cost = nn.CrossEntropyLoss()

Optimizer

54lr = 0.0005

opt = optim.SGD(model_ft.parameters(), lr=lr, momentum=0.9)

56opt = torch.optim.Adam(model_ft.parameters(), lr=lr, betas=(0.9, 0.95), weight_decay=1e-4) #0.0005 l2_factor.item()

Create a trainer

59trainer = Trainer(model_ft, opt, cost, name="Transfer-learning",lr=lr , use_lr_schedule=True, device=device)

Run training

62epochs = 25
63trainer.Train(trainloader, epochs, testloader=testloader)

trainer.Train(trainloader, epochs) # check train error

66print('done')