3import torch
4import torch.nn as nn6class MLP(nn.Module):7 def __init__(self
8 , in_features
9 , out_features
10 , hidden_layers
11 , actv_func
12 , pre_module_list=None
13 , use_dropout=False
14 , use_batch_norm=False
15 , use_softmax=True
16 , device="cpu"
17 ):
18 super(MLP, self).__init__()
19
20 self.in_features = in_features
21 self.out_features = out_features
22 self.num_hidden_layers = len(hidden_layers)
23 self.hidden_layers = hidden_layers
24 self.use_dropout = use_dropout
25 self.use_batch_norm = use_batch_norm
26 self.actv_func = actv_func
27 self.use_softmax = use_softmax
28
29 self.device = deviceAdd on to another model
32 if pre_module_list:
33 self.module_list = pre_module_list
34 else:
35 self.module_list = nn.ModuleList()
36
37 self.build_()Send to gpu
40 self.to(self.device)42 def build_(self):Activation Functions for Fully connected layers # Start with input dimensions
45 dim = self.in_features
46 for i in range(self.num_hidden_layers):Create a fully connected layer between the last layer and the current hidden layer
49 self.module_list.append(nn.Linear(dim, self.hidden_layers[i]))Update the current dimension
51 dim = self.hidden_layers[i]
52
53 if self.use_batch_norm:
54 self.module_list.append( nn.BatchNorm1d(dim, affine=True) )Add the Activation function
57 self.module_list.append( self.GetActivation(name=self.actv_func[i]) )
58
59 if self.use_dropout:
60 self.module_list.append( nn.Dropout(p=0.10) )Fully connect to output dimensions
63 if dim != self.out_features:
64 self.module_list.append( nn.Linear(dim, self.out_features) )67 def forward(self, x):Flatten the 2d image into 1d Also convert into float for FC layer
70 x = torch.flatten(x.float(), start_dim=1)Apply each layer in the module list
73 for i in range( len(self.module_list) ):
74 x = self.module_list[i](x)
75
76 return x78 def GetActivation(self, name="relu"):
79 if name == "relu":
80 return nn.ReLU()
81 elif name == "leakyrelu":
82 return nn.LeakyReLU()
83 elif name == "Sigmoid":
84 return nn.Sigmoid()
85 elif name == "Tanh":
86 return nn.Tanh()
87 elif name == "Identity":
88 return nn.Identity()
89 else:
90 return nn.ReLU()