mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-15 02:07:56 +08:00
Implemented CNN filter visualization and Hyperparameter tuning (Ray)
This commit is contained in:

committed by
Varuna Jayasiri

parent
ba5c7200e8
commit
73ce42ec4c
164
labml_nn/cnn/cnn_visualization.py
Executable file
164
labml_nn/cnn/cnn_visualization.py
Executable file
@ -0,0 +1,164 @@
|
|||||||
|
#!/bin/python
|
||||||
|
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.optim as optim
|
||||||
|
from torchsummary import summary
|
||||||
|
from functools import partial
|
||||||
|
from skimage.filters import sobel, sobel_h, roberts
|
||||||
|
from models.cnn import CNN
|
||||||
|
from utils.dataloader import *
|
||||||
|
from utils.train import Trainer
|
||||||
|
|
||||||
|
# Check if GPU is available
|
||||||
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||||
|
print("Device: " + str(device))
|
||||||
|
|
||||||
|
# Cifar 10 Datasets location
|
||||||
|
save='./data/Cifar10'
|
||||||
|
|
||||||
|
# Transformations train
|
||||||
|
transform_train = transforms.Compose(
|
||||||
|
[transforms.ToTensor(),
|
||||||
|
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
|
||||||
|
|
||||||
|
# Load train dataset and dataloader
|
||||||
|
trainset = LoadCifar10DatasetTrain(save, transform_train)
|
||||||
|
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64,
|
||||||
|
shuffle=True, num_workers=4)
|
||||||
|
|
||||||
|
# Transformations test
|
||||||
|
transform_test = transforms.Compose(
|
||||||
|
[transforms.ToTensor(),
|
||||||
|
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
|
||||||
|
|
||||||
|
# Load test dataset and dataloader
|
||||||
|
testset = LoadCifar10DatasetTest(save, transform_test)
|
||||||
|
testloader = torch.utils.data.DataLoader(testset, batch_size=64,
|
||||||
|
shuffle=False, num_workers=4)
|
||||||
|
|
||||||
|
# Create CNN model
|
||||||
|
def GetCNN():
|
||||||
|
cnn = CNN( in_features=(32,32,3),
|
||||||
|
out_features=10,
|
||||||
|
conv_filters=[32,32,64,64],
|
||||||
|
conv_kernel_size=[3,3,3,3],
|
||||||
|
conv_strides=[1,1,1,1],
|
||||||
|
conv_pad=[0,0,0,0],
|
||||||
|
max_pool_kernels=[None, (2,2), None, (2,2)],
|
||||||
|
max_pool_strides=[None,2,None,2],
|
||||||
|
use_dropout=False,
|
||||||
|
use_batch_norm=True, #False
|
||||||
|
actv_func=["relu", "relu", "relu", "relu"],
|
||||||
|
device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
return cnn
|
||||||
|
|
||||||
|
model = GetCNN()
|
||||||
|
|
||||||
|
# Display model specifications
|
||||||
|
summary(model, (3,32,32))
|
||||||
|
|
||||||
|
# Send model to GPU
|
||||||
|
model.to(device)
|
||||||
|
|
||||||
|
# Specify optimizer
|
||||||
|
opt = optim.Adam(model.parameters(), lr=0.0005, betas=(0.9, 0.95))
|
||||||
|
|
||||||
|
# Specify loss function
|
||||||
|
cost = nn.CrossEntropyLoss()
|
||||||
|
|
||||||
|
# Train the model
|
||||||
|
trainer = Trainer(device=device, name="Basic_CNN")
|
||||||
|
epochs = 5
|
||||||
|
trainer.Train(model, trainloader, testloader, cost=cost, opt=opt, epochs=epochs)
|
||||||
|
|
||||||
|
# Load best saved model for inference
|
||||||
|
model_loaded = GetCNN()
|
||||||
|
|
||||||
|
# Specify location of saved model
|
||||||
|
PATH = "./save/Basic_CNN-best-model/model.pt"
|
||||||
|
checkpoint = torch.load(PATH)
|
||||||
|
|
||||||
|
# load the saved model
|
||||||
|
model_loaded.load_state_dict(checkpoint['state_dict'])
|
||||||
|
|
||||||
|
# intialization for hooks and storing activation of ReLU layers
|
||||||
|
activation = {}
|
||||||
|
hooks = []
|
||||||
|
|
||||||
|
# Hook function saves activation of a particular layer
|
||||||
|
def hook_fn(model, input, output, name):
|
||||||
|
activation[name] = output.cpu().detach().numpy()
|
||||||
|
|
||||||
|
# Registering hooks
|
||||||
|
count =0
|
||||||
|
conv_count = 0
|
||||||
|
for name, layer in model_loaded.named_modules():
|
||||||
|
if isinstance(layer, nn.ReLU):
|
||||||
|
count +=1
|
||||||
|
hook = layer.register_forward_hook(partial(hook_fn, name=f"{layer._get_name()}-{count}")) #f"{type(layer).__name__}-{name}"
|
||||||
|
hooks.append(hook)
|
||||||
|
if isinstance(layer, nn.Conv2d):
|
||||||
|
conv_count += 1
|
||||||
|
|
||||||
|
# Displaying image used for inference
|
||||||
|
data, _ = trainset[15]
|
||||||
|
imshow(data)
|
||||||
|
|
||||||
|
# Infering model to save activation of ReLU layers
|
||||||
|
output = model_loaded(data[None].to(device))
|
||||||
|
|
||||||
|
# Removing hooks
|
||||||
|
for hook in hooks:
|
||||||
|
hook.remove()
|
||||||
|
|
||||||
|
# Function to display output of a particular ReLU layer
|
||||||
|
def output_one_layer(layer_num):
|
||||||
|
assert 1 <= layer_num <= len(activation), "Wrong layer number"
|
||||||
|
|
||||||
|
layer_name = f"ReLu-{layer_num}"
|
||||||
|
act = activation[f"ReLU-{layer_num}"]
|
||||||
|
if act.shape[1]==32:
|
||||||
|
rows = 4
|
||||||
|
columns = 8
|
||||||
|
elif act.shape[1]==64:
|
||||||
|
rows = 8
|
||||||
|
columns = 8
|
||||||
|
|
||||||
|
fig = plt.figure(figsize=(rows, columns))
|
||||||
|
for idx in range(1, columns * rows + 1):
|
||||||
|
fig.add_subplot(rows, columns, idx)
|
||||||
|
plt.imshow(sobel(act[0][idx-1]), cmap=plt.cm.gray)
|
||||||
|
|
||||||
|
# try different filters
|
||||||
|
# plt.imshow(act[0][idx-1], cmap='viridis', vmin=0, vmax=act.max())
|
||||||
|
# plt.imshow(act[0][idx - 1], cmap='hot')
|
||||||
|
# plt.imshow(roberts(act[0][idx - 1]), cmap=plt.cm.gray)
|
||||||
|
# plt.imshow(sobel_h(act[0][idx-1]), cmap=plt.cm.gray)
|
||||||
|
|
||||||
|
plt.axis('off')
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
# Function to display output of all ReLU layer after Convulution layers
|
||||||
|
def output_all_layers():
|
||||||
|
for [name, output], count in zip(activation.items(), range(conv_count)):
|
||||||
|
if output.shape[1] == 32:
|
||||||
|
_, axs = plt.subplots(8, 4, figsize=(8, 4))
|
||||||
|
elif output.shape[1] == 64:
|
||||||
|
_, axs = plt.subplots(8, 8, figsize=(8, 8))
|
||||||
|
|
||||||
|
for ax, out in zip(np.ravel(axs), output[0]):
|
||||||
|
ax.imshow(sobel(out), cmap=plt.cm.gray)
|
||||||
|
ax.axis('off')
|
||||||
|
|
||||||
|
plt.suptitle(name)
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
# Choose either one to display
|
||||||
|
output_one_layer(layer_num=3) # choose layer number
|
||||||
|
output_all_layers()
|
||||||
|
|
193
labml_nn/cnn/models/cnn.py
Executable file
193
labml_nn/cnn/models/cnn.py
Executable file
@ -0,0 +1,193 @@
|
|||||||
|
#!/bin/python
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
|
||||||
|
# Use the formula:
|
||||||
|
# [(W-K+2P)/S] + 1
|
||||||
|
# where:
|
||||||
|
# W: Is the input volume size for each dimension
|
||||||
|
# K: Is the kernel size
|
||||||
|
# P: Is the padding
|
||||||
|
# S: Is the stride
|
||||||
|
|
||||||
|
def CalcConvFormula(W, K, P, S):
|
||||||
|
return int(np.floor(((W - K + 2 * P) / S) + 1))
|
||||||
|
|
||||||
|
|
||||||
|
# https://stackoverflow.com/questions/53580088/calculate-the-output-size-in-convolution-layer
|
||||||
|
# Calculate the output shape after applying a convolution
|
||||||
|
def CalcConvOutShape(in_shape, kernel_size, padding, stride, out_filters):
|
||||||
|
# Multiple options for different kernel shapes
|
||||||
|
if type(kernel_size) == int:
|
||||||
|
out_shape = [CalcConvFormula(in_shape[i], kernel_size, padding, stride) for i in range(2)]
|
||||||
|
else:
|
||||||
|
out_shape = [CalcConvFormula(in_shape[i], kernel_size[i], padding, stride) for i in range(2)]
|
||||||
|
|
||||||
|
return (out_shape[0], out_shape[1], out_filters) # , batch_size... but not necessary.
|
||||||
|
|
||||||
|
class CNN(nn.Module):
|
||||||
|
def __init__(self
|
||||||
|
, in_features
|
||||||
|
, out_features
|
||||||
|
, conv_filters
|
||||||
|
, conv_kernel_size
|
||||||
|
, conv_strides
|
||||||
|
, conv_pad
|
||||||
|
, actv_func
|
||||||
|
, max_pool_kernels
|
||||||
|
, max_pool_strides
|
||||||
|
, l1=120
|
||||||
|
, l2=84
|
||||||
|
, MLP=None
|
||||||
|
, pre_module_list=None
|
||||||
|
, use_dropout=False
|
||||||
|
, use_batch_norm=False
|
||||||
|
, device="cpu"
|
||||||
|
):
|
||||||
|
super(CNN, self).__init__()
|
||||||
|
|
||||||
|
# Gerneral model Properties
|
||||||
|
self.in_features = in_features
|
||||||
|
self.out_features = out_features
|
||||||
|
|
||||||
|
# Convolution operations
|
||||||
|
self.conv_filters = conv_filters
|
||||||
|
self.conv_kernel_size = conv_kernel_size
|
||||||
|
self.conv_strides = conv_strides
|
||||||
|
self.conv_pad = conv_pad
|
||||||
|
|
||||||
|
# Convolution Activiations
|
||||||
|
self.actv_func = actv_func
|
||||||
|
|
||||||
|
# Max Pools
|
||||||
|
self.max_pool_kernels = max_pool_kernels
|
||||||
|
self.max_pool_strides = max_pool_strides
|
||||||
|
|
||||||
|
# Regularization
|
||||||
|
self.use_dropout = use_dropout
|
||||||
|
self.use_batch_norm = use_batch_norm
|
||||||
|
|
||||||
|
# Tunable parameters
|
||||||
|
self.l1 = l1
|
||||||
|
self.l2 = l2
|
||||||
|
|
||||||
|
# Number of conv/pool/act/batch_norm/dropout layers we add
|
||||||
|
self.n_conv_layers = len(self.conv_filters)
|
||||||
|
|
||||||
|
# Create the module list
|
||||||
|
if pre_module_list:
|
||||||
|
self.module_list = pre_module_list
|
||||||
|
else:
|
||||||
|
self.module_list = nn.ModuleList()
|
||||||
|
|
||||||
|
self.shape_list = []
|
||||||
|
self.shape_list.append(self.in_features)
|
||||||
|
|
||||||
|
self.build_()
|
||||||
|
|
||||||
|
# Send to gpu
|
||||||
|
self.device = device
|
||||||
|
self.to(self.device)
|
||||||
|
|
||||||
|
def build_(self):
|
||||||
|
# Track shape
|
||||||
|
cur_shape = self.GetCurShape()
|
||||||
|
|
||||||
|
for i in range(self.n_conv_layers):
|
||||||
|
if i == 0:
|
||||||
|
if len(self.in_features) == 2:
|
||||||
|
in_channels = 1
|
||||||
|
else:
|
||||||
|
in_channels = self.in_features[2]
|
||||||
|
else:
|
||||||
|
in_channels = self.conv_filters[i - 1]
|
||||||
|
|
||||||
|
cur_shape = CalcConvOutShape(cur_shape, self.conv_kernel_size[i], self.conv_pad[i], self.conv_strides[i],
|
||||||
|
self.conv_filters[i])
|
||||||
|
self.shape_list.append(cur_shape)
|
||||||
|
|
||||||
|
conv = nn.Conv2d(in_channels=in_channels,
|
||||||
|
out_channels=self.conv_filters[i],
|
||||||
|
kernel_size=self.conv_kernel_size[i],
|
||||||
|
padding=self.conv_pad[i],
|
||||||
|
stride=self.conv_strides[i]
|
||||||
|
)
|
||||||
|
self.module_list.append(conv)
|
||||||
|
|
||||||
|
if self.use_batch_norm:
|
||||||
|
self.module_list.append(nn.BatchNorm2d(cur_shape[2]))
|
||||||
|
|
||||||
|
if self.use_dropout:
|
||||||
|
self.module_list.append(nn.Dropout(p=0.15))
|
||||||
|
|
||||||
|
# Add the Activation function
|
||||||
|
if self.actv_func[i]:
|
||||||
|
self.module_list.append(GetActivation(name=self.actv_func[i]))
|
||||||
|
|
||||||
|
if self.max_pool_kernels:
|
||||||
|
if self.max_pool_kernels[i]:
|
||||||
|
self.module_list.append(nn.MaxPool2d(self.max_pool_kernels[i], stride=self.max_pool_strides[i]))
|
||||||
|
cur_shape = CalcConvOutShape(cur_shape, self.max_pool_kernels[i], 0, self.max_pool_strides[i],
|
||||||
|
cur_shape[2])
|
||||||
|
self.shape_list.append(cur_shape)
|
||||||
|
|
||||||
|
# # Adding MLP
|
||||||
|
s = self.GetCurShape()
|
||||||
|
in_features = s[0] * s[1] * s[2]
|
||||||
|
self.module_list.append(nn.Linear(in_features, self.l1))
|
||||||
|
self.module_list.append(nn.ReLU())
|
||||||
|
self.module_list.append(nn.Linear(self.l1, self.l2))
|
||||||
|
self.module_list.append(nn.ReLU())
|
||||||
|
self.module_list.append(nn.Linear(self.l2, self.out_features))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
j = 0
|
||||||
|
for i, module in enumerate(self.module_list):
|
||||||
|
if isinstance(module, nn.Linear) and j == 0:
|
||||||
|
x = torch.flatten(x.float(), start_dim=1)
|
||||||
|
j = 1
|
||||||
|
x = module(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def GetCurShape(self):
|
||||||
|
return self.shape_list[-1]
|
||||||
|
|
||||||
|
def GetCNN(l1=120, l2=84):
|
||||||
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||||
|
cnn = CNN(in_features=(32, 32, 3),
|
||||||
|
out_features=10,
|
||||||
|
conv_filters=[32, 32, 64, 64], # , 128, 256, 512
|
||||||
|
conv_kernel_size=[3, 3, 3, 3], # ,3,3,1
|
||||||
|
conv_strides=[1, 1, 1, 1], # ,1,1,1
|
||||||
|
conv_pad=[0, 0, 0, 0, 0, 0, 0],
|
||||||
|
actv_func=["relu", "relu", "relu", "relu"], # , "relu", "relu", "relu"
|
||||||
|
max_pool_kernels=[None, (2, 2), None, (2, 2)], # , None, None, None
|
||||||
|
max_pool_strides=[None, 2, None, 2], # , None,None, None
|
||||||
|
l1=l1,
|
||||||
|
l2=l2,
|
||||||
|
use_dropout=False,
|
||||||
|
use_batch_norm=True, # False
|
||||||
|
device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
return cnn
|
||||||
|
|
||||||
|
|
||||||
|
def GetActivation(name="relu"):
|
||||||
|
if name == "relu":
|
||||||
|
return nn.ReLU()
|
||||||
|
elif name == "leakyrelu":
|
||||||
|
return nn.LeakyReLU()
|
||||||
|
elif name == "Sigmoid":
|
||||||
|
return nn.Sigmoid()
|
||||||
|
elif name == "Tanh":
|
||||||
|
return nn.Tanh()
|
||||||
|
elif name == "Identity":
|
||||||
|
return nn.Identity()
|
||||||
|
else:
|
||||||
|
return nn.ReLU()
|
106
labml_nn/cnn/ray_tune.py
Normal file
106
labml_nn/cnn/ray_tune.py
Normal file
@ -0,0 +1,106 @@
|
|||||||
|
#!/bin/python
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
from ray import tune
|
||||||
|
from ray.tune.schedulers import ASHAScheduler, PopulationBasedTraining
|
||||||
|
from utils.train import Trainer
|
||||||
|
from models.cnn import GetCNN
|
||||||
|
|
||||||
|
# Check if GPU is available
|
||||||
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||||
|
print("Device: " + str(device))
|
||||||
|
|
||||||
|
#
|
||||||
|
num_samples= 40 # for multiple trials
|
||||||
|
max_num_epochs= 25
|
||||||
|
gpus_per_trial= 1
|
||||||
|
|
||||||
|
# Cifar 10 Datasets location
|
||||||
|
data_dir = './data/Cifar10'
|
||||||
|
|
||||||
|
"""config - returns a dict of hyperparameters
|
||||||
|
|
||||||
|
Selecting different hyperparameters for tuning
|
||||||
|
l1 : Number of units in first fully connected layer
|
||||||
|
l2 : Number of units in second fully connected layer
|
||||||
|
lr : Learning rate
|
||||||
|
decay : Decay rate for regularization
|
||||||
|
batch_size : Batch size of test and train data
|
||||||
|
"""
|
||||||
|
config = {
|
||||||
|
"l1": tune.sample_from(lambda _: 2 ** np.random.randint(2, 9)), # eg. 4, 8, 16 .. 512
|
||||||
|
"l2": tune.sample_from(lambda _: 2 ** np.random.randint(2, 9)), # eg. 4, 8, 16 .. 512
|
||||||
|
"lr": tune.loguniform(1e-4, 1e-1), # Sampling from log uniform distribution
|
||||||
|
"decay": tune.sample_from(lambda _: 10 ** np.random.randint(-7, -3)), # eg. 1e-7, 1e-6, .. 1e-3
|
||||||
|
"batch_size": tune.choice([32, 64, 128, 256])
|
||||||
|
}
|
||||||
|
|
||||||
|
# calling trainer
|
||||||
|
trainer = Trainer(device=device)
|
||||||
|
|
||||||
|
"""ASHA (Asynchronous Successive Halving Algorithm) scheduler
|
||||||
|
max_t : Maximum number of units per trail (can be time or epochs)
|
||||||
|
grace_period : Stop trials after specific number of unit if model is not performing well (can be time or epochs)
|
||||||
|
reduction_factor : Set halving rate
|
||||||
|
"""
|
||||||
|
scheduler = ASHAScheduler(
|
||||||
|
max_t=max_num_epochs,
|
||||||
|
grace_period=4,
|
||||||
|
reduction_factor=4)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
"""Population based training scheduler
|
||||||
|
time_attr : Can be time or epochs
|
||||||
|
metric : Objective of training (loss or accuracy)
|
||||||
|
perturbation_interval : Perturbation occur after specified unit (can be time or epochs)
|
||||||
|
hyperparam_mutations : Hyperparameters to mutate
|
||||||
|
"""
|
||||||
|
scheduler = PopulationBasedTraining(
|
||||||
|
time_attr= "training_iteration", # epochs
|
||||||
|
metric='loss', # loss is objective function
|
||||||
|
mode='min', # minimizing loss is objective of training
|
||||||
|
perturbation_interval=5.0, # after 5 epochs perturbate
|
||||||
|
hyperparam_mutations={
|
||||||
|
"lr": [1e-3, 5e-4, 1e-4, 5e-4, 1e-5], # choose from given learning rates
|
||||||
|
"batch_size": [64, 128, 256], # choose from given batch sizes
|
||||||
|
"decay": tune.uniform(10**-8, 10**-4) # sample from uniform distribution
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
result = tune.run(
|
||||||
|
tune.with_parameters(trainer.Train_ray, data_dir=data_dir),
|
||||||
|
name="ray_test_basic-CNN", # name for identifying models (checkpoints)
|
||||||
|
scheduler=scheduler, # select scheduler PBT or ASHA
|
||||||
|
resources_per_trial={"cpu": 8, "gpu": gpus_per_trial}, # select number of CPUs or GPUs
|
||||||
|
config=config, # input config dict consisting of different hyperparameters
|
||||||
|
stop={
|
||||||
|
"training_iteration": max_num_epochs, # stopping criterea
|
||||||
|
},
|
||||||
|
metric="loss", # uncomment for ASHA scheduler
|
||||||
|
mode="min", # uncomment for ASHA scheduler
|
||||||
|
num_samples=num_samples,
|
||||||
|
verbose=True, # keep to true to check how training progresses
|
||||||
|
fail_fast=True, # fail on first error
|
||||||
|
keep_checkpoints_num=5, # number of checkpoints to be saved per num_samples
|
||||||
|
|
||||||
|
)
|
||||||
|
|
||||||
|
best_trial = result.get_best_trial("loss", "min", "last")
|
||||||
|
print("Best configuration: {}".format(best_trial.config))
|
||||||
|
print("Best validation loss: {}".format(best_trial.last_result["loss"]))
|
||||||
|
print("Best validation accuracy: {}".format(
|
||||||
|
best_trial.last_result["accuracy"]))
|
||||||
|
|
||||||
|
|
||||||
|
best_trained_model = GetCNN(best_trial.config["l1"], best_trial.config["l2"])
|
||||||
|
best_trained_model.to(device)
|
||||||
|
checkpoint_path = os.path.join(best_trial.checkpoint.value, "checkpoint")
|
||||||
|
model_state, optimizer_state = torch.load(checkpoint_path)
|
||||||
|
best_trained_model.load_state_dict(model_state)
|
||||||
|
|
||||||
|
# Check accuracy of best model
|
||||||
|
test_acc = trainer.Test(best_trained_model, save=data_dir)
|
||||||
|
print("Best Test accuracy: {}".format(test_acc))
|
BIN
labml_nn/cnn/save/Basic_CNN-best-model/model.pt
Normal file
BIN
labml_nn/cnn/save/Basic_CNN-best-model/model.pt
Normal file
Binary file not shown.
83
labml_nn/cnn/utils/dataloader.py
Normal file
83
labml_nn/cnn/utils/dataloader.py
Normal file
@ -0,0 +1,83 @@
|
|||||||
|
#!/bin/python
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torchvision
|
||||||
|
import torchvision.transforms as transforms
|
||||||
|
from torch.utils.data import Dataset, random_split
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
def LoadCifar10DatasetTrain(save, transform=None):
|
||||||
|
trainset = torchvision.datasets.CIFAR10(root=save, train=True,
|
||||||
|
download=True, transform=transform)
|
||||||
|
return trainset
|
||||||
|
|
||||||
|
def LoadCifar10DatasetTest(save, transform):
|
||||||
|
return torchvision.datasets.CIFAR10(root=save, train=False,
|
||||||
|
download=False, transform=transform)
|
||||||
|
|
||||||
|
def GetCustTransform():
|
||||||
|
transform_train = transforms.Compose([
|
||||||
|
transforms.RandomRotation(20),
|
||||||
|
transforms.RandomCrop(32, (2, 2), pad_if_needed=False, padding_mode='constant'),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
|
||||||
|
return transform_train
|
||||||
|
|
||||||
|
def Dataloader_train_valid(save, batch_size):
|
||||||
|
|
||||||
|
# See utils/dataloader.py for data augmentations
|
||||||
|
transform_train_valid = GetCustTransform()
|
||||||
|
|
||||||
|
# Get Cifar 10 Datasets
|
||||||
|
trainset = LoadCifar10DatasetTrain(save, transform_train_valid)
|
||||||
|
train_val_abs = int(len(trainset) * 0.8)
|
||||||
|
train_subset, val_subset = random_split(trainset, [train_val_abs, len(trainset) - train_val_abs])
|
||||||
|
|
||||||
|
# Get Cifar 10 Dataloaders
|
||||||
|
trainloader = torch.utils.data.DataLoader(train_subset, batch_size=batch_size,
|
||||||
|
shuffle=True, num_workers=4)
|
||||||
|
|
||||||
|
valloader = torch.utils.data.DataLoader(val_subset, batch_size=batch_size,
|
||||||
|
shuffle=True, num_workers=4)
|
||||||
|
return trainloader, valloader
|
||||||
|
|
||||||
|
def Dataloader_train(save, batch_size):
|
||||||
|
|
||||||
|
# See utils/dataloader.py for data augmentations
|
||||||
|
transform_train = GetCustTransform()
|
||||||
|
|
||||||
|
# Get Cifar 10 Datasets
|
||||||
|
trainset = LoadCifar10DatasetTrain(save, transform_train)
|
||||||
|
# Get Cifar 10 Dataloaders
|
||||||
|
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
|
||||||
|
shuffle=True, num_workers=4)
|
||||||
|
|
||||||
|
return trainloader
|
||||||
|
|
||||||
|
def Dataloader_test(save, batch_size):
|
||||||
|
|
||||||
|
# transformation test set
|
||||||
|
transform_test = transforms.Compose(
|
||||||
|
[transforms.ToTensor(),
|
||||||
|
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
|
||||||
|
|
||||||
|
# initialize test dataset and dataloader
|
||||||
|
testset = LoadCifar10DatasetTest(save, transform_test)
|
||||||
|
testloader = torch.utils.data.DataLoader(testset, batch_size=64,
|
||||||
|
shuffle=False, num_workers=4)
|
||||||
|
|
||||||
|
return testloader
|
||||||
|
|
||||||
|
def imshow(im):
|
||||||
|
image = im.cpu().clone().detach().numpy()
|
||||||
|
image = image.transpose(1, 2, 0)
|
||||||
|
image = image * np.array((0.5, 0.5, 0.5)) + np.array((0.5, 0.5, 0.5)) # unnormalize
|
||||||
|
plt.imshow(image)
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
def imretrun(im):
|
||||||
|
image = im.cpu().clone().detach().numpy()
|
||||||
|
image = image.transpose(1, 2, 0)
|
||||||
|
image = image * np.array((0.5, 0.5, 0.5)) + np.array((0.5, 0.5, 0.5)) # unnormalize
|
||||||
|
return image
|
210
labml_nn/cnn/utils/train.py
Normal file
210
labml_nn/cnn/utils/train.py
Normal file
@ -0,0 +1,210 @@
|
|||||||
|
#!/bin/python
|
||||||
|
|
||||||
|
import torch.nn as nn
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import os
|
||||||
|
from models.cnn import GetCNN
|
||||||
|
from ray import tune
|
||||||
|
from utils.dataloader import * # Get the transforms
|
||||||
|
|
||||||
|
|
||||||
|
class Trainer():
|
||||||
|
def __init__(self, name="default", device=None):
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
self.epoch = 0
|
||||||
|
self.start_epoch = 0
|
||||||
|
self.name = name
|
||||||
|
|
||||||
|
# Train function
|
||||||
|
def Train(self, net, trainloader, testloader, cost, opt, epochs = 25):
|
||||||
|
|
||||||
|
self.net = net
|
||||||
|
self.trainloader = trainloader
|
||||||
|
self.testloader = testloader
|
||||||
|
|
||||||
|
# Optimizer and Cost function
|
||||||
|
self.opt = opt
|
||||||
|
self.cost = cost
|
||||||
|
|
||||||
|
# Bookkeeping
|
||||||
|
train_loss = torch.zeros(epochs)
|
||||||
|
self.epoch = 0
|
||||||
|
train_steps = 0
|
||||||
|
accuracy = torch.zeros(epochs)
|
||||||
|
|
||||||
|
# Training loop
|
||||||
|
for epoch in range(self.start_epoch, self.start_epoch+epochs):
|
||||||
|
self.epoch = epoch+1
|
||||||
|
self.net.train() # Enable Dropout
|
||||||
|
|
||||||
|
# Iterating over train data
|
||||||
|
for data in self.trainloader:
|
||||||
|
if self.device:
|
||||||
|
images, labels = data[0].to(self.device), data[1].to(self.device)
|
||||||
|
else:
|
||||||
|
images, labels = data[0], data[1]
|
||||||
|
|
||||||
|
self.opt.zero_grad()
|
||||||
|
|
||||||
|
# Forward + backward + optimize
|
||||||
|
outputs = self.net(images)
|
||||||
|
epoch_loss = self.cost(outputs, labels)
|
||||||
|
epoch_loss.backward()
|
||||||
|
self.opt.step()
|
||||||
|
train_steps+=1
|
||||||
|
|
||||||
|
train_loss[epoch] += epoch_loss.item()
|
||||||
|
loss_train = train_loss[epoch] / train_steps
|
||||||
|
|
||||||
|
accuracy[epoch] = self.Test() #correct / total
|
||||||
|
|
||||||
|
print("Epoch %d LR %.6f Train Loss: %.3f Test Accuracy: %.3f" % (
|
||||||
|
self.epoch, self.opt.param_groups[0]['lr'], loss_train, accuracy[epoch]))
|
||||||
|
|
||||||
|
# Save best model
|
||||||
|
if accuracy[epoch] >= torch.max(accuracy):
|
||||||
|
self.save_best_model({
|
||||||
|
'epoch': self.epoch,
|
||||||
|
'state_dict': self.net.state_dict(),
|
||||||
|
'optimizer': self.opt.state_dict(),
|
||||||
|
})
|
||||||
|
|
||||||
|
self.plot_accuracy(accuracy)
|
||||||
|
|
||||||
|
# Test over testloader loop
|
||||||
|
def Test(self, net = None, save=None):
|
||||||
|
# Initialize dataloader
|
||||||
|
if save == None:
|
||||||
|
testloader = self.testloader
|
||||||
|
else:
|
||||||
|
testloader = Dataloader_test(save, batch_size=128)
|
||||||
|
|
||||||
|
# Initialize net
|
||||||
|
if net == None:
|
||||||
|
net = self.net
|
||||||
|
|
||||||
|
# Disable Dropout
|
||||||
|
net.eval()
|
||||||
|
|
||||||
|
# Bookkeeping
|
||||||
|
correct = 0.0
|
||||||
|
total = 0.0
|
||||||
|
|
||||||
|
# Infer the model
|
||||||
|
with torch.no_grad():
|
||||||
|
for data in testloader:
|
||||||
|
if self.device:
|
||||||
|
images, labels = data[0].to(self.device), data[1].to(self.device)
|
||||||
|
else:
|
||||||
|
images, labels = data[0], data[1]
|
||||||
|
|
||||||
|
outputs = net(images)
|
||||||
|
_, predicted = torch.max(outputs.data, 1)
|
||||||
|
total += labels.size(0)
|
||||||
|
correct += (predicted == labels).sum().item()
|
||||||
|
|
||||||
|
# compute the final accuracy
|
||||||
|
accuracy = correct / total
|
||||||
|
return accuracy
|
||||||
|
|
||||||
|
# Train function modified for ray schedulers
|
||||||
|
def Train_ray(self, config, checkpoint_dir=None, data_dir=None):
|
||||||
|
epochs = 25
|
||||||
|
|
||||||
|
self.net = GetCNN(config["l1"], config["l2"])
|
||||||
|
self.net.to(self.device)
|
||||||
|
|
||||||
|
trainloader, valloader = Dataloader_train_valid(data_dir, batch_size=config["batch_size"])
|
||||||
|
|
||||||
|
# Optimizer and Cost function
|
||||||
|
self.opt = torch.optim.Adam(self.net.parameters(), lr=config["lr"], betas=(0.9, 0.95), weight_decay=config["decay"])
|
||||||
|
self.cost = nn.CrossEntropyLoss()
|
||||||
|
|
||||||
|
# restoring checkpoint
|
||||||
|
if checkpoint_dir:
|
||||||
|
checkpoint = os.path.join(checkpoint_dir, "checkpoint")
|
||||||
|
# checkpoint = checkpoint_dir
|
||||||
|
model_state, optimizer_state = torch.load(checkpoint)
|
||||||
|
self.net.load_state_dict(model_state)
|
||||||
|
self.opt.load_state_dict(optimizer_state)
|
||||||
|
|
||||||
|
self.net.train()
|
||||||
|
|
||||||
|
# Record loss/accuracies
|
||||||
|
train_loss = torch.zeros(epochs)
|
||||||
|
self.epoch = 0
|
||||||
|
train_steps = 0
|
||||||
|
for epoch in range(self.start_epoch, self.start_epoch+epochs):
|
||||||
|
self.epoch = epoch+1
|
||||||
|
|
||||||
|
self.net.train() # Enable Dropout
|
||||||
|
for data in trainloader:
|
||||||
|
# Get the inputs; data is a list of [inputs, labels]
|
||||||
|
if self.device:
|
||||||
|
images, labels = data[0].to(self.device), data[1].to(self.device)
|
||||||
|
else:
|
||||||
|
images, labels = data[0], data[1]
|
||||||
|
|
||||||
|
self.opt.zero_grad()
|
||||||
|
# Forward + backward + optimize
|
||||||
|
outputs = self.net(images)
|
||||||
|
epoch_loss = self.cost(outputs, labels)
|
||||||
|
epoch_loss.backward()
|
||||||
|
self.opt.step()
|
||||||
|
train_steps+=1
|
||||||
|
|
||||||
|
train_loss[epoch] += epoch_loss.item()
|
||||||
|
|
||||||
|
# Validation loss
|
||||||
|
val_loss = 0.0
|
||||||
|
val_steps = 0
|
||||||
|
total = 0
|
||||||
|
correct = 0
|
||||||
|
self.net.eval()
|
||||||
|
for data in valloader:
|
||||||
|
with torch.no_grad():
|
||||||
|
# Get the inputs; data is a list of [inputs, labels]
|
||||||
|
if self.device:
|
||||||
|
images, labels = data[0].to(self.device), data[1].to(self.device)
|
||||||
|
else:
|
||||||
|
images, labels = data[0], data[1]
|
||||||
|
|
||||||
|
# Forward + backward + optimize
|
||||||
|
outputs = self.net(images)
|
||||||
|
_, predicted = torch.max(outputs.data, 1)
|
||||||
|
total += labels.size(0)
|
||||||
|
correct += (predicted == labels).sum().item()
|
||||||
|
|
||||||
|
loss = self.cost(outputs, labels)
|
||||||
|
val_loss += loss.cpu().numpy()
|
||||||
|
val_steps += 1
|
||||||
|
|
||||||
|
# Save checkpoints
|
||||||
|
with tune.checkpoint_dir(step=epoch) as checkpoint_dir:
|
||||||
|
path = os.path.join(checkpoint_dir, "checkpoint")
|
||||||
|
torch.save(
|
||||||
|
(self.net.state_dict(), self.opt.state_dict()), path)
|
||||||
|
|
||||||
|
tune.report(loss=(val_loss / val_steps), accuracy=correct / total)
|
||||||
|
|
||||||
|
return train_loss
|
||||||
|
|
||||||
|
def plot_accuracy(self, accuracy, criterea = "accuracy"):
|
||||||
|
plt.plot(accuracy.numpy())
|
||||||
|
plt.ylabel("Accuracy" if criterea == "accuracy" else "Loss")
|
||||||
|
plt.xlabel("Epochs")
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
def save_checkpoint(self, state):
|
||||||
|
directory = os.path.dirname("./save/%s-checkpoints/"%(self.name))
|
||||||
|
if not os.path.exists(directory):
|
||||||
|
os.mkdir(directory)
|
||||||
|
torch.save(state, "%s/model_epoch_%s.pt" %(directory, self.epoch))
|
||||||
|
|
||||||
|
def save_best_model(self, state):
|
||||||
|
directory = os.path.dirname("./save/%s-best-model/"%(self.name))
|
||||||
|
if not os.path.exists(directory):
|
||||||
|
os.mkdir(directory)
|
||||||
|
torch.save(state, "%s/model.pt" %(directory))
|
Reference in New Issue
Block a user