3import torch.nn as nn
4import matplotlib.pyplot as plt
5import os
6from models.cnn import GetCNN
7from ray import tune
8from utils.dataloader import * # Get the transforms
11class Trainer():
12    def __init__(self, name="default", device=None):
13        self.device = device
14
15        self.epoch = 0
16        self.start_epoch = 0
17        self.name = name

Train function

20    def Train(self, net, trainloader, testloader, cost, opt, epochs = 25):
21
22        self.net = net
23        self.trainloader = trainloader
24        self.testloader = testloader

Optimizer and Cost function

27        self.opt = opt
28        self.cost = cost

Bookkeeping

31        train_loss = torch.zeros(epochs)
32        self.epoch = 0
33        train_steps = 0
34        accuracy = torch.zeros(epochs)

Training loop

37        for epoch in range(self.start_epoch, self.start_epoch+epochs):
38            self.epoch = epoch+1
39            self.net.train()  # Enable Dropout

Iterating over train data

42            for data in self.trainloader:
43                if self.device:
44                    images, labels = data[0].to(self.device), data[1].to(self.device)
45                else:
46                    images, labels = data[0], data[1]
47
48                self.opt.zero_grad()

Forward + backward + optimize

51                outputs = self.net(images)
52                epoch_loss = self.cost(outputs, labels)
53                epoch_loss.backward()
54                self.opt.step()
55                train_steps+=1
56
57                train_loss[epoch] += epoch_loss.item()
58            loss_train = train_loss[epoch] / train_steps
59
60            accuracy[epoch] = self.Test() #correct / total
61
62            print("Epoch %d LR %.6f Train Loss: %.3f Test Accuracy: %.3f" % (
63            self.epoch, self.opt.param_groups[0]['lr'], loss_train, accuracy[epoch]))

Save best model

66            if accuracy[epoch] >= torch.max(accuracy):
67                self.save_best_model({
68                    'epoch': self.epoch,
69                    'state_dict': self.net.state_dict(),
70                    'optimizer': self.opt.state_dict(),
71                })
72
73        self.plot_accuracy(accuracy)

Test over testloader loop

76    def Test(self, net = None, save=None):

Initialize dataloader

78        if save == None:
79            testloader = self.testloader
80        else:
81            testloader = Dataloader_test(save, batch_size=128)

Initialize net

84        if net == None:
85            net = self.net

Disable Dropout

88        net.eval()

Bookkeeping

91        correct = 0.0
92        total = 0.0

Infer the model

95        with torch.no_grad():
96            for data in testloader:
97                if self.device:
98                    images, labels = data[0].to(self.device), data[1].to(self.device)
99                else:
100                    images, labels = data[0], data[1]
101
102                outputs = net(images)
103                _, predicted = torch.max(outputs.data, 1)
104                total += labels.size(0)
105                correct += (predicted == labels).sum().item()

compute the final accuracy

108        accuracy = correct / total
109        return accuracy

Train function modified for ray schedulers

112    def Train_ray(self, config, checkpoint_dir=None, data_dir=None):
113        epochs = 25
114
115        self.net = GetCNN(config["l1"], config["l2"])
116        self.net.to(self.device)
117
118        trainloader, valloader = Dataloader_train_valid(data_dir, batch_size=config["batch_size"])

Optimizer and Cost function

121        self.opt = torch.optim.Adam(self.net.parameters(), lr=config["lr"], betas=(0.9, 0.95), weight_decay=config["decay"])
122        self.cost = nn.CrossEntropyLoss()

restoring checkpoint

125        if checkpoint_dir:
126            checkpoint = os.path.join(checkpoint_dir, "checkpoint")

checkpoint = checkpoint_dir

128            model_state, optimizer_state = torch.load(checkpoint)
129            self.net.load_state_dict(model_state)
130            self.opt.load_state_dict(optimizer_state)
131
132        self.net.train()

Record loss/accuracies

135        train_loss = torch.zeros(epochs)
136        self.epoch = 0
137        train_steps = 0
138        for epoch in range(self.start_epoch, self.start_epoch+epochs):
139            self.epoch = epoch+1
140
141            self.net.train()  # Enable Dropout
142            for data in trainloader:

Get the inputs; data is a list of [inputs, labels]

144                if self.device:
145                    images, labels = data[0].to(self.device), data[1].to(self.device)
146                else:
147                    images, labels = data[0], data[1]
148
149                self.opt.zero_grad()

Forward + backward + optimize

151                outputs = self.net(images)
152                epoch_loss = self.cost(outputs, labels)
153                epoch_loss.backward()
154                self.opt.step()
155                train_steps+=1
156
157                train_loss[epoch] += epoch_loss.item()

Validation loss

160            val_loss = 0.0
161            val_steps = 0
162            total = 0
163            correct = 0
164            self.net.eval()
165            for data in valloader:
166                with torch.no_grad():

Get the inputs; data is a list of [inputs, labels]

168                    if self.device:
169                        images, labels = data[0].to(self.device), data[1].to(self.device)
170                    else:
171                        images, labels = data[0], data[1]

Forward + backward + optimize

174                    outputs = self.net(images)
175                    _, predicted = torch.max(outputs.data, 1)
176                    total += labels.size(0)
177                    correct += (predicted == labels).sum().item()
178
179                    loss = self.cost(outputs, labels)
180                    val_loss += loss.cpu().numpy()
181                    val_steps += 1

Save checkpoints

184            with tune.checkpoint_dir(step=epoch) as checkpoint_dir:
185                path = os.path.join(checkpoint_dir, "checkpoint")
186                torch.save(
187                    (self.net.state_dict(), self.opt.state_dict()), path)
188
189            tune.report(loss=(val_loss / val_steps), accuracy=correct / total)
190
191        return train_loss
193    def plot_accuracy(self, accuracy, criterea = "accuracy"):
194        plt.plot(accuracy.numpy())
195        plt.ylabel("Accuracy" if criterea == "accuracy" else "Loss")
196        plt.xlabel("Epochs")
197        plt.show()
200    def save_checkpoint(self, state):
201        directory = os.path.dirname("./save/%s-checkpoints/"%(self.name))
202        if not os.path.exists(directory):
203            os.mkdir(directory)
204        torch.save(state, "%s/model_epoch_%s.pt" %(directory, self.epoch))
206    def save_best_model(self, state):
207        directory = os.path.dirname("./save/%s-best-model/"%(self.name))
208        if not os.path.exists(directory):
209            os.mkdir(directory)
210        torch.save(state, "%s/model.pt" %(directory))