3import torch
4
5from torch.utils.data import Subset
6
7from sklearn.model_selection import KFold
8from torch.utils.data.sampler import SubsetRandomSampler
9from models.cnn import GetCNN
10from torchsummary import summary
11import torch.optim as optim
12import os
13
14from torch.utils.tensorboard import SummaryWriter
15
16from datetime import datetime
17from glob import glob
21def cross_val_train(cost, trainset, epochs, splits, device=None):
22
23 patience = 4
24 history = []
25 kf = KFold(n_splits=splits, shuffle=True)
26 batch_size = 64
27 now = datetime.now()
28 date_time = now.strftime("%d-%m-%Y_%H:%M:%S")
29 directory = os.path.dirname('./save/tensorboard-%s/'%(date_time))
30
31 if not os.path.exists(directory):
32 os.mkdir(directory)
33
34 for fold, (train_index, test_index) in enumerate(kf.split(trainset.data, trainset.targets)): #dataset required - compelete training set
35 comment = f'{directory}/fold-{fold}'
36 writer = SummaryWriter(log_dir=comment)
37
38 train_sampler = SubsetRandomSampler(train_index)
39 valid_sampler = SubsetRandomSampler(test_index)
40 traindata = torch.utils.data.DataLoader(trainset, batch_size=batch_size, sampler=train_sampler,
41 num_workers=2)
42 valdata = torch.utils.data.DataLoader(trainset, batch_size=batch_size, sampler=valid_sampler,
43 num_workers=2)
44
45 net = GetCNN()
46 net.to(device)
47 if fold == 0: #Printing model detials for the first time
48 summary(net, (3, 32, 32))
Specify optimizer
52 optimizer = optim.Adam(net.parameters(), lr=0.0005, betas=(0.9, 0.95))
53 losses = torch.zeros(epochs)
54 accuracies = torch.zeros(epochs)
55 min_loss = None
56 count = 0
57 for epoch in range(epochs):
58 valid_loss = 0
59 running_loss = 0.0
60 epoch_loss = 0.0
61 train_loss = torch.zeros(epochs)
62 train_steps = 0.0
training steps
65 net.train() # Enable Dropout
66 for i, data in enumerate(traindata, 0):
Get the inputs; data is a list of [inputs, labels]
68 if device:
69 images, labels = data[0].to(device), data[1].to(device)
70 else:
71 images, labels = data
Forward + backward + optimize
74 outputs = net(images)
75 loss = cost(outputs, labels)
76 loss.backward()
77 optimizer.step()
Zero the parameter gradients
79 optimizer.zero_grad()
Print loss
82 running_loss += loss.item()
83 epoch_loss += loss.item()
84 train_loss[epoch] += loss.item()
85 train_steps += 1
86
87 loss_train = train_loss[epoch] / train_steps
Validation
90 loss_accuracy = Test(net, cost, valdata, device)
91
92 losses[epoch] = loss_accuracy['val_loss']
93 accuracies[epoch] = loss_accuracy['val_acc']
94 print("Fold %d, Epoch %d, Train Loss %.4f Validation Loss: %.4f, Validation Accuracy: %.4f" % (fold+1, epoch+1, loss_train, losses[epoch], accuracies[epoch]))
TensorBoard
97 info = {
98 "Loss/train": loss_train,
99 "Loss/valid": losses[epoch],
100 "Accuracy/valid": accuracies[epoch]
101 }
102
103 for tag, item in info.items():
104 writer.add_scalar(tag, item, global_step=epoch)
105
106 if min_loss == None:
107 min_loss = losses[epoch]
Early stopping refered from https://github.com/Bjarten/early-stopping-pytorch/blob/master/pytorchtools.py
110 if losses[epoch] > min_loss:
111 print("Epoch loss: %.4f, Min loss: %.4f"%(losses[epoch], min_loss))
112 count += 1
113 print(f'Early stopping counter: {count} out of {patience}')
114 if count >= patience:
115 print(f'############### EarlyStopping ##################')
116 break
Saving best model
119 elif losses[epoch] <= min_loss:
120 count = 0
121 save_best_model({
122 'epoch': epoch,
123 'state_dict': net.state_dict(),
124 'optimizer': optimizer.state_dict(),
125 'accuracy' : accuracies[epoch]
126 }, fold=fold, date_time=date_time)
127 min_loss = losses[epoch]
128
129 history.append({'val_loss': losses[epoch], 'val_acc': accuracies[epoch]})
130 return history
132def save_best_model(state, fold, date_time):
133 directory = os.path.dirname("./save/CV_models-%s/"%(date_time))
134 if not os.path.exists(directory):
135 os.mkdir(directory)
136 torch.save(state, "%s/fold-%d-model.pt" % (directory, fold))
138def retreive_best_trial():
139 PATH = "./save/"
140 best_model = GetCNN()
141
142 content = os.listdir(PATH)
143 latest_time = 0
144 for item in content:
145 if 'CV_models' in item:
146 foldername = os.path.join(PATH, item)
147 tm = os.path.getmtime(foldername)
148 if tm > latest_time:
149 latest_folder = foldername
150
151 file_type = '/*.pt'
152 files = glob(latest_folder + file_type)
153
154 accuracy = 0
155 for model_file in files:
156 checkpoint = torch.load(model_file)
157 if checkpoint['accuracy'] > accuracy:
158 best_model.load_state_dict(checkpoint['state_dict'])
159 best_val_accuracy = checkpoint['accuracy']
Test(best_model,)
162 return best_model, best_val_accuracy
164def val_step(net, cost, images, labels):
forward pass
166 output = net(images)
loss in batch
168 loss = cost(output, labels)
update validation loss
171 _, preds = torch.max(output, dim=1)
172 acc = torch.tensor(torch.sum(preds == labels).item() / len(preds))
173 acc_output = {'val_loss': loss.detach(), 'val_acc': acc}
174 return acc_output
Test over testloader/valloader loop
177def Test(net, cost, testloader, device):
Disable Dropout
179 net.eval()
Bookkeeping
182 correct = 0.0
183 total = 0.0
184 loss = 0.0
185 train_steps = 0.0
Infer the model
188 with torch.no_grad():
189 for data in testloader:
190 if device:
191 images, labels = data[0].to(device), data[1].to(device)
192 else:
193 images, labels = data[0], data[1]
194
195 outputs = net(images)
loss in batch
197 loss += cost(outputs, labels)
198 train_steps+=1
losses[epoch] += loss.item()
201 _, predicted = torch.max(outputs.data, 1)
202 total += labels.size(0)
203 correct += (predicted == labels).sum().item()
204 loss = loss/train_steps
205
206 accuracy = correct / total
207 loss_accuracy = {'val_loss': loss, 'val_acc': accuracy} #accuracy
208 return loss_accuracy