3import torch.nn as nn
4import matplotlib.pyplot as plt
5import os
6from models.cnn import GetCNN
7from ray import tune
8from utils.dataloader import * # Get the transforms
11class Trainer():
12 def __init__(self, name="default", device=None):
13 self.device = device
14
15 self.epoch = 0
16 self.start_epoch = 0
17 self.name = name
Train function
20 def Train(self, net, trainloader, testloader, cost, opt, epochs = 25):
21
22 self.net = net
23 self.trainloader = trainloader
24 self.testloader = testloader
Optimizer and Cost function
27 self.opt = opt
28 self.cost = cost
Bookkeeping
31 train_loss = torch.zeros(epochs)
32 self.epoch = 0
33 train_steps = 0
34 accuracy = torch.zeros(epochs)
Training loop
37 for epoch in range(self.start_epoch, self.start_epoch+epochs):
38 self.epoch = epoch+1
39 self.net.train() # Enable Dropout
Iterating over train data
42 for data in self.trainloader:
43 if self.device:
44 images, labels = data[0].to(self.device), data[1].to(self.device)
45 else:
46 images, labels = data[0], data[1]
47
48 self.opt.zero_grad()
Forward + backward + optimize
51 outputs = self.net(images)
52 epoch_loss = self.cost(outputs, labels)
53 epoch_loss.backward()
54 self.opt.step()
55 train_steps+=1
56
57 train_loss[epoch] += epoch_loss.item()
58 loss_train = train_loss[epoch] / train_steps
59
60 accuracy[epoch] = self.Test() #correct / total
61
62 print("Epoch %d LR %.6f Train Loss: %.3f Test Accuracy: %.3f" % (
63 self.epoch, self.opt.param_groups[0]['lr'], loss_train, accuracy[epoch]))
Save best model
66 if accuracy[epoch] >= torch.max(accuracy):
67 self.save_best_model({
68 'epoch': self.epoch,
69 'state_dict': self.net.state_dict(),
70 'optimizer': self.opt.state_dict(),
71 })
72
73 self.plot_accuracy(accuracy)
Test over testloader loop
76 def Test(self, net = None, save=None):
Initialize dataloader
78 if save == None:
79 testloader = self.testloader
80 else:
81 testloader = Dataloader_test(save, batch_size=128)
Initialize net
84 if net == None:
85 net = self.net
Disable Dropout
88 net.eval()
Bookkeeping
91 correct = 0.0
92 total = 0.0
Infer the model
95 with torch.no_grad():
96 for data in testloader:
97 if self.device:
98 images, labels = data[0].to(self.device), data[1].to(self.device)
99 else:
100 images, labels = data[0], data[1]
101
102 outputs = net(images)
103 _, predicted = torch.max(outputs.data, 1)
104 total += labels.size(0)
105 correct += (predicted == labels).sum().item()
compute the final accuracy
108 accuracy = correct / total
109 return accuracy
Train function modified for ray schedulers
112 def Train_ray(self, config, checkpoint_dir=None, data_dir=None):
113 epochs = 25
114
115 self.net = GetCNN(config["l1"], config["l2"])
116 self.net.to(self.device)
117
118 trainloader, valloader = Dataloader_train_valid(data_dir, batch_size=config["batch_size"])
Optimizer and Cost function
121 self.opt = torch.optim.Adam(self.net.parameters(), lr=config["lr"], betas=(0.9, 0.95), weight_decay=config["decay"])
122 self.cost = nn.CrossEntropyLoss()
restoring checkpoint
125 if checkpoint_dir:
126 checkpoint = os.path.join(checkpoint_dir, "checkpoint")
checkpoint = checkpoint_dir
128 model_state, optimizer_state = torch.load(checkpoint)
129 self.net.load_state_dict(model_state)
130 self.opt.load_state_dict(optimizer_state)
131
132 self.net.train()
Record loss/accuracies
135 train_loss = torch.zeros(epochs)
136 self.epoch = 0
137 train_steps = 0
138 for epoch in range(self.start_epoch, self.start_epoch+epochs):
139 self.epoch = epoch+1
140
141 self.net.train() # Enable Dropout
142 for data in trainloader:
Get the inputs; data is a list of [inputs, labels]
144 if self.device:
145 images, labels = data[0].to(self.device), data[1].to(self.device)
146 else:
147 images, labels = data[0], data[1]
148
149 self.opt.zero_grad()
Forward + backward + optimize
151 outputs = self.net(images)
152 epoch_loss = self.cost(outputs, labels)
153 epoch_loss.backward()
154 self.opt.step()
155 train_steps+=1
156
157 train_loss[epoch] += epoch_loss.item()
Validation loss
160 val_loss = 0.0
161 val_steps = 0
162 total = 0
163 correct = 0
164 self.net.eval()
165 for data in valloader:
166 with torch.no_grad():
Get the inputs; data is a list of [inputs, labels]
168 if self.device:
169 images, labels = data[0].to(self.device), data[1].to(self.device)
170 else:
171 images, labels = data[0], data[1]
Forward + backward + optimize
174 outputs = self.net(images)
175 _, predicted = torch.max(outputs.data, 1)
176 total += labels.size(0)
177 correct += (predicted == labels).sum().item()
178
179 loss = self.cost(outputs, labels)
180 val_loss += loss.cpu().numpy()
181 val_steps += 1
Save checkpoints
184 with tune.checkpoint_dir(step=epoch) as checkpoint_dir:
185 path = os.path.join(checkpoint_dir, "checkpoint")
186 torch.save(
187 (self.net.state_dict(), self.opt.state_dict()), path)
188
189 tune.report(loss=(val_loss / val_steps), accuracy=correct / total)
190
191 return train_loss
193 def plot_accuracy(self, accuracy, criterea = "accuracy"):
194 plt.plot(accuracy.numpy())
195 plt.ylabel("Accuracy" if criterea == "accuracy" else "Loss")
196 plt.xlabel("Epochs")
197 plt.show()
200 def save_checkpoint(self, state):
201 directory = os.path.dirname("./save/%s-checkpoints/"%(self.name))
202 if not os.path.exists(directory):
203 os.mkdir(directory)
204 torch.save(state, "%s/model_epoch_%s.pt" %(directory, self.epoch))
206 def save_best_model(self, state):
207 directory = os.path.dirname("./save/%s-best-model/"%(self.name))
208 if not os.path.exists(directory):
209 os.mkdir(directory)
210 torch.save(state, "%s/model.pt" %(directory))