Files
2021-02-25 20:32:44 +05:30

84 lines
2.5 KiB
Python
Executable File

#!/bin/python
# Custom classes
from models.mlp import MLP
from utils.train import Trainer
from models.resnet import *
# GPU Check
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Device: " + str(device))
#Use different train/test data augmentations
transform_test = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
transform_train = transforms.Compose([
transforms.RandomHorizontalFlip(p=1.0),
transforms.RandomRotation(20),
transforms.RandomCrop(32, (2, 2), pad_if_needed=False, padding_mode='constant'),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
# Get Cifar 10 Datasets
save='./data/Cifar10'
trainset = torchvision.datasets.CIFAR10(root=save, train=True, download=True, transform=transform_train)
testset = torchvision.datasets.CIFAR10(root=save, train=False, download=True, transform=transform_test)
# Get Cifar 10 Dataloaders
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64,
shuffle=True, num_workers=4)
testloader = torch.utils.data.DataLoader(testset, batch_size=64,
shuffle=False, num_workers=4)
epochs = 50
#################################
# Create the assignment Resnet (part a)
#################################
def MyResNet():
resnet = ResNet(in_features= [32, 32, 3],
num_class=10,
feature_channel_list = [128, 256, 512],
batch_norm= True,
num_stacks=1
)
# Create MLP
# Calculate the input shape
s = resnet.GetCurShape()
in_features = s[0]*s[1]*s[2]
mlp = MLP(in_features,
10,
[], #512, 1024, 512
[],
use_batch_norm=False,
use_dropout=False,
use_softmax=False,
device=device)
resnet.AddMLP(mlp)
return resnet
model = MyResNet()
model.to(device=device)
summary(model, (3, 32,32))
# Optimizer
opt = torch.optim.Adam(model.parameters(), lr=0.0005, betas=(0.9, 0.95), weight_decay=1e-8) #0.0005 l2_factor.item()
# Loss function
cost = nn.CrossEntropyLoss()
# Create a trainer
trainer = Trainer(model, opt, cost, name="MyResNet", device=device, use_lr_schedule =True)
# Run training
trainer.Train(trainloader, epochs, testloader=testloader)
print('done')