mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-14 17:41:37 +08:00
Implemented CNN filter visualization and Hyperparameter tuning (Ray)
This commit is contained in:

committed by
Varuna Jayasiri

parent
ba5c7200e8
commit
73ce42ec4c
164
labml_nn/cnn/cnn_visualization.py
Executable file
164
labml_nn/cnn/cnn_visualization.py
Executable file
@ -0,0 +1,164 @@
|
||||
#!/bin/python
|
||||
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torchsummary import summary
|
||||
from functools import partial
|
||||
from skimage.filters import sobel, sobel_h, roberts
|
||||
from models.cnn import CNN
|
||||
from utils.dataloader import *
|
||||
from utils.train import Trainer
|
||||
|
||||
# Check if GPU is available
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
print("Device: " + str(device))
|
||||
|
||||
# Cifar 10 Datasets location
|
||||
save='./data/Cifar10'
|
||||
|
||||
# Transformations train
|
||||
transform_train = transforms.Compose(
|
||||
[transforms.ToTensor(),
|
||||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
|
||||
|
||||
# Load train dataset and dataloader
|
||||
trainset = LoadCifar10DatasetTrain(save, transform_train)
|
||||
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64,
|
||||
shuffle=True, num_workers=4)
|
||||
|
||||
# Transformations test
|
||||
transform_test = transforms.Compose(
|
||||
[transforms.ToTensor(),
|
||||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
|
||||
|
||||
# Load test dataset and dataloader
|
||||
testset = LoadCifar10DatasetTest(save, transform_test)
|
||||
testloader = torch.utils.data.DataLoader(testset, batch_size=64,
|
||||
shuffle=False, num_workers=4)
|
||||
|
||||
# Create CNN model
|
||||
def GetCNN():
|
||||
cnn = CNN( in_features=(32,32,3),
|
||||
out_features=10,
|
||||
conv_filters=[32,32,64,64],
|
||||
conv_kernel_size=[3,3,3,3],
|
||||
conv_strides=[1,1,1,1],
|
||||
conv_pad=[0,0,0,0],
|
||||
max_pool_kernels=[None, (2,2), None, (2,2)],
|
||||
max_pool_strides=[None,2,None,2],
|
||||
use_dropout=False,
|
||||
use_batch_norm=True, #False
|
||||
actv_func=["relu", "relu", "relu", "relu"],
|
||||
device=device
|
||||
)
|
||||
|
||||
return cnn
|
||||
|
||||
model = GetCNN()
|
||||
|
||||
# Display model specifications
|
||||
summary(model, (3,32,32))
|
||||
|
||||
# Send model to GPU
|
||||
model.to(device)
|
||||
|
||||
# Specify optimizer
|
||||
opt = optim.Adam(model.parameters(), lr=0.0005, betas=(0.9, 0.95))
|
||||
|
||||
# Specify loss function
|
||||
cost = nn.CrossEntropyLoss()
|
||||
|
||||
# Train the model
|
||||
trainer = Trainer(device=device, name="Basic_CNN")
|
||||
epochs = 5
|
||||
trainer.Train(model, trainloader, testloader, cost=cost, opt=opt, epochs=epochs)
|
||||
|
||||
# Load best saved model for inference
|
||||
model_loaded = GetCNN()
|
||||
|
||||
# Specify location of saved model
|
||||
PATH = "./save/Basic_CNN-best-model/model.pt"
|
||||
checkpoint = torch.load(PATH)
|
||||
|
||||
# load the saved model
|
||||
model_loaded.load_state_dict(checkpoint['state_dict'])
|
||||
|
||||
# intialization for hooks and storing activation of ReLU layers
|
||||
activation = {}
|
||||
hooks = []
|
||||
|
||||
# Hook function saves activation of a particular layer
|
||||
def hook_fn(model, input, output, name):
|
||||
activation[name] = output.cpu().detach().numpy()
|
||||
|
||||
# Registering hooks
|
||||
count =0
|
||||
conv_count = 0
|
||||
for name, layer in model_loaded.named_modules():
|
||||
if isinstance(layer, nn.ReLU):
|
||||
count +=1
|
||||
hook = layer.register_forward_hook(partial(hook_fn, name=f"{layer._get_name()}-{count}")) #f"{type(layer).__name__}-{name}"
|
||||
hooks.append(hook)
|
||||
if isinstance(layer, nn.Conv2d):
|
||||
conv_count += 1
|
||||
|
||||
# Displaying image used for inference
|
||||
data, _ = trainset[15]
|
||||
imshow(data)
|
||||
|
||||
# Infering model to save activation of ReLU layers
|
||||
output = model_loaded(data[None].to(device))
|
||||
|
||||
# Removing hooks
|
||||
for hook in hooks:
|
||||
hook.remove()
|
||||
|
||||
# Function to display output of a particular ReLU layer
|
||||
def output_one_layer(layer_num):
|
||||
assert 1 <= layer_num <= len(activation), "Wrong layer number"
|
||||
|
||||
layer_name = f"ReLu-{layer_num}"
|
||||
act = activation[f"ReLU-{layer_num}"]
|
||||
if act.shape[1]==32:
|
||||
rows = 4
|
||||
columns = 8
|
||||
elif act.shape[1]==64:
|
||||
rows = 8
|
||||
columns = 8
|
||||
|
||||
fig = plt.figure(figsize=(rows, columns))
|
||||
for idx in range(1, columns * rows + 1):
|
||||
fig.add_subplot(rows, columns, idx)
|
||||
plt.imshow(sobel(act[0][idx-1]), cmap=plt.cm.gray)
|
||||
|
||||
# try different filters
|
||||
# plt.imshow(act[0][idx-1], cmap='viridis', vmin=0, vmax=act.max())
|
||||
# plt.imshow(act[0][idx - 1], cmap='hot')
|
||||
# plt.imshow(roberts(act[0][idx - 1]), cmap=plt.cm.gray)
|
||||
# plt.imshow(sobel_h(act[0][idx-1]), cmap=plt.cm.gray)
|
||||
|
||||
plt.axis('off')
|
||||
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
|
||||
# Function to display output of all ReLU layer after Convulution layers
|
||||
def output_all_layers():
|
||||
for [name, output], count in zip(activation.items(), range(conv_count)):
|
||||
if output.shape[1] == 32:
|
||||
_, axs = plt.subplots(8, 4, figsize=(8, 4))
|
||||
elif output.shape[1] == 64:
|
||||
_, axs = plt.subplots(8, 8, figsize=(8, 8))
|
||||
|
||||
for ax, out in zip(np.ravel(axs), output[0]):
|
||||
ax.imshow(sobel(out), cmap=plt.cm.gray)
|
||||
ax.axis('off')
|
||||
|
||||
plt.suptitle(name)
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
|
||||
# Choose either one to display
|
||||
output_one_layer(layer_num=3) # choose layer number
|
||||
output_all_layers()
|
||||
|
193
labml_nn/cnn/models/cnn.py
Executable file
193
labml_nn/cnn/models/cnn.py
Executable file
@ -0,0 +1,193 @@
|
||||
#!/bin/python
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
|
||||
# Use the formula:
|
||||
# [(W-K+2P)/S] + 1
|
||||
# where:
|
||||
# W: Is the input volume size for each dimension
|
||||
# K: Is the kernel size
|
||||
# P: Is the padding
|
||||
# S: Is the stride
|
||||
|
||||
def CalcConvFormula(W, K, P, S):
|
||||
return int(np.floor(((W - K + 2 * P) / S) + 1))
|
||||
|
||||
|
||||
# https://stackoverflow.com/questions/53580088/calculate-the-output-size-in-convolution-layer
|
||||
# Calculate the output shape after applying a convolution
|
||||
def CalcConvOutShape(in_shape, kernel_size, padding, stride, out_filters):
|
||||
# Multiple options for different kernel shapes
|
||||
if type(kernel_size) == int:
|
||||
out_shape = [CalcConvFormula(in_shape[i], kernel_size, padding, stride) for i in range(2)]
|
||||
else:
|
||||
out_shape = [CalcConvFormula(in_shape[i], kernel_size[i], padding, stride) for i in range(2)]
|
||||
|
||||
return (out_shape[0], out_shape[1], out_filters) # , batch_size... but not necessary.
|
||||
|
||||
class CNN(nn.Module):
|
||||
def __init__(self
|
||||
, in_features
|
||||
, out_features
|
||||
, conv_filters
|
||||
, conv_kernel_size
|
||||
, conv_strides
|
||||
, conv_pad
|
||||
, actv_func
|
||||
, max_pool_kernels
|
||||
, max_pool_strides
|
||||
, l1=120
|
||||
, l2=84
|
||||
, MLP=None
|
||||
, pre_module_list=None
|
||||
, use_dropout=False
|
||||
, use_batch_norm=False
|
||||
, device="cpu"
|
||||
):
|
||||
super(CNN, self).__init__()
|
||||
|
||||
# Gerneral model Properties
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
|
||||
# Convolution operations
|
||||
self.conv_filters = conv_filters
|
||||
self.conv_kernel_size = conv_kernel_size
|
||||
self.conv_strides = conv_strides
|
||||
self.conv_pad = conv_pad
|
||||
|
||||
# Convolution Activiations
|
||||
self.actv_func = actv_func
|
||||
|
||||
# Max Pools
|
||||
self.max_pool_kernels = max_pool_kernels
|
||||
self.max_pool_strides = max_pool_strides
|
||||
|
||||
# Regularization
|
||||
self.use_dropout = use_dropout
|
||||
self.use_batch_norm = use_batch_norm
|
||||
|
||||
# Tunable parameters
|
||||
self.l1 = l1
|
||||
self.l2 = l2
|
||||
|
||||
# Number of conv/pool/act/batch_norm/dropout layers we add
|
||||
self.n_conv_layers = len(self.conv_filters)
|
||||
|
||||
# Create the module list
|
||||
if pre_module_list:
|
||||
self.module_list = pre_module_list
|
||||
else:
|
||||
self.module_list = nn.ModuleList()
|
||||
|
||||
self.shape_list = []
|
||||
self.shape_list.append(self.in_features)
|
||||
|
||||
self.build_()
|
||||
|
||||
# Send to gpu
|
||||
self.device = device
|
||||
self.to(self.device)
|
||||
|
||||
def build_(self):
|
||||
# Track shape
|
||||
cur_shape = self.GetCurShape()
|
||||
|
||||
for i in range(self.n_conv_layers):
|
||||
if i == 0:
|
||||
if len(self.in_features) == 2:
|
||||
in_channels = 1
|
||||
else:
|
||||
in_channels = self.in_features[2]
|
||||
else:
|
||||
in_channels = self.conv_filters[i - 1]
|
||||
|
||||
cur_shape = CalcConvOutShape(cur_shape, self.conv_kernel_size[i], self.conv_pad[i], self.conv_strides[i],
|
||||
self.conv_filters[i])
|
||||
self.shape_list.append(cur_shape)
|
||||
|
||||
conv = nn.Conv2d(in_channels=in_channels,
|
||||
out_channels=self.conv_filters[i],
|
||||
kernel_size=self.conv_kernel_size[i],
|
||||
padding=self.conv_pad[i],
|
||||
stride=self.conv_strides[i]
|
||||
)
|
||||
self.module_list.append(conv)
|
||||
|
||||
if self.use_batch_norm:
|
||||
self.module_list.append(nn.BatchNorm2d(cur_shape[2]))
|
||||
|
||||
if self.use_dropout:
|
||||
self.module_list.append(nn.Dropout(p=0.15))
|
||||
|
||||
# Add the Activation function
|
||||
if self.actv_func[i]:
|
||||
self.module_list.append(GetActivation(name=self.actv_func[i]))
|
||||
|
||||
if self.max_pool_kernels:
|
||||
if self.max_pool_kernels[i]:
|
||||
self.module_list.append(nn.MaxPool2d(self.max_pool_kernels[i], stride=self.max_pool_strides[i]))
|
||||
cur_shape = CalcConvOutShape(cur_shape, self.max_pool_kernels[i], 0, self.max_pool_strides[i],
|
||||
cur_shape[2])
|
||||
self.shape_list.append(cur_shape)
|
||||
|
||||
# # Adding MLP
|
||||
s = self.GetCurShape()
|
||||
in_features = s[0] * s[1] * s[2]
|
||||
self.module_list.append(nn.Linear(in_features, self.l1))
|
||||
self.module_list.append(nn.ReLU())
|
||||
self.module_list.append(nn.Linear(self.l1, self.l2))
|
||||
self.module_list.append(nn.ReLU())
|
||||
self.module_list.append(nn.Linear(self.l2, self.out_features))
|
||||
|
||||
def forward(self, x):
|
||||
j = 0
|
||||
for i, module in enumerate(self.module_list):
|
||||
if isinstance(module, nn.Linear) and j == 0:
|
||||
x = torch.flatten(x.float(), start_dim=1)
|
||||
j = 1
|
||||
x = module(x)
|
||||
return x
|
||||
|
||||
def GetCurShape(self):
|
||||
return self.shape_list[-1]
|
||||
|
||||
def GetCNN(l1=120, l2=84):
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
cnn = CNN(in_features=(32, 32, 3),
|
||||
out_features=10,
|
||||
conv_filters=[32, 32, 64, 64], # , 128, 256, 512
|
||||
conv_kernel_size=[3, 3, 3, 3], # ,3,3,1
|
||||
conv_strides=[1, 1, 1, 1], # ,1,1,1
|
||||
conv_pad=[0, 0, 0, 0, 0, 0, 0],
|
||||
actv_func=["relu", "relu", "relu", "relu"], # , "relu", "relu", "relu"
|
||||
max_pool_kernels=[None, (2, 2), None, (2, 2)], # , None, None, None
|
||||
max_pool_strides=[None, 2, None, 2], # , None,None, None
|
||||
l1=l1,
|
||||
l2=l2,
|
||||
use_dropout=False,
|
||||
use_batch_norm=True, # False
|
||||
device=device
|
||||
)
|
||||
|
||||
return cnn
|
||||
|
||||
|
||||
def GetActivation(name="relu"):
|
||||
if name == "relu":
|
||||
return nn.ReLU()
|
||||
elif name == "leakyrelu":
|
||||
return nn.LeakyReLU()
|
||||
elif name == "Sigmoid":
|
||||
return nn.Sigmoid()
|
||||
elif name == "Tanh":
|
||||
return nn.Tanh()
|
||||
elif name == "Identity":
|
||||
return nn.Identity()
|
||||
else:
|
||||
return nn.ReLU()
|
106
labml_nn/cnn/ray_tune.py
Normal file
106
labml_nn/cnn/ray_tune.py
Normal file
@ -0,0 +1,106 @@
|
||||
#!/bin/python
|
||||
|
||||
import numpy as np
|
||||
import os
|
||||
import torch
|
||||
from ray import tune
|
||||
from ray.tune.schedulers import ASHAScheduler, PopulationBasedTraining
|
||||
from utils.train import Trainer
|
||||
from models.cnn import GetCNN
|
||||
|
||||
# Check if GPU is available
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
print("Device: " + str(device))
|
||||
|
||||
#
|
||||
num_samples= 40 # for multiple trials
|
||||
max_num_epochs= 25
|
||||
gpus_per_trial= 1
|
||||
|
||||
# Cifar 10 Datasets location
|
||||
data_dir = './data/Cifar10'
|
||||
|
||||
"""config - returns a dict of hyperparameters
|
||||
|
||||
Selecting different hyperparameters for tuning
|
||||
l1 : Number of units in first fully connected layer
|
||||
l2 : Number of units in second fully connected layer
|
||||
lr : Learning rate
|
||||
decay : Decay rate for regularization
|
||||
batch_size : Batch size of test and train data
|
||||
"""
|
||||
config = {
|
||||
"l1": tune.sample_from(lambda _: 2 ** np.random.randint(2, 9)), # eg. 4, 8, 16 .. 512
|
||||
"l2": tune.sample_from(lambda _: 2 ** np.random.randint(2, 9)), # eg. 4, 8, 16 .. 512
|
||||
"lr": tune.loguniform(1e-4, 1e-1), # Sampling from log uniform distribution
|
||||
"decay": tune.sample_from(lambda _: 10 ** np.random.randint(-7, -3)), # eg. 1e-7, 1e-6, .. 1e-3
|
||||
"batch_size": tune.choice([32, 64, 128, 256])
|
||||
}
|
||||
|
||||
# calling trainer
|
||||
trainer = Trainer(device=device)
|
||||
|
||||
"""ASHA (Asynchronous Successive Halving Algorithm) scheduler
|
||||
max_t : Maximum number of units per trail (can be time or epochs)
|
||||
grace_period : Stop trials after specific number of unit if model is not performing well (can be time or epochs)
|
||||
reduction_factor : Set halving rate
|
||||
"""
|
||||
scheduler = ASHAScheduler(
|
||||
max_t=max_num_epochs,
|
||||
grace_period=4,
|
||||
reduction_factor=4)
|
||||
|
||||
|
||||
|
||||
"""Population based training scheduler
|
||||
time_attr : Can be time or epochs
|
||||
metric : Objective of training (loss or accuracy)
|
||||
perturbation_interval : Perturbation occur after specified unit (can be time or epochs)
|
||||
hyperparam_mutations : Hyperparameters to mutate
|
||||
"""
|
||||
scheduler = PopulationBasedTraining(
|
||||
time_attr= "training_iteration", # epochs
|
||||
metric='loss', # loss is objective function
|
||||
mode='min', # minimizing loss is objective of training
|
||||
perturbation_interval=5.0, # after 5 epochs perturbate
|
||||
hyperparam_mutations={
|
||||
"lr": [1e-3, 5e-4, 1e-4, 5e-4, 1e-5], # choose from given learning rates
|
||||
"batch_size": [64, 128, 256], # choose from given batch sizes
|
||||
"decay": tune.uniform(10**-8, 10**-4) # sample from uniform distribution
|
||||
}
|
||||
)
|
||||
|
||||
result = tune.run(
|
||||
tune.with_parameters(trainer.Train_ray, data_dir=data_dir),
|
||||
name="ray_test_basic-CNN", # name for identifying models (checkpoints)
|
||||
scheduler=scheduler, # select scheduler PBT or ASHA
|
||||
resources_per_trial={"cpu": 8, "gpu": gpus_per_trial}, # select number of CPUs or GPUs
|
||||
config=config, # input config dict consisting of different hyperparameters
|
||||
stop={
|
||||
"training_iteration": max_num_epochs, # stopping criterea
|
||||
},
|
||||
metric="loss", # uncomment for ASHA scheduler
|
||||
mode="min", # uncomment for ASHA scheduler
|
||||
num_samples=num_samples,
|
||||
verbose=True, # keep to true to check how training progresses
|
||||
fail_fast=True, # fail on first error
|
||||
keep_checkpoints_num=5, # number of checkpoints to be saved per num_samples
|
||||
|
||||
)
|
||||
|
||||
best_trial = result.get_best_trial("loss", "min", "last")
|
||||
print("Best configuration: {}".format(best_trial.config))
|
||||
print("Best validation loss: {}".format(best_trial.last_result["loss"]))
|
||||
print("Best validation accuracy: {}".format(
|
||||
best_trial.last_result["accuracy"]))
|
||||
|
||||
|
||||
best_trained_model = GetCNN(best_trial.config["l1"], best_trial.config["l2"])
|
||||
best_trained_model.to(device)
|
||||
checkpoint_path = os.path.join(best_trial.checkpoint.value, "checkpoint")
|
||||
model_state, optimizer_state = torch.load(checkpoint_path)
|
||||
best_trained_model.load_state_dict(model_state)
|
||||
|
||||
# Check accuracy of best model
|
||||
test_acc = trainer.Test(best_trained_model, save=data_dir)
|
||||
print("Best Test accuracy: {}".format(test_acc))
|
BIN
labml_nn/cnn/save/Basic_CNN-best-model/model.pt
Normal file
BIN
labml_nn/cnn/save/Basic_CNN-best-model/model.pt
Normal file
Binary file not shown.
83
labml_nn/cnn/utils/dataloader.py
Normal file
83
labml_nn/cnn/utils/dataloader.py
Normal file
@ -0,0 +1,83 @@
|
||||
#!/bin/python
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
import torchvision.transforms as transforms
|
||||
from torch.utils.data import Dataset, random_split
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
def LoadCifar10DatasetTrain(save, transform=None):
|
||||
trainset = torchvision.datasets.CIFAR10(root=save, train=True,
|
||||
download=True, transform=transform)
|
||||
return trainset
|
||||
|
||||
def LoadCifar10DatasetTest(save, transform):
|
||||
return torchvision.datasets.CIFAR10(root=save, train=False,
|
||||
download=False, transform=transform)
|
||||
|
||||
def GetCustTransform():
|
||||
transform_train = transforms.Compose([
|
||||
transforms.RandomRotation(20),
|
||||
transforms.RandomCrop(32, (2, 2), pad_if_needed=False, padding_mode='constant'),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
|
||||
return transform_train
|
||||
|
||||
def Dataloader_train_valid(save, batch_size):
|
||||
|
||||
# See utils/dataloader.py for data augmentations
|
||||
transform_train_valid = GetCustTransform()
|
||||
|
||||
# Get Cifar 10 Datasets
|
||||
trainset = LoadCifar10DatasetTrain(save, transform_train_valid)
|
||||
train_val_abs = int(len(trainset) * 0.8)
|
||||
train_subset, val_subset = random_split(trainset, [train_val_abs, len(trainset) - train_val_abs])
|
||||
|
||||
# Get Cifar 10 Dataloaders
|
||||
trainloader = torch.utils.data.DataLoader(train_subset, batch_size=batch_size,
|
||||
shuffle=True, num_workers=4)
|
||||
|
||||
valloader = torch.utils.data.DataLoader(val_subset, batch_size=batch_size,
|
||||
shuffle=True, num_workers=4)
|
||||
return trainloader, valloader
|
||||
|
||||
def Dataloader_train(save, batch_size):
|
||||
|
||||
# See utils/dataloader.py for data augmentations
|
||||
transform_train = GetCustTransform()
|
||||
|
||||
# Get Cifar 10 Datasets
|
||||
trainset = LoadCifar10DatasetTrain(save, transform_train)
|
||||
# Get Cifar 10 Dataloaders
|
||||
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
|
||||
shuffle=True, num_workers=4)
|
||||
|
||||
return trainloader
|
||||
|
||||
def Dataloader_test(save, batch_size):
|
||||
|
||||
# transformation test set
|
||||
transform_test = transforms.Compose(
|
||||
[transforms.ToTensor(),
|
||||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
|
||||
|
||||
# initialize test dataset and dataloader
|
||||
testset = LoadCifar10DatasetTest(save, transform_test)
|
||||
testloader = torch.utils.data.DataLoader(testset, batch_size=64,
|
||||
shuffle=False, num_workers=4)
|
||||
|
||||
return testloader
|
||||
|
||||
def imshow(im):
|
||||
image = im.cpu().clone().detach().numpy()
|
||||
image = image.transpose(1, 2, 0)
|
||||
image = image * np.array((0.5, 0.5, 0.5)) + np.array((0.5, 0.5, 0.5)) # unnormalize
|
||||
plt.imshow(image)
|
||||
plt.show()
|
||||
|
||||
def imretrun(im):
|
||||
image = im.cpu().clone().detach().numpy()
|
||||
image = image.transpose(1, 2, 0)
|
||||
image = image * np.array((0.5, 0.5, 0.5)) + np.array((0.5, 0.5, 0.5)) # unnormalize
|
||||
return image
|
210
labml_nn/cnn/utils/train.py
Normal file
210
labml_nn/cnn/utils/train.py
Normal file
@ -0,0 +1,210 @@
|
||||
#!/bin/python
|
||||
|
||||
import torch.nn as nn
|
||||
import matplotlib.pyplot as plt
|
||||
import os
|
||||
from models.cnn import GetCNN
|
||||
from ray import tune
|
||||
from utils.dataloader import * # Get the transforms
|
||||
|
||||
|
||||
class Trainer():
|
||||
def __init__(self, name="default", device=None):
|
||||
self.device = device
|
||||
|
||||
self.epoch = 0
|
||||
self.start_epoch = 0
|
||||
self.name = name
|
||||
|
||||
# Train function
|
||||
def Train(self, net, trainloader, testloader, cost, opt, epochs = 25):
|
||||
|
||||
self.net = net
|
||||
self.trainloader = trainloader
|
||||
self.testloader = testloader
|
||||
|
||||
# Optimizer and Cost function
|
||||
self.opt = opt
|
||||
self.cost = cost
|
||||
|
||||
# Bookkeeping
|
||||
train_loss = torch.zeros(epochs)
|
||||
self.epoch = 0
|
||||
train_steps = 0
|
||||
accuracy = torch.zeros(epochs)
|
||||
|
||||
# Training loop
|
||||
for epoch in range(self.start_epoch, self.start_epoch+epochs):
|
||||
self.epoch = epoch+1
|
||||
self.net.train() # Enable Dropout
|
||||
|
||||
# Iterating over train data
|
||||
for data in self.trainloader:
|
||||
if self.device:
|
||||
images, labels = data[0].to(self.device), data[1].to(self.device)
|
||||
else:
|
||||
images, labels = data[0], data[1]
|
||||
|
||||
self.opt.zero_grad()
|
||||
|
||||
# Forward + backward + optimize
|
||||
outputs = self.net(images)
|
||||
epoch_loss = self.cost(outputs, labels)
|
||||
epoch_loss.backward()
|
||||
self.opt.step()
|
||||
train_steps+=1
|
||||
|
||||
train_loss[epoch] += epoch_loss.item()
|
||||
loss_train = train_loss[epoch] / train_steps
|
||||
|
||||
accuracy[epoch] = self.Test() #correct / total
|
||||
|
||||
print("Epoch %d LR %.6f Train Loss: %.3f Test Accuracy: %.3f" % (
|
||||
self.epoch, self.opt.param_groups[0]['lr'], loss_train, accuracy[epoch]))
|
||||
|
||||
# Save best model
|
||||
if accuracy[epoch] >= torch.max(accuracy):
|
||||
self.save_best_model({
|
||||
'epoch': self.epoch,
|
||||
'state_dict': self.net.state_dict(),
|
||||
'optimizer': self.opt.state_dict(),
|
||||
})
|
||||
|
||||
self.plot_accuracy(accuracy)
|
||||
|
||||
# Test over testloader loop
|
||||
def Test(self, net = None, save=None):
|
||||
# Initialize dataloader
|
||||
if save == None:
|
||||
testloader = self.testloader
|
||||
else:
|
||||
testloader = Dataloader_test(save, batch_size=128)
|
||||
|
||||
# Initialize net
|
||||
if net == None:
|
||||
net = self.net
|
||||
|
||||
# Disable Dropout
|
||||
net.eval()
|
||||
|
||||
# Bookkeeping
|
||||
correct = 0.0
|
||||
total = 0.0
|
||||
|
||||
# Infer the model
|
||||
with torch.no_grad():
|
||||
for data in testloader:
|
||||
if self.device:
|
||||
images, labels = data[0].to(self.device), data[1].to(self.device)
|
||||
else:
|
||||
images, labels = data[0], data[1]
|
||||
|
||||
outputs = net(images)
|
||||
_, predicted = torch.max(outputs.data, 1)
|
||||
total += labels.size(0)
|
||||
correct += (predicted == labels).sum().item()
|
||||
|
||||
# compute the final accuracy
|
||||
accuracy = correct / total
|
||||
return accuracy
|
||||
|
||||
# Train function modified for ray schedulers
|
||||
def Train_ray(self, config, checkpoint_dir=None, data_dir=None):
|
||||
epochs = 25
|
||||
|
||||
self.net = GetCNN(config["l1"], config["l2"])
|
||||
self.net.to(self.device)
|
||||
|
||||
trainloader, valloader = Dataloader_train_valid(data_dir, batch_size=config["batch_size"])
|
||||
|
||||
# Optimizer and Cost function
|
||||
self.opt = torch.optim.Adam(self.net.parameters(), lr=config["lr"], betas=(0.9, 0.95), weight_decay=config["decay"])
|
||||
self.cost = nn.CrossEntropyLoss()
|
||||
|
||||
# restoring checkpoint
|
||||
if checkpoint_dir:
|
||||
checkpoint = os.path.join(checkpoint_dir, "checkpoint")
|
||||
# checkpoint = checkpoint_dir
|
||||
model_state, optimizer_state = torch.load(checkpoint)
|
||||
self.net.load_state_dict(model_state)
|
||||
self.opt.load_state_dict(optimizer_state)
|
||||
|
||||
self.net.train()
|
||||
|
||||
# Record loss/accuracies
|
||||
train_loss = torch.zeros(epochs)
|
||||
self.epoch = 0
|
||||
train_steps = 0
|
||||
for epoch in range(self.start_epoch, self.start_epoch+epochs):
|
||||
self.epoch = epoch+1
|
||||
|
||||
self.net.train() # Enable Dropout
|
||||
for data in trainloader:
|
||||
# Get the inputs; data is a list of [inputs, labels]
|
||||
if self.device:
|
||||
images, labels = data[0].to(self.device), data[1].to(self.device)
|
||||
else:
|
||||
images, labels = data[0], data[1]
|
||||
|
||||
self.opt.zero_grad()
|
||||
# Forward + backward + optimize
|
||||
outputs = self.net(images)
|
||||
epoch_loss = self.cost(outputs, labels)
|
||||
epoch_loss.backward()
|
||||
self.opt.step()
|
||||
train_steps+=1
|
||||
|
||||
train_loss[epoch] += epoch_loss.item()
|
||||
|
||||
# Validation loss
|
||||
val_loss = 0.0
|
||||
val_steps = 0
|
||||
total = 0
|
||||
correct = 0
|
||||
self.net.eval()
|
||||
for data in valloader:
|
||||
with torch.no_grad():
|
||||
# Get the inputs; data is a list of [inputs, labels]
|
||||
if self.device:
|
||||
images, labels = data[0].to(self.device), data[1].to(self.device)
|
||||
else:
|
||||
images, labels = data[0], data[1]
|
||||
|
||||
# Forward + backward + optimize
|
||||
outputs = self.net(images)
|
||||
_, predicted = torch.max(outputs.data, 1)
|
||||
total += labels.size(0)
|
||||
correct += (predicted == labels).sum().item()
|
||||
|
||||
loss = self.cost(outputs, labels)
|
||||
val_loss += loss.cpu().numpy()
|
||||
val_steps += 1
|
||||
|
||||
# Save checkpoints
|
||||
with tune.checkpoint_dir(step=epoch) as checkpoint_dir:
|
||||
path = os.path.join(checkpoint_dir, "checkpoint")
|
||||
torch.save(
|
||||
(self.net.state_dict(), self.opt.state_dict()), path)
|
||||
|
||||
tune.report(loss=(val_loss / val_steps), accuracy=correct / total)
|
||||
|
||||
return train_loss
|
||||
|
||||
def plot_accuracy(self, accuracy, criterea = "accuracy"):
|
||||
plt.plot(accuracy.numpy())
|
||||
plt.ylabel("Accuracy" if criterea == "accuracy" else "Loss")
|
||||
plt.xlabel("Epochs")
|
||||
plt.show()
|
||||
|
||||
|
||||
def save_checkpoint(self, state):
|
||||
directory = os.path.dirname("./save/%s-checkpoints/"%(self.name))
|
||||
if not os.path.exists(directory):
|
||||
os.mkdir(directory)
|
||||
torch.save(state, "%s/model_epoch_%s.pt" %(directory, self.epoch))
|
||||
|
||||
def save_best_model(self, state):
|
||||
directory = os.path.dirname("./save/%s-best-model/"%(self.name))
|
||||
if not os.path.exists(directory):
|
||||
os.mkdir(directory)
|
||||
torch.save(state, "%s/model.pt" %(directory))
|
Reference in New Issue
Block a user