3import torch
4from torch.utils.data import DataLoader, ConcatDataset

from sklearn.model_selection import KFold from torch.utils.data.sampler import SubsetRandomSampler

8import matplotlib.pyplot as plt
9from pylab import *
10import os
11
12from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR
16class Trainer():
17    def __init__(self, net, opt, cost, name="default", lr=0.0005, use_lr_schedule =False , device=None):
18        self.net = net
19        self.opt = opt
20        self.cost = cost
21        self.device = device
22        self.epoch = 0
23        self.start_epoch = 0
24        self.name = name
25
26        self.lr = lr
27        self.use_lr_schedule = use_lr_schedule
28        if self.use_lr_schedule:
29            self.scheduler = ReduceLROnPlateau( self.opt, 'max', factor=0.1, patience=5, threshold=0.00001, verbose=True)

self.scheduler = StepLR(self.opt, step_size=15, gamma=0.1)

Train loop over epochs. Optinal use testloader to return test accuracy after each epoch

33    def Train(self, trainloader, epochs, testloader=None):

Enable Dropout

Record loss/accuracies

37        loss = torch.zeros(epochs)
38        self.epoch = 0

If testloader is used, loss will be the accuracy

41        for epoch in range(self.start_epoch, self.start_epoch+epochs):
42            self.epoch = epoch+1
43
44            self.net.train()  # Enable Dropout
45            for data in trainloader:

Get the inputs; data is a list of [inputs, labels]

47                if self.device:
48                    images, labels = data[0].to(self.device), data[1].to(self.device)
49                else:
50                    images, labels = data
51
52                self.opt.zero_grad()

Forward + backward + optimize

54                outputs = self.net(images)
55                epoch_loss = self.cost(outputs, labels)
56                epoch_loss.backward()
57                self.opt.step()
58
59                loss[epoch] += epoch_loss.item()
60
61            if testloader:
62                loss[epoch] = self.Test(testloader)
63            else:
64                loss[epoch] /= len(trainloader)
65
66            print("Epoch %d Learning rate %.6f %s: %.3f" % (
67            self.epoch, self.opt.param_groups[0]['lr'], "Accuracy" if testloader else "Loss", loss[epoch]))

learning rate scheduler

70            if self.use_lr_schedule:
71                self.scheduler.step(loss[epoch])

self.scheduler.step()

Saving best model

75            if loss[epoch] >= torch.max(loss):
76                self.save_best_model({
77                    'epoch': self.epoch,
78                    'state_dict': self.net.state_dict(),
79                    'optimizer': self.opt.state_dict(),
80                })
81
82        return loss

Testing

85    def Test(self, testloader, ret="accuracy"):

Disable Dropout

87        self.net.eval()

Track correct and total

90        correct = 0.0
91        total = 0.0
92        with torch.no_grad():
93            for data in testloader:
94                if self.device:
95                    images, labels = data[0].to(self.device), data[1].to(self.device)
96                else:
97                    images, labels = data
98
99                outputs = self.net(images)
100                _, predicted = torch.max(outputs.data, 1)
101                total += labels.size(0)
102                correct += (predicted == labels).sum().item()
103
104        return correct / total
106    def save_best_model(self, state):
107        directory = os.path.dirname("./save/%s-best-model/"%(self.name))
108        if not os.path.exists(directory):
109            os.mkdir(directory)
110        torch.save(state, "%s/model.pt" %(directory))
112    def save_checkpoint(self, state):
113        directory = os.path.dirname("./save/%s-checkpoints/"%(self.name))
114        if not os.path.exists(directory):
115            os.mkdir(directory)
116        torch.save(state, "%s/model_epoch_%s.pt" %(directory, self.epoch))

torch.save(state, “./save/checkpoints/model_epoch_%s.pt” % (self.epoch))