2import torch
3import torch.nn as nn
4import torchvision
5import torchvision.transforms as transforms
6import torch.optim as optim
7from torchsummary import summary
custom import
10import numpy as np
11import time
12import os
ResBlock
16class ResBlock(nn.Module):
17 def __init__(self, num_features, use_batch_norm=False):
18 super(ResBlock, self).__init__()
19 self.num_features = num_features
20 self.conv_layer1 = nn.Conv2d(num_features, num_features, kernel_size=3, stride=1, padding=1)
21 self.relu_layer = nn.ReLU()
22 self.conv_layer2 = nn.Conv2d(num_features, num_features, kernel_size=3, stride=1, padding=1)
23
24 self.use_batch_norm = use_batch_norm
25 if self.use_batch_norm:
26 self.batch_norm_layer1 = nn.BatchNorm2d(self.num_features)
27 self.batch_norm_layer2 = nn.BatchNorm2d(self.num_features)
28
29 for m in self.modules():
30 if isinstance(m, nn.Conv2d):
31 nn.init.kaiming_normal_(m.weight)
nn.init.xavier_uniform_(m.weight)
34 def forward(self, x):
35 residual = x
36 x = self.conv_layer1(x)
37 if self.use_batch_norm:
38 x = self.batch_norm_layer1(x)
39
40 x = self.relu_layer(x)
41 x = self.conv_layer2(x)
42 if self.use_batch_norm:
43 x = self.batch_norm_layer2(x)
44
45 x += residual
46 x = self.relu_layer(x)
47 return x
ResNet
50class ResNet(nn.Module):
51 def __init__(self, in_features, num_class, feature_channel_list, batch_norm= False, num_stacks=1, zero_init_residual=True):
52 super(ResNet, self).__init__()
53 self.in_features = in_features
54 self.num_in_channel = in_features[2]
55 self.num_class = num_class
56 self.feature_channel_list = feature_channel_list
57 self.num_residual_blocks = len(self.feature_channel_list)
58 self.num_stacks = num_stacks
59 self.batch_norm = batch_norm
60 self.shape_list = []
61 self.shape_list.append(in_features)
62 self.module_list = nn.ModuleList()
63 self.zero_init_residual= zero_init_residual
64 self.build_()
66 def build_(self):
track filter shape
68 cur_shape = self.GetCurShape()
69 cur_shape = self.CalcConvOutShape(cur_shape, kernel_size=7, padding=1, stride=2, out_filters= self.feature_channel_list[0])
70 self.shape_list.append(cur_shape)
71
72 if len(self.in_features) == 2:
73 in_channels = 1
74 else:
75 in_channels = self.in_features[2]
First Conv layer 7x7 stride=2, pad =1
78 self.module_list.append(nn.Conv2d(in_channels= in_channels,
79 out_channels= self.feature_channel_list[0],
80 kernel_size=7,
81 stride=2,
82 padding=3))
batch norm
86 if self.batch_norm: #batch_norm
87 self.module_list.append(nn.BatchNorm2d(self.feature_channel_list[0]))
ReLU()
90 self.module_list.append(nn.ReLU())
91
92 for i in range(self.num_residual_blocks-1):
93 in_size = self.feature_channel_list[i]
94 out_size = self.feature_channel_list[i+1]
95
96 res_block = ResBlock(in_size, use_batch_norm=True)
99 for num in range(self.num_stacks):
100 self.module_list.append(res_block)
103 self.module_list.append(nn.Conv2d(in_channels=in_size,
104 out_channels= out_size,
105 kernel_size=3,
106 padding=1,
107 stride=2))
track filter shape
110 cur_shape = self.CalcConvOutShape(cur_shape, kernel_size=3, padding=1,
111 stride=2, out_filters=out_size)
112
113 self.shape_list.append(cur_shape)
116 if self.batch_norm: # batch_norm
117 self.module_list.append(nn.BatchNorm2d(out_size))
118
119 self.module_list.append(nn.ReLU())
print(“shape list”, self.shape_list)
TODO include in the main loop Last Residual block
125 res_block = ResBlock(out_size, use_batch_norm=True)
126 for num in range(self.num_stacks):
127 self.module_list.append(res_block)
Last AvgPool layer self.module_list.append(nn.AvgPool2d(kernel_size=2, stride=2, padding=0))
131 self.module_list.append(nn.MaxPool2d(kernel_size=2, stride=2, padding=0))
track filter shape
134 cur_shape = self.CalcConvOutShape(cur_shape, kernel_size=2, padding=0, stride=2, out_filters=out_size)
135 self.shape_list.append(cur_shape)
136
137 s = self.GetCurShape()
138 in_features = s[0] * s[1] * s[2]
Initialization
141 for m in self.modules():
142 if isinstance(m, nn.Conv2d):
143 nn.init.kaiming_normal_(m.weight)
nn.init.xavier_uniform_(m.weight)
if self.zero_init_residual: for m in self.modules(): if isinstance(m, ResBlock): nn.init.constant_(m.batch_norm_layer1.weight, 0) nn.init.constant_(m.batch_norm_layer2.weight, 0)
152 def GetCurShape(self):
153 return self.shape_list[-1]
155 def CalcConvFormula(self, W, K, P, S):
156 return int(np.floor(((W - K + 2 * P) / S) + 1))
https://stackoverflow.com/questions/53580088/calculate-the-output-size-in-convolution-layer Calculate the output shape after applying a convolution
160 def CalcConvOutShape(self, in_shape, kernel_size, padding, stride, out_filters):
Multiple options for different kernel shapes
162 if type(kernel_size) == int:
163 out_shape = [self.CalcConvFormula(in_shape[i], kernel_size, padding, stride) for i in range(2)]
164 else:
165 out_shape = [self.CalcConvFormula(in_shape[i], kernel_size[i], padding, stride) for i in range(2)]
166
167 return (out_shape[0], out_shape[1], out_filters) # , batch_size... but not necessary.
169 def AddMLP(self, MLP):
170 if MLP:
171 self.module_list.append(MLP)
def MLP(self, in_features, num_classes, use_batch_norm=False, use_dropout=False, use_softmax=False): return nn.ReLU(nn.Linear(in_features, num_classes))
176 def forward(self, x):
177 for mod_name in self.module_list:
178 x = mod_name(x)
179 x = x.view(x.size(0), -1) # flat #TODO check if it works
180 return x