2import torch
3import torch.nn as nn
4import torchvision
5import torchvision.transforms as transforms
6import torch.optim as optim
7from torchsummary import summary

custom import

10import numpy as np
11import time
12import os

ResBlock

16class ResBlock(nn.Module):
17    def __init__(self, num_features, use_batch_norm=False):
18        super(ResBlock, self).__init__()
19        self.num_features = num_features
20        self.conv_layer1 = nn.Conv2d(num_features, num_features,  kernel_size=3, stride=1, padding=1)
21        self.relu_layer = nn.ReLU()
22        self.conv_layer2 = nn.Conv2d(num_features, num_features, kernel_size=3, stride=1, padding=1)
23
24        self.use_batch_norm = use_batch_norm
25        if self.use_batch_norm:
26            self.batch_norm_layer1 = nn.BatchNorm2d(self.num_features)
27            self.batch_norm_layer2 = nn.BatchNorm2d(self.num_features)
28
29        for m in self.modules():
30            if isinstance(m, nn.Conv2d):
31                nn.init.kaiming_normal_(m.weight)

nn.init.xavier_uniform_(m.weight)

34    def forward(self, x):
35        residual = x
36        x = self.conv_layer1(x)
37        if self.use_batch_norm:
38            x = self.batch_norm_layer1(x)
39
40        x = self.relu_layer(x)
41        x = self.conv_layer2(x)
42        if self.use_batch_norm:
43            x = self.batch_norm_layer2(x)
44
45        x += residual
46        x = self.relu_layer(x)
47        return x

ResNet

50class ResNet(nn.Module):
51    def __init__(self, in_features, num_class, feature_channel_list, batch_norm= False, num_stacks=1, zero_init_residual=True):
52        super(ResNet, self).__init__()
53        self.in_features = in_features
54        self.num_in_channel = in_features[2]
55        self.num_class = num_class
56        self.feature_channel_list = feature_channel_list
57        self.num_residual_blocks = len(self.feature_channel_list)
58        self.num_stacks = num_stacks
59        self.batch_norm = batch_norm
60        self.shape_list = []
61        self.shape_list.append(in_features)
62        self.module_list = nn.ModuleList()
63        self.zero_init_residual= zero_init_residual
64        self.build_()
66    def build_(self):

track filter shape

68        cur_shape = self.GetCurShape()
69        cur_shape = self.CalcConvOutShape(cur_shape, kernel_size=7, padding=1, stride=2, out_filters= self.feature_channel_list[0])
70        self.shape_list.append(cur_shape)
71
72        if len(self.in_features) == 2:
73            in_channels = 1
74        else:
75            in_channels = self.in_features[2]

First Conv layer 7x7 stride=2, pad =1

78        self.module_list.append(nn.Conv2d(in_channels= in_channels,
79                                    out_channels= self.feature_channel_list[0],
80                                    kernel_size=7,
81                                    stride=2,
82                                    padding=3))

batch norm

86        if self.batch_norm: #batch_norm
87            self.module_list.append(nn.BatchNorm2d(self.feature_channel_list[0]))

ReLU()

90        self.module_list.append(nn.ReLU())
91
92        for i in range(self.num_residual_blocks-1):
93            in_size = self.feature_channel_list[i]
94            out_size = self.feature_channel_list[i+1]
95
96            res_block = ResBlock(in_size, use_batch_norm=True)

Stacking Residual blocks

99            for num in range(self.num_stacks):
100                self.module_list.append(res_block)

Intermediate Conv and ReLU()

103            self.module_list.append(nn.Conv2d(in_channels=in_size,
104                                              out_channels= out_size,
105                                              kernel_size=3,
106                                              padding=1,
107                                              stride=2))

track filter shape

110            cur_shape = self.CalcConvOutShape(cur_shape, kernel_size=3, padding=1,
111                                         stride=2, out_filters=out_size)
112
113            self.shape_list.append(cur_shape)

batch norm

116            if self.batch_norm:  # batch_norm
117                self.module_list.append(nn.BatchNorm2d(out_size))
118
119            self.module_list.append(nn.ReLU())

print(“shape list”, self.shape_list)

TODO include in the main loop Last Residual block

125        res_block = ResBlock(out_size, use_batch_norm=True)
126        for num in range(self.num_stacks):
127            self.module_list.append(res_block)

Last AvgPool layer self.module_list.append(nn.AvgPool2d(kernel_size=2, stride=2, padding=0))

131        self.module_list.append(nn.MaxPool2d(kernel_size=2, stride=2, padding=0))

track filter shape

134        cur_shape = self.CalcConvOutShape(cur_shape, kernel_size=2, padding=0, stride=2, out_filters=out_size)
135        self.shape_list.append(cur_shape)
136
137        s = self.GetCurShape()
138        in_features = s[0] * s[1] * s[2]

Initialization

141        for m in self.modules():
142            if isinstance(m, nn.Conv2d):
143                nn.init.kaiming_normal_(m.weight)

nn.init.xavier_uniform_(m.weight)

if self.zero_init_residual: for m in self.modules(): if isinstance(m, ResBlock): nn.init.constant_(m.batch_norm_layer1.weight, 0) nn.init.constant_(m.batch_norm_layer2.weight, 0)

152    def GetCurShape(self):
153        return self.shape_list[-1]
155    def CalcConvFormula(self, W, K, P, S):
156        return int(np.floor(((W - K + 2 * P) / S) + 1))

https://stackoverflow.com/questions/53580088/calculate-the-output-size-in-convolution-layer Calculate the output shape after applying a convolution

160    def CalcConvOutShape(self, in_shape, kernel_size, padding, stride, out_filters):

Multiple options for different kernel shapes

162        if type(kernel_size) == int:
163            out_shape = [self.CalcConvFormula(in_shape[i], kernel_size, padding, stride) for i in range(2)]
164        else:
165            out_shape = [self.CalcConvFormula(in_shape[i], kernel_size[i], padding, stride) for i in range(2)]
166
167        return (out_shape[0], out_shape[1], out_filters)  # , batch_size... but not necessary.
169    def AddMLP(self, MLP):
170        if MLP:
171            self.module_list.append(MLP)

def MLP(self, in_features, num_classes, use_batch_norm=False, use_dropout=False, use_softmax=False): return nn.ReLU(nn.Linear(in_features, num_classes))

176    def forward(self, x):
177        for mod_name in self.module_list:
178            x = mod_name(x)
179        x = x.view(x.size(0), -1)  # flat #TODO check if it works
180        return x