3import torch
4import torchvision
5import torchvision.transforms as transforms
6
7import torch.nn as nn
8import torch.nn.functional as F
9
10import matplotlib.pyplot as plt
11import numpy as np
12
13from sklearn.model_selection import KFold
14from torch.utils.data.sampler import SubsetRandomSampler

Plot the loss of multiple runs together

19def PlotLosses(losses, titles, save=None):
20    fig = plt.figure()
21    fig.set_size_inches(14, 22)

Plot results on 3 subgraphs subplot integers: nrows ncols index

27    sublplot_str_start = "" + str(len(losses)) + "1"
28
29    for i in range(len(losses)):
30        subplot = sublplot_str_start + str(i+1)
31        loss = losses[i]
32        title = titles[i]
33
34        ax = plt.subplot(int(subplot))
35        ax.plot(range(len(loss)), loss)
36        ax.set_xlabel("Epoch")
37        ax.set_title(title)
38        ax.set_ylabel("Loss")

Save Figure

41    if save:
42    	plt.savefig(save)
43    else:
44    	plt.show()
48def ClassSpecificTestCifar10(net, testdata, device=None):
49    classes = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')
50    class_correct = list(0. for i in range(10))
51    class_total = list(0. for i in range(10))
52    with torch.no_grad():
53        for data in testdata:
54            if device:
55                images, labels = data[0].to(device), data[1].to(device)
56            else:
57                images, labels = data
58
59            outputs = net(images)
60            _, predicted = torch.max(outputs, 1)
61            c = (predicted == labels).squeeze()
62            for i in range(4):
63                label = labels[i]
64                class_correct[label] += c[i].item()
65                class_total[label] += 1

Print out

68    for i in range(10):
69        print('Accuracy of %5s : %2d %%' % (
70            classes[i], 100 * class_correct[i] / class_total[i]))
74def GetActivation(name="relu"):
75    if name == "relu":
76        return nn.ReLU()
77    elif name == "leakyrelu":
78        return nn.LeakyReLU()
79    elif name == "Sigmoid":
80        return nn.Sigmoid()
81    elif name == "Tanh":
82        return nn.Tanh()
83    elif name == "Identity":
84        return nn.Identity()
85    else:
86        return nn.ReLU()