3import torch.nn as nn
4import torch.optim as optim
5from torchsummary import summary
6from functools import partial
7from skimage.filters import sobel, sobel_h, roberts
8from models.cnn import CNN
9from utils.dataloader import *
10from utils.train import Trainer

Check if GPU is available

13device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
14print("Device:  " + str(device))

Cifar 10 Datasets location

17save='./data/Cifar10'

Transformations train

20transform_train = transforms.Compose(
21        [transforms.ToTensor(),
22         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

Load train dataset and dataloader

25trainset = LoadCifar10DatasetTrain(save, transform_train)
26trainloader = torch.utils.data.DataLoader(trainset, batch_size=64,
27                                          shuffle=True, num_workers=4)

Transformations test

30transform_test = transforms.Compose(
31        [transforms.ToTensor(),
32         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

Load test dataset and dataloader

35testset = LoadCifar10DatasetTest(save, transform_test)
36testloader = torch.utils.data.DataLoader(testset, batch_size=64,
37                                         shuffle=False, num_workers=4)

Create CNN model

40def GetCNN():
41    cnn = CNN( in_features=(32,32,3),
42                out_features=10,
43                conv_filters=[32,32,64,64],
44                conv_kernel_size=[3,3,3,3],
45                conv_strides=[1,1,1,1],
46                conv_pad=[0,0,0,0],
47                max_pool_kernels=[None, (2,2), None, (2,2)],
48                max_pool_strides=[None,2,None,2],
49                use_dropout=False,
50                use_batch_norm=True, #False
51                actv_func=["relu", "relu", "relu", "relu"],
52                device=device
53        )
54
55    return cnn
56
57model = GetCNN()

Display model specifications

60summary(model, (3,32,32))

Send model to GPU

63model.to(device)

Specify optimizer

66opt = optim.Adam(model.parameters(), lr=0.0005, betas=(0.9, 0.95))

Specify loss function

69cost = nn.CrossEntropyLoss()

Train the model

72trainer = Trainer(device=device, name="Basic_CNN")
73epochs = 5
74trainer.Train(model, trainloader, testloader, cost=cost, opt=opt, epochs=epochs)

Load best saved model for inference

77model_loaded = GetCNN()

Specify location of saved model

80PATH = "./save/Basic_CNN-best-model/model.pt"
81checkpoint = torch.load(PATH)

load the saved model

84model_loaded.load_state_dict(checkpoint['state_dict'])

intialization for hooks and storing activation of ReLU layers

87activation = {}
88hooks = []

Hook function saves activation of a particular layer

91def hook_fn(model, input, output, name):
92    activation[name] = output.cpu().detach().numpy()

Registering hooks

95count =0
96conv_count = 0
97for name, layer in model_loaded.named_modules():
98    if isinstance(layer, nn.ReLU):
99        count +=1
100        hook = layer.register_forward_hook(partial(hook_fn, name=f"{layer._get_name()}-{count}")) #f"{type(layer).__name__}-{name}"
101        hooks.append(hook)
102    if isinstance(layer, nn.Conv2d):
103        conv_count += 1

Displaying image used for inference

106data, _ = trainset[15]
107imshow(data)

Infering model to save activation of ReLU layers

110output = model_loaded(data[None].to(device))

Removing hooks

113for hook in hooks:
114    hook.remove()

Function to display output of a particular ReLU layer

117def output_one_layer(layer_num):
118    assert 1 <= layer_num <= len(activation), "Wrong layer number"
119
120    layer_name = f"ReLu-{layer_num}"
121    act = activation[f"ReLU-{layer_num}"]
122    if act.shape[1]==32:
123        rows = 4
124        columns = 8
125    elif act.shape[1]==64:
126        rows = 8
127        columns = 8
128
129    fig = plt.figure(figsize=(rows, columns))
130    for idx in range(1, columns * rows + 1):
131        fig.add_subplot(rows, columns, idx)
132        plt.imshow(sobel(act[0][idx-1]), cmap=plt.cm.gray)

try different filters plt.imshow(act[0][idx-1], cmap=’viridis’, vmin=0, vmax=act.max()) plt.imshow(act[0][idx - 1], cmap=’hot’) plt.imshow(roberts(act[0][idx - 1]), cmap=plt.cm.gray) plt.imshow(sobel_h(act[0][idx-1]), cmap=plt.cm.gray)

140        plt.axis('off')
141
142    plt.tight_layout()
143    plt.show()

Function to display output of all ReLU layer after Convulution layers

146def output_all_layers():
147    for [name, output], count in zip(activation.items(), range(conv_count)):
148        if output.shape[1] == 32:
149            _, axs = plt.subplots(8, 4, figsize=(8, 4))
150        elif output.shape[1] == 64:
151            _, axs = plt.subplots(8, 8, figsize=(8, 8))
152
153        for ax, out in zip(np.ravel(axs), output[0]):
154            ax.imshow(sobel(out), cmap=plt.cm.gray)
155            ax.axis('off')
156
157        plt.suptitle(name)
158        plt.tight_layout()
159        plt.show()

Choose either one to display

162output_one_layer(layer_num=3) # choose layer number
163output_all_layers()