3import torch
4from torch.utils.data import DataLoader, ConcatDataset
from sklearn.model_selection import KFold from torch.utils.data.sampler import SubsetRandomSampler
8import matplotlib.pyplot as plt
9from pylab import *
10import os
11
12from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR
16class Trainer():
17 def __init__(self, net, opt, cost, name="default", lr=0.0005, use_lr_schedule =False , device=None):
18 self.net = net
19 self.opt = opt
20 self.cost = cost
21 self.device = device
22 self.epoch = 0
23 self.start_epoch = 0
24 self.name = name
25
26 self.lr = lr
27 self.use_lr_schedule = use_lr_schedule
28 if self.use_lr_schedule:
29 self.scheduler = ReduceLROnPlateau( self.opt, 'max', factor=0.1, patience=5, threshold=0.00001, verbose=True)
self.scheduler = StepLR(self.opt, step_size=15, gamma=0.1)
Train loop over epochs. Optinal use testloader to return test accuracy after each epoch
33 def Train(self, trainloader, epochs, testloader=None):
Enable Dropout
Record loss/accuracies
37 loss = torch.zeros(epochs)
38 self.epoch = 0
If testloader is used, loss will be the accuracy
41 for epoch in range(self.start_epoch, self.start_epoch+epochs):
42 self.epoch = epoch+1
43
44 self.net.train() # Enable Dropout
45 for data in trainloader:
Get the inputs; data is a list of [inputs, labels]
47 if self.device:
48 images, labels = data[0].to(self.device), data[1].to(self.device)
49 else:
50 images, labels = data
51
52 self.opt.zero_grad()
Forward + backward + optimize
54 outputs = self.net(images)
55 epoch_loss = self.cost(outputs, labels)
56 epoch_loss.backward()
57 self.opt.step()
58
59 loss[epoch] += epoch_loss.item()
60
61 if testloader:
62 loss[epoch] = self.Test(testloader)
63 else:
64 loss[epoch] /= len(trainloader)
65
66 print("Epoch %d Learning rate %.6f %s: %.3f" % (
67 self.epoch, self.opt.param_groups[0]['lr'], "Accuracy" if testloader else "Loss", loss[epoch]))
learning rate scheduler
70 if self.use_lr_schedule:
71 self.scheduler.step(loss[epoch])
self.scheduler.step()
Saving best model
75 if loss[epoch] >= torch.max(loss):
76 self.save_best_model({
77 'epoch': self.epoch,
78 'state_dict': self.net.state_dict(),
79 'optimizer': self.opt.state_dict(),
80 })
81
82 return loss
Testing
85 def Test(self, testloader, ret="accuracy"):
Disable Dropout
87 self.net.eval()
Track correct and total
90 correct = 0.0
91 total = 0.0
92 with torch.no_grad():
93 for data in testloader:
94 if self.device:
95 images, labels = data[0].to(self.device), data[1].to(self.device)
96 else:
97 images, labels = data
98
99 outputs = self.net(images)
100 _, predicted = torch.max(outputs.data, 1)
101 total += labels.size(0)
102 correct += (predicted == labels).sum().item()
103
104 return correct / total
106 def save_best_model(self, state):
107 directory = os.path.dirname("./save/%s-best-model/"%(self.name))
108 if not os.path.exists(directory):
109 os.mkdir(directory)
110 torch.save(state, "%s/model.pt" %(directory))
112 def save_checkpoint(self, state):
113 directory = os.path.dirname("./save/%s-checkpoints/"%(self.name))
114 if not os.path.exists(directory):
115 os.mkdir(directory)
116 torch.save(state, "%s/model_epoch_%s.pt" %(directory, self.epoch))
torch.save(state, “./save/checkpoints/model_epoch_%s.pt” % (self.epoch))