3import torch
4
5from torch.utils.data import Subset
6
7from sklearn.model_selection import KFold
8from torch.utils.data.sampler import SubsetRandomSampler
9from models.cnn import GetCNN
10from torchsummary import summary
11import torch.optim as optim
12import os
13
14from torch.utils.tensorboard import SummaryWriter
15
16from datetime import datetime
17from glob import glob
21def cross_val_train(cost, trainset, epochs, splits, device=None):
22
23    patience = 4
24    history = []
25    kf = KFold(n_splits=splits, shuffle=True)
26    batch_size = 64
27    now = datetime.now()
28    date_time = now.strftime("%d-%m-%Y_%H:%M:%S")
29    directory = os.path.dirname('./save/tensorboard-%s/'%(date_time))
30
31    if not os.path.exists(directory):
32        os.mkdir(directory)
33
34    for fold, (train_index, test_index) in enumerate(kf.split(trainset.data, trainset.targets)): #dataset required - compelete training set
35        comment = f'{directory}/fold-{fold}'
36        writer = SummaryWriter(log_dir=comment)
37
38        train_sampler = SubsetRandomSampler(train_index)
39        valid_sampler = SubsetRandomSampler(test_index)
40        traindata = torch.utils.data.DataLoader(trainset, batch_size=batch_size, sampler=train_sampler,
41                                                   num_workers=2)
42        valdata = torch.utils.data.DataLoader(trainset, batch_size=batch_size, sampler=valid_sampler,
43                                                   num_workers=2)
44
45        net = GetCNN()
46        net.to(device)
47        if fold == 0: #Printing model detials for the first time
48            summary(net, (3, 32, 32))

Specify optimizer

52        optimizer = optim.Adam(net.parameters(), lr=0.0005, betas=(0.9, 0.95))
53        losses = torch.zeros(epochs)
54        accuracies = torch.zeros(epochs)
55        min_loss = None
56        count = 0
57        for epoch in range(epochs):
58            valid_loss = 0
59            running_loss = 0.0
60            epoch_loss = 0.0
61            train_loss = torch.zeros(epochs)
62            train_steps = 0.0

training steps

65            net.train()  # Enable Dropout
66            for i, data in enumerate(traindata, 0):

Get the inputs; data is a list of [inputs, labels]

68                if device:
69                    images, labels = data[0].to(device), data[1].to(device)
70                else:
71                    images, labels = data

Forward + backward + optimize

74                outputs = net(images)
75                loss = cost(outputs, labels)
76                loss.backward()
77                optimizer.step()

Zero the parameter gradients

79                optimizer.zero_grad()

Print loss

82                running_loss += loss.item()
83                epoch_loss += loss.item()
84                train_loss[epoch] += loss.item()
85                train_steps += 1
86
87            loss_train = train_loss[epoch] / train_steps

Validation

90            loss_accuracy = Test(net, cost, valdata, device)
91
92            losses[epoch] = loss_accuracy['val_loss']
93            accuracies[epoch] = loss_accuracy['val_acc']
94            print("Fold %d, Epoch %d, Train Loss %.4f Validation Loss: %.4f, Validation Accuracy: %.4f" % (fold+1, epoch+1, loss_train, losses[epoch], accuracies[epoch]))

TensorBoard

97            info = {
98                "Loss/train": loss_train,
99                "Loss/valid": losses[epoch],
100                "Accuracy/valid": accuracies[epoch]
101                }
102
103            for tag, item in info.items():
104                writer.add_scalar(tag, item, global_step=epoch)
105
106            if min_loss == None:
107                min_loss = losses[epoch]

Early stopping refered from https://github.com/Bjarten/early-stopping-pytorch/blob/master/pytorchtools.py

110            if losses[epoch] > min_loss:
111                print("Epoch loss: %.4f, Min loss: %.4f"%(losses[epoch], min_loss))
112                count += 1
113                print(f'Early stopping counter: {count} out of {patience}')
114                if count >= patience:
115                    print(f'############### EarlyStopping ##################')
116                    break

Saving best model

119            elif losses[epoch] <= min_loss:
120                count = 0
121                save_best_model({
122                    'epoch': epoch,
123                    'state_dict': net.state_dict(),
124                    'optimizer': optimizer.state_dict(),
125                    'accuracy' : accuracies[epoch]
126                }, fold=fold, date_time=date_time)
127                min_loss = losses[epoch]
128
129            history.append({'val_loss': losses[epoch], 'val_acc': accuracies[epoch]})
130    return history
132def save_best_model(state, fold, date_time):
133    directory = os.path.dirname("./save/CV_models-%s/"%(date_time))
134    if not os.path.exists(directory):
135        os.mkdir(directory)
136    torch.save(state, "%s/fold-%d-model.pt" % (directory, fold))
138def retreive_best_trial():
139    PATH = "./save/"
140    best_model = GetCNN()
141
142    content = os.listdir(PATH)
143    latest_time = 0
144    for item in content:
145        if 'CV_models' in item:
146            foldername = os.path.join(PATH, item)
147            tm = os.path.getmtime(foldername)
148            if tm > latest_time:
149                latest_folder = foldername
150
151    file_type = '/*.pt'
152    files = glob(latest_folder + file_type)
153
154    accuracy = 0
155    for model_file in files:
156        checkpoint = torch.load(model_file)
157        if checkpoint['accuracy'] > accuracy:
158            best_model.load_state_dict(checkpoint['state_dict'])
159            best_val_accuracy = checkpoint['accuracy']

Test(best_model,)

162    return best_model, best_val_accuracy
164def val_step(net, cost, images, labels):

forward pass

166    output = net(images)

loss in batch

168    loss = cost(output, labels)

update validation loss

171    _, preds = torch.max(output, dim=1)
172    acc = torch.tensor(torch.sum(preds == labels).item() / len(preds))
173    acc_output = {'val_loss': loss.detach(), 'val_acc': acc}
174    return acc_output

Test over testloader/valloader loop

177def Test(net, cost, testloader, device):

Disable Dropout

179    net.eval()

Bookkeeping

182    correct = 0.0
183    total = 0.0
184    loss = 0.0
185    train_steps = 0.0

Infer the model

188    with torch.no_grad():
189        for data in testloader:
190            if device:
191                images, labels = data[0].to(device), data[1].to(device)
192            else:
193                images, labels = data[0], data[1]
194
195            outputs = net(images)

loss in batch

197            loss += cost(outputs, labels)
198            train_steps+=1

losses[epoch] += loss.item()

201            _, predicted = torch.max(outputs.data, 1)
202            total += labels.size(0)
203            correct += (predicted == labels).sum().item()
204        loss = loss/train_steps
205
206    accuracy = correct / total
207    loss_accuracy = {'val_loss': loss, 'val_acc': accuracy} #accuracy
208    return loss_accuracy