Custom classes

4from models.mlp import MLP
5from utils.train import Trainer
6from models.resnet import *

GPU Check

9device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
10print("Device:  " + str(device))

Use different train/test data augmentations

13transform_test = transforms.Compose(
14        [transforms.ToTensor(),
15         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
16
17transform_train = transforms.Compose([
18        transforms.RandomHorizontalFlip(p=1.0),
19        transforms.RandomRotation(20),
20        transforms.RandomCrop(32, (2, 2), pad_if_needed=False, padding_mode='constant'),
21        transforms.ToTensor(),
22        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

Get Cifar 10 Datasets

26save='./data/Cifar10'
27trainset = torchvision.datasets.CIFAR10(root=save, train=True, download=True, transform=transform_train)
28testset = torchvision.datasets.CIFAR10(root=save, train=False, download=True, transform=transform_test)

Get Cifar 10 Dataloaders

31trainloader = torch.utils.data.DataLoader(trainset, batch_size=64,
32                                          shuffle=True, num_workers=4)
33
34testloader = torch.utils.data.DataLoader(testset, batch_size=64, 
35                                         shuffle=False, num_workers=4)
36
37epochs = 50

Create the assignment Resnet (part a)

42def MyResNet():
43    resnet = ResNet(in_features= [32, 32, 3],
44                    num_class=10,
45                    feature_channel_list = [128, 256, 512],
46                    batch_norm= True,
47                    num_stacks=1
48                    )

Create MLP Calculate the input shape

52    s = resnet.GetCurShape()
53    in_features = s[0]*s[1]*s[2]
54
55    mlp = MLP(in_features,
56                 10,
57                 [], #512, 1024, 512
58                 [],
59                 use_batch_norm=False,
60                 use_dropout=False,
61                 use_softmax=False,
62                 device=device)
63
64    resnet.AddMLP(mlp)
65    return resnet
66
67model = MyResNet()
68model.to(device=device)
69summary(model, (3, 32,32))

Optimizer

72opt = torch.optim.Adam(model.parameters(), lr=0.0005, betas=(0.9, 0.95), weight_decay=1e-8) #0.0005 l2_factor.item()

Loss function

75cost = nn.CrossEntropyLoss()

Create a trainer

78trainer = Trainer(model, opt, cost, name="MyResNet", device=device, use_lr_schedule =True)

Run training

81trainer.Train(trainloader, epochs, testloader=testloader)
82
83print('done')