2import torch
3import torch.nn as nn
4import torchvision
5import torchvision.transforms as transforms
6import torch.optim as optim
7from torchsummary import summarycustom import
10import numpy as np
11import time
12import osResBlock
16class ResBlock(nn.Module):17    def __init__(self, num_features, use_batch_norm=False):
18        super(ResBlock, self).__init__()
19        self.num_features = num_features
20        self.conv_layer1 = nn.Conv2d(num_features, num_features,  kernel_size=3, stride=1, padding=1)
21        self.relu_layer = nn.ReLU()
22        self.conv_layer2 = nn.Conv2d(num_features, num_features, kernel_size=3, stride=1, padding=1)
23
24        self.use_batch_norm = use_batch_norm
25        if self.use_batch_norm:
26            self.batch_norm_layer1 = nn.BatchNorm2d(self.num_features)
27            self.batch_norm_layer2 = nn.BatchNorm2d(self.num_features)
28
29        for m in self.modules():
30            if isinstance(m, nn.Conv2d):
31                nn.init.kaiming_normal_(m.weight)nn.init.xavier_uniform_(m.weight)
34    def forward(self, x):
35        residual = x
36        x = self.conv_layer1(x)
37        if self.use_batch_norm:
38            x = self.batch_norm_layer1(x)
39
40        x = self.relu_layer(x)
41        x = self.conv_layer2(x)
42        if self.use_batch_norm:
43            x = self.batch_norm_layer2(x)
44
45        x += residual
46        x = self.relu_layer(x)
47        return xResNet
50class ResNet(nn.Module):51    def __init__(self, in_features, num_class, feature_channel_list, batch_norm= False, num_stacks=1, zero_init_residual=True):
52        super(ResNet, self).__init__()
53        self.in_features = in_features
54        self.num_in_channel = in_features[2]
55        self.num_class = num_class
56        self.feature_channel_list = feature_channel_list
57        self.num_residual_blocks = len(self.feature_channel_list)
58        self.num_stacks = num_stacks
59        self.batch_norm = batch_norm
60        self.shape_list = []
61        self.shape_list.append(in_features)
62        self.module_list = nn.ModuleList()
63        self.zero_init_residual= zero_init_residual
64        self.build_()66    def build_(self):track filter shape
68        cur_shape = self.GetCurShape()
69        cur_shape = self.CalcConvOutShape(cur_shape, kernel_size=7, padding=1, stride=2, out_filters= self.feature_channel_list[0])
70        self.shape_list.append(cur_shape)
71
72        if len(self.in_features) == 2:
73            in_channels = 1
74        else:
75            in_channels = self.in_features[2]First Conv layer 7x7 stride=2, pad =1
78        self.module_list.append(nn.Conv2d(in_channels= in_channels,
79                                    out_channels= self.feature_channel_list[0],
80                                    kernel_size=7,
81                                    stride=2,
82                                    padding=3))batch norm
86        if self.batch_norm: #batch_norm
87            self.module_list.append(nn.BatchNorm2d(self.feature_channel_list[0]))ReLU()
90        self.module_list.append(nn.ReLU())
91
92        for i in range(self.num_residual_blocks-1):
93            in_size = self.feature_channel_list[i]
94            out_size = self.feature_channel_list[i+1]
95
96            res_block = ResBlock(in_size, use_batch_norm=True)99            for num in range(self.num_stacks):
100                self.module_list.append(res_block)103            self.module_list.append(nn.Conv2d(in_channels=in_size,
104                                              out_channels= out_size,
105                                              kernel_size=3,
106                                              padding=1,
107                                              stride=2))track filter shape
110            cur_shape = self.CalcConvOutShape(cur_shape, kernel_size=3, padding=1,
111                                         stride=2, out_filters=out_size)
112
113            self.shape_list.append(cur_shape)116            if self.batch_norm:  # batch_norm
117                self.module_list.append(nn.BatchNorm2d(out_size))
118
119            self.module_list.append(nn.ReLU())print(“shape list”, self.shape_list)
TODO include in the main loop Last Residual block
125        res_block = ResBlock(out_size, use_batch_norm=True)
126        for num in range(self.num_stacks):
127            self.module_list.append(res_block)Last AvgPool layer self.module_list.append(nn.AvgPool2d(kernel_size=2, stride=2, padding=0))
131        self.module_list.append(nn.MaxPool2d(kernel_size=2, stride=2, padding=0))track filter shape
134        cur_shape = self.CalcConvOutShape(cur_shape, kernel_size=2, padding=0, stride=2, out_filters=out_size)
135        self.shape_list.append(cur_shape)
136
137        s = self.GetCurShape()
138        in_features = s[0] * s[1] * s[2]Initialization
141        for m in self.modules():
142            if isinstance(m, nn.Conv2d):
143                nn.init.kaiming_normal_(m.weight)nn.init.xavier_uniform_(m.weight)
if self.zero_init_residual: for m in self.modules(): if isinstance(m, ResBlock): nn.init.constant_(m.batch_norm_layer1.weight, 0) nn.init.constant_(m.batch_norm_layer2.weight, 0)
152    def GetCurShape(self):
153        return self.shape_list[-1]155    def CalcConvFormula(self, W, K, P, S):
156        return int(np.floor(((W - K + 2 * P) / S) + 1))https://stackoverflow.com/questions/53580088/calculate-the-output-size-in-convolution-layer Calculate the output shape after applying a convolution
160    def CalcConvOutShape(self, in_shape, kernel_size, padding, stride, out_filters):Multiple options for different kernel shapes
162        if type(kernel_size) == int:
163            out_shape = [self.CalcConvFormula(in_shape[i], kernel_size, padding, stride) for i in range(2)]
164        else:
165            out_shape = [self.CalcConvFormula(in_shape[i], kernel_size[i], padding, stride) for i in range(2)]
166
167        return (out_shape[0], out_shape[1], out_filters)  # , batch_size... but not necessary.169    def AddMLP(self, MLP):
170        if MLP:
171            self.module_list.append(MLP)def MLP(self, in_features, num_classes, use_batch_norm=False, use_dropout=False, use_softmax=False): return nn.ReLU(nn.Linear(in_features, num_classes))
176    def forward(self, x):
177        for mod_name in self.module_list:
178            x = mod_name(x)
179        x = x.view(x.size(0), -1)  # flat #TODO check if it works
180        return x