diff --git a/labml_nn/cnn/cnn_visualization.py b/labml_nn/cnn/cnn_visualization.py new file mode 100755 index 00000000..8d14765a --- /dev/null +++ b/labml_nn/cnn/cnn_visualization.py @@ -0,0 +1,164 @@ +#!/bin/python + +import torch.nn as nn +import torch.optim as optim +from torchsummary import summary +from functools import partial +from skimage.filters import sobel, sobel_h, roberts +from models.cnn import CNN +from utils.dataloader import * +from utils.train import Trainer + +# Check if GPU is available +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +print("Device: " + str(device)) + +# Cifar 10 Datasets location +save='./data/Cifar10' + +# Transformations train +transform_train = transforms.Compose( + [transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) + +# Load train dataset and dataloader +trainset = LoadCifar10DatasetTrain(save, transform_train) +trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, + shuffle=True, num_workers=4) + +# Transformations test +transform_test = transforms.Compose( + [transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) + +# Load test dataset and dataloader +testset = LoadCifar10DatasetTest(save, transform_test) +testloader = torch.utils.data.DataLoader(testset, batch_size=64, + shuffle=False, num_workers=4) + +# Create CNN model +def GetCNN(): + cnn = CNN( in_features=(32,32,3), + out_features=10, + conv_filters=[32,32,64,64], + conv_kernel_size=[3,3,3,3], + conv_strides=[1,1,1,1], + conv_pad=[0,0,0,0], + max_pool_kernels=[None, (2,2), None, (2,2)], + max_pool_strides=[None,2,None,2], + use_dropout=False, + use_batch_norm=True, #False + actv_func=["relu", "relu", "relu", "relu"], + device=device + ) + + return cnn + +model = GetCNN() + +# Display model specifications +summary(model, (3,32,32)) + +# Send model to GPU +model.to(device) + +# Specify optimizer +opt = optim.Adam(model.parameters(), lr=0.0005, betas=(0.9, 0.95)) + +# Specify loss function +cost = nn.CrossEntropyLoss() + +# Train the model +trainer = Trainer(device=device, name="Basic_CNN") +epochs = 5 +trainer.Train(model, trainloader, testloader, cost=cost, opt=opt, epochs=epochs) + +# Load best saved model for inference +model_loaded = GetCNN() + +# Specify location of saved model +PATH = "./save/Basic_CNN-best-model/model.pt" +checkpoint = torch.load(PATH) + +# load the saved model +model_loaded.load_state_dict(checkpoint['state_dict']) + +# intialization for hooks and storing activation of ReLU layers +activation = {} +hooks = [] + +# Hook function saves activation of a particular layer +def hook_fn(model, input, output, name): + activation[name] = output.cpu().detach().numpy() + +# Registering hooks +count =0 +conv_count = 0 +for name, layer in model_loaded.named_modules(): + if isinstance(layer, nn.ReLU): + count +=1 + hook = layer.register_forward_hook(partial(hook_fn, name=f"{layer._get_name()}-{count}")) #f"{type(layer).__name__}-{name}" + hooks.append(hook) + if isinstance(layer, nn.Conv2d): + conv_count += 1 + +# Displaying image used for inference +data, _ = trainset[15] +imshow(data) + +# Infering model to save activation of ReLU layers +output = model_loaded(data[None].to(device)) + +# Removing hooks +for hook in hooks: + hook.remove() + +# Function to display output of a particular ReLU layer +def output_one_layer(layer_num): + assert 1 <= layer_num <= len(activation), "Wrong layer number" + + layer_name = f"ReLu-{layer_num}" + act = activation[f"ReLU-{layer_num}"] + if act.shape[1]==32: + rows = 4 + columns = 8 + elif act.shape[1]==64: + rows = 8 + columns = 8 + + fig = plt.figure(figsize=(rows, columns)) + for idx in range(1, columns * rows + 1): + fig.add_subplot(rows, columns, idx) + plt.imshow(sobel(act[0][idx-1]), cmap=plt.cm.gray) + + # try different filters + # plt.imshow(act[0][idx-1], cmap='viridis', vmin=0, vmax=act.max()) + # plt.imshow(act[0][idx - 1], cmap='hot') + # plt.imshow(roberts(act[0][idx - 1]), cmap=plt.cm.gray) + # plt.imshow(sobel_h(act[0][idx-1]), cmap=plt.cm.gray) + + plt.axis('off') + + plt.tight_layout() + plt.show() + +# Function to display output of all ReLU layer after Convulution layers +def output_all_layers(): + for [name, output], count in zip(activation.items(), range(conv_count)): + if output.shape[1] == 32: + _, axs = plt.subplots(8, 4, figsize=(8, 4)) + elif output.shape[1] == 64: + _, axs = plt.subplots(8, 8, figsize=(8, 8)) + + for ax, out in zip(np.ravel(axs), output[0]): + ax.imshow(sobel(out), cmap=plt.cm.gray) + ax.axis('off') + + plt.suptitle(name) + plt.tight_layout() + plt.show() + +# Choose either one to display +output_one_layer(layer_num=3) # choose layer number +output_all_layers() + diff --git a/labml_nn/cnn/models/cnn.py b/labml_nn/cnn/models/cnn.py new file mode 100755 index 00000000..cf1dc984 --- /dev/null +++ b/labml_nn/cnn/models/cnn.py @@ -0,0 +1,193 @@ +#!/bin/python + +import numpy as np +import torch +import torch.nn as nn + +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + +# Use the formula: +# [(W-K+2P)/S] + 1 +# where: +# W: Is the input volume size for each dimension +# K: Is the kernel size +# P: Is the padding +# S: Is the stride + +def CalcConvFormula(W, K, P, S): + return int(np.floor(((W - K + 2 * P) / S) + 1)) + + +# https://stackoverflow.com/questions/53580088/calculate-the-output-size-in-convolution-layer +# Calculate the output shape after applying a convolution +def CalcConvOutShape(in_shape, kernel_size, padding, stride, out_filters): + # Multiple options for different kernel shapes + if type(kernel_size) == int: + out_shape = [CalcConvFormula(in_shape[i], kernel_size, padding, stride) for i in range(2)] + else: + out_shape = [CalcConvFormula(in_shape[i], kernel_size[i], padding, stride) for i in range(2)] + + return (out_shape[0], out_shape[1], out_filters) # , batch_size... but not necessary. + +class CNN(nn.Module): + def __init__(self + , in_features + , out_features + , conv_filters + , conv_kernel_size + , conv_strides + , conv_pad + , actv_func + , max_pool_kernels + , max_pool_strides + , l1=120 + , l2=84 + , MLP=None + , pre_module_list=None + , use_dropout=False + , use_batch_norm=False + , device="cpu" + ): + super(CNN, self).__init__() + + # Gerneral model Properties + self.in_features = in_features + self.out_features = out_features + + # Convolution operations + self.conv_filters = conv_filters + self.conv_kernel_size = conv_kernel_size + self.conv_strides = conv_strides + self.conv_pad = conv_pad + + # Convolution Activiations + self.actv_func = actv_func + + # Max Pools + self.max_pool_kernels = max_pool_kernels + self.max_pool_strides = max_pool_strides + + # Regularization + self.use_dropout = use_dropout + self.use_batch_norm = use_batch_norm + + # Tunable parameters + self.l1 = l1 + self.l2 = l2 + + # Number of conv/pool/act/batch_norm/dropout layers we add + self.n_conv_layers = len(self.conv_filters) + + # Create the module list + if pre_module_list: + self.module_list = pre_module_list + else: + self.module_list = nn.ModuleList() + + self.shape_list = [] + self.shape_list.append(self.in_features) + + self.build_() + + # Send to gpu + self.device = device + self.to(self.device) + + def build_(self): + # Track shape + cur_shape = self.GetCurShape() + + for i in range(self.n_conv_layers): + if i == 0: + if len(self.in_features) == 2: + in_channels = 1 + else: + in_channels = self.in_features[2] + else: + in_channels = self.conv_filters[i - 1] + + cur_shape = CalcConvOutShape(cur_shape, self.conv_kernel_size[i], self.conv_pad[i], self.conv_strides[i], + self.conv_filters[i]) + self.shape_list.append(cur_shape) + + conv = nn.Conv2d(in_channels=in_channels, + out_channels=self.conv_filters[i], + kernel_size=self.conv_kernel_size[i], + padding=self.conv_pad[i], + stride=self.conv_strides[i] + ) + self.module_list.append(conv) + + if self.use_batch_norm: + self.module_list.append(nn.BatchNorm2d(cur_shape[2])) + + if self.use_dropout: + self.module_list.append(nn.Dropout(p=0.15)) + + # Add the Activation function + if self.actv_func[i]: + self.module_list.append(GetActivation(name=self.actv_func[i])) + + if self.max_pool_kernels: + if self.max_pool_kernels[i]: + self.module_list.append(nn.MaxPool2d(self.max_pool_kernels[i], stride=self.max_pool_strides[i])) + cur_shape = CalcConvOutShape(cur_shape, self.max_pool_kernels[i], 0, self.max_pool_strides[i], + cur_shape[2]) + self.shape_list.append(cur_shape) + + # # Adding MLP + s = self.GetCurShape() + in_features = s[0] * s[1] * s[2] + self.module_list.append(nn.Linear(in_features, self.l1)) + self.module_list.append(nn.ReLU()) + self.module_list.append(nn.Linear(self.l1, self.l2)) + self.module_list.append(nn.ReLU()) + self.module_list.append(nn.Linear(self.l2, self.out_features)) + + def forward(self, x): + j = 0 + for i, module in enumerate(self.module_list): + if isinstance(module, nn.Linear) and j == 0: + x = torch.flatten(x.float(), start_dim=1) + j = 1 + x = module(x) + return x + + def GetCurShape(self): + return self.shape_list[-1] + +def GetCNN(l1=120, l2=84): + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + cnn = CNN(in_features=(32, 32, 3), + out_features=10, + conv_filters=[32, 32, 64, 64], # , 128, 256, 512 + conv_kernel_size=[3, 3, 3, 3], # ,3,3,1 + conv_strides=[1, 1, 1, 1], # ,1,1,1 + conv_pad=[0, 0, 0, 0, 0, 0, 0], + actv_func=["relu", "relu", "relu", "relu"], # , "relu", "relu", "relu" + max_pool_kernels=[None, (2, 2), None, (2, 2)], # , None, None, None + max_pool_strides=[None, 2, None, 2], # , None,None, None + l1=l1, + l2=l2, + use_dropout=False, + use_batch_norm=True, # False + device=device + ) + + return cnn + + +def GetActivation(name="relu"): + if name == "relu": + return nn.ReLU() + elif name == "leakyrelu": + return nn.LeakyReLU() + elif name == "Sigmoid": + return nn.Sigmoid() + elif name == "Tanh": + return nn.Tanh() + elif name == "Identity": + return nn.Identity() + else: + return nn.ReLU() \ No newline at end of file diff --git a/labml_nn/cnn/ray_tune.py b/labml_nn/cnn/ray_tune.py new file mode 100644 index 00000000..18ee7040 --- /dev/null +++ b/labml_nn/cnn/ray_tune.py @@ -0,0 +1,106 @@ +#!/bin/python + +import numpy as np +import os +import torch +from ray import tune +from ray.tune.schedulers import ASHAScheduler, PopulationBasedTraining +from utils.train import Trainer +from models.cnn import GetCNN + +# Check if GPU is available +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +print("Device: " + str(device)) + +# +num_samples= 40 # for multiple trials +max_num_epochs= 25 +gpus_per_trial= 1 + +# Cifar 10 Datasets location +data_dir = './data/Cifar10' + +"""config - returns a dict of hyperparameters + +Selecting different hyperparameters for tuning + l1 : Number of units in first fully connected layer + l2 : Number of units in second fully connected layer + lr : Learning rate + decay : Decay rate for regularization + batch_size : Batch size of test and train data +""" +config = { + "l1": tune.sample_from(lambda _: 2 ** np.random.randint(2, 9)), # eg. 4, 8, 16 .. 512 + "l2": tune.sample_from(lambda _: 2 ** np.random.randint(2, 9)), # eg. 4, 8, 16 .. 512 + "lr": tune.loguniform(1e-4, 1e-1), # Sampling from log uniform distribution + "decay": tune.sample_from(lambda _: 10 ** np.random.randint(-7, -3)), # eg. 1e-7, 1e-6, .. 1e-3 + "batch_size": tune.choice([32, 64, 128, 256]) +} + +# calling trainer +trainer = Trainer(device=device) + +"""ASHA (Asynchronous Successive Halving Algorithm) scheduler + max_t : Maximum number of units per trail (can be time or epochs) + grace_period : Stop trials after specific number of unit if model is not performing well (can be time or epochs) + reduction_factor : Set halving rate +""" +scheduler = ASHAScheduler( + max_t=max_num_epochs, + grace_period=4, + reduction_factor=4) + + + +"""Population based training scheduler + time_attr : Can be time or epochs + metric : Objective of training (loss or accuracy) + perturbation_interval : Perturbation occur after specified unit (can be time or epochs) + hyperparam_mutations : Hyperparameters to mutate +""" +scheduler = PopulationBasedTraining( + time_attr= "training_iteration", # epochs + metric='loss', # loss is objective function + mode='min', # minimizing loss is objective of training + perturbation_interval=5.0, # after 5 epochs perturbate + hyperparam_mutations={ + "lr": [1e-3, 5e-4, 1e-4, 5e-4, 1e-5], # choose from given learning rates + "batch_size": [64, 128, 256], # choose from given batch sizes + "decay": tune.uniform(10**-8, 10**-4) # sample from uniform distribution + } + ) + +result = tune.run( + tune.with_parameters(trainer.Train_ray, data_dir=data_dir), + name="ray_test_basic-CNN", # name for identifying models (checkpoints) + scheduler=scheduler, # select scheduler PBT or ASHA + resources_per_trial={"cpu": 8, "gpu": gpus_per_trial}, # select number of CPUs or GPUs + config=config, # input config dict consisting of different hyperparameters + stop={ + "training_iteration": max_num_epochs, # stopping criterea + }, + metric="loss", # uncomment for ASHA scheduler + mode="min", # uncomment for ASHA scheduler + num_samples=num_samples, + verbose=True, # keep to true to check how training progresses + fail_fast=True, # fail on first error + keep_checkpoints_num=5, # number of checkpoints to be saved per num_samples + +) + +best_trial = result.get_best_trial("loss", "min", "last") +print("Best configuration: {}".format(best_trial.config)) +print("Best validation loss: {}".format(best_trial.last_result["loss"])) +print("Best validation accuracy: {}".format( + best_trial.last_result["accuracy"])) + + +best_trained_model = GetCNN(best_trial.config["l1"], best_trial.config["l2"]) +best_trained_model.to(device) +checkpoint_path = os.path.join(best_trial.checkpoint.value, "checkpoint") +model_state, optimizer_state = torch.load(checkpoint_path) +best_trained_model.load_state_dict(model_state) + +# Check accuracy of best model +test_acc = trainer.Test(best_trained_model, save=data_dir) +print("Best Test accuracy: {}".format(test_acc)) \ No newline at end of file diff --git a/labml_nn/cnn/save/Basic_CNN-best-model/model.pt b/labml_nn/cnn/save/Basic_CNN-best-model/model.pt new file mode 100644 index 00000000..41eaee5b Binary files /dev/null and b/labml_nn/cnn/save/Basic_CNN-best-model/model.pt differ diff --git a/labml_nn/cnn/utils/dataloader.py b/labml_nn/cnn/utils/dataloader.py new file mode 100644 index 00000000..a097d6b9 --- /dev/null +++ b/labml_nn/cnn/utils/dataloader.py @@ -0,0 +1,83 @@ +#!/bin/python + +import torch +import torchvision +import torchvision.transforms as transforms +from torch.utils.data import Dataset, random_split +import matplotlib.pyplot as plt +import numpy as np + +def LoadCifar10DatasetTrain(save, transform=None): + trainset = torchvision.datasets.CIFAR10(root=save, train=True, + download=True, transform=transform) + return trainset + +def LoadCifar10DatasetTest(save, transform): + return torchvision.datasets.CIFAR10(root=save, train=False, + download=False, transform=transform) + +def GetCustTransform(): + transform_train = transforms.Compose([ + transforms.RandomRotation(20), + transforms.RandomCrop(32, (2, 2), pad_if_needed=False, padding_mode='constant'), + transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) + return transform_train + +def Dataloader_train_valid(save, batch_size): + + # See utils/dataloader.py for data augmentations + transform_train_valid = GetCustTransform() + + # Get Cifar 10 Datasets + trainset = LoadCifar10DatasetTrain(save, transform_train_valid) + train_val_abs = int(len(trainset) * 0.8) + train_subset, val_subset = random_split(trainset, [train_val_abs, len(trainset) - train_val_abs]) + + # Get Cifar 10 Dataloaders + trainloader = torch.utils.data.DataLoader(train_subset, batch_size=batch_size, + shuffle=True, num_workers=4) + + valloader = torch.utils.data.DataLoader(val_subset, batch_size=batch_size, + shuffle=True, num_workers=4) + return trainloader, valloader + +def Dataloader_train(save, batch_size): + + # See utils/dataloader.py for data augmentations + transform_train = GetCustTransform() + + # Get Cifar 10 Datasets + trainset = LoadCifar10DatasetTrain(save, transform_train) + # Get Cifar 10 Dataloaders + trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, + shuffle=True, num_workers=4) + + return trainloader + +def Dataloader_test(save, batch_size): + + # transformation test set + transform_test = transforms.Compose( + [transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) + + # initialize test dataset and dataloader + testset = LoadCifar10DatasetTest(save, transform_test) + testloader = torch.utils.data.DataLoader(testset, batch_size=64, + shuffle=False, num_workers=4) + + return testloader + +def imshow(im): + image = im.cpu().clone().detach().numpy() + image = image.transpose(1, 2, 0) + image = image * np.array((0.5, 0.5, 0.5)) + np.array((0.5, 0.5, 0.5)) # unnormalize + plt.imshow(image) + plt.show() + +def imretrun(im): + image = im.cpu().clone().detach().numpy() + image = image.transpose(1, 2, 0) + image = image * np.array((0.5, 0.5, 0.5)) + np.array((0.5, 0.5, 0.5)) # unnormalize + return image \ No newline at end of file diff --git a/labml_nn/cnn/utils/train.py b/labml_nn/cnn/utils/train.py new file mode 100644 index 00000000..dbef1fa3 --- /dev/null +++ b/labml_nn/cnn/utils/train.py @@ -0,0 +1,210 @@ +#!/bin/python + +import torch.nn as nn +import matplotlib.pyplot as plt +import os +from models.cnn import GetCNN +from ray import tune +from utils.dataloader import * # Get the transforms + + +class Trainer(): + def __init__(self, name="default", device=None): + self.device = device + + self.epoch = 0 + self.start_epoch = 0 + self.name = name + + # Train function + def Train(self, net, trainloader, testloader, cost, opt, epochs = 25): + + self.net = net + self.trainloader = trainloader + self.testloader = testloader + + # Optimizer and Cost function + self.opt = opt + self.cost = cost + + # Bookkeeping + train_loss = torch.zeros(epochs) + self.epoch = 0 + train_steps = 0 + accuracy = torch.zeros(epochs) + + # Training loop + for epoch in range(self.start_epoch, self.start_epoch+epochs): + self.epoch = epoch+1 + self.net.train() # Enable Dropout + + # Iterating over train data + for data in self.trainloader: + if self.device: + images, labels = data[0].to(self.device), data[1].to(self.device) + else: + images, labels = data[0], data[1] + + self.opt.zero_grad() + + # Forward + backward + optimize + outputs = self.net(images) + epoch_loss = self.cost(outputs, labels) + epoch_loss.backward() + self.opt.step() + train_steps+=1 + + train_loss[epoch] += epoch_loss.item() + loss_train = train_loss[epoch] / train_steps + + accuracy[epoch] = self.Test() #correct / total + + print("Epoch %d LR %.6f Train Loss: %.3f Test Accuracy: %.3f" % ( + self.epoch, self.opt.param_groups[0]['lr'], loss_train, accuracy[epoch])) + + # Save best model + if accuracy[epoch] >= torch.max(accuracy): + self.save_best_model({ + 'epoch': self.epoch, + 'state_dict': self.net.state_dict(), + 'optimizer': self.opt.state_dict(), + }) + + self.plot_accuracy(accuracy) + + # Test over testloader loop + def Test(self, net = None, save=None): + # Initialize dataloader + if save == None: + testloader = self.testloader + else: + testloader = Dataloader_test(save, batch_size=128) + + # Initialize net + if net == None: + net = self.net + + # Disable Dropout + net.eval() + + # Bookkeeping + correct = 0.0 + total = 0.0 + + # Infer the model + with torch.no_grad(): + for data in testloader: + if self.device: + images, labels = data[0].to(self.device), data[1].to(self.device) + else: + images, labels = data[0], data[1] + + outputs = net(images) + _, predicted = torch.max(outputs.data, 1) + total += labels.size(0) + correct += (predicted == labels).sum().item() + + # compute the final accuracy + accuracy = correct / total + return accuracy + + # Train function modified for ray schedulers + def Train_ray(self, config, checkpoint_dir=None, data_dir=None): + epochs = 25 + + self.net = GetCNN(config["l1"], config["l2"]) + self.net.to(self.device) + + trainloader, valloader = Dataloader_train_valid(data_dir, batch_size=config["batch_size"]) + + # Optimizer and Cost function + self.opt = torch.optim.Adam(self.net.parameters(), lr=config["lr"], betas=(0.9, 0.95), weight_decay=config["decay"]) + self.cost = nn.CrossEntropyLoss() + + # restoring checkpoint + if checkpoint_dir: + checkpoint = os.path.join(checkpoint_dir, "checkpoint") + # checkpoint = checkpoint_dir + model_state, optimizer_state = torch.load(checkpoint) + self.net.load_state_dict(model_state) + self.opt.load_state_dict(optimizer_state) + + self.net.train() + + # Record loss/accuracies + train_loss = torch.zeros(epochs) + self.epoch = 0 + train_steps = 0 + for epoch in range(self.start_epoch, self.start_epoch+epochs): + self.epoch = epoch+1 + + self.net.train() # Enable Dropout + for data in trainloader: + # Get the inputs; data is a list of [inputs, labels] + if self.device: + images, labels = data[0].to(self.device), data[1].to(self.device) + else: + images, labels = data[0], data[1] + + self.opt.zero_grad() + # Forward + backward + optimize + outputs = self.net(images) + epoch_loss = self.cost(outputs, labels) + epoch_loss.backward() + self.opt.step() + train_steps+=1 + + train_loss[epoch] += epoch_loss.item() + + # Validation loss + val_loss = 0.0 + val_steps = 0 + total = 0 + correct = 0 + self.net.eval() + for data in valloader: + with torch.no_grad(): + # Get the inputs; data is a list of [inputs, labels] + if self.device: + images, labels = data[0].to(self.device), data[1].to(self.device) + else: + images, labels = data[0], data[1] + + # Forward + backward + optimize + outputs = self.net(images) + _, predicted = torch.max(outputs.data, 1) + total += labels.size(0) + correct += (predicted == labels).sum().item() + + loss = self.cost(outputs, labels) + val_loss += loss.cpu().numpy() + val_steps += 1 + + # Save checkpoints + with tune.checkpoint_dir(step=epoch) as checkpoint_dir: + path = os.path.join(checkpoint_dir, "checkpoint") + torch.save( + (self.net.state_dict(), self.opt.state_dict()), path) + + tune.report(loss=(val_loss / val_steps), accuracy=correct / total) + + return train_loss + + def plot_accuracy(self, accuracy, criterea = "accuracy"): + plt.plot(accuracy.numpy()) + plt.ylabel("Accuracy" if criterea == "accuracy" else "Loss") + plt.xlabel("Epochs") + plt.show() + + + def save_checkpoint(self, state): + directory = os.path.dirname("./save/%s-checkpoints/"%(self.name)) + if not os.path.exists(directory): + os.mkdir(directory) + torch.save(state, "%s/model_epoch_%s.pt" %(directory, self.epoch)) + + def save_best_model(self, state): + directory = os.path.dirname("./save/%s-best-model/"%(self.name)) + if not os.path.exists(directory): + os.mkdir(directory) + torch.save(state, "%s/model.pt" %(directory)) \ No newline at end of file