3import torch
4import torch.nn as nn6class MLP(nn.Module):7    def __init__(self
8            , in_features
9            , out_features
10            , hidden_layers
11            , actv_func
12            , pre_module_list=None
13            , use_dropout=False
14            , use_batch_norm=False
15            , use_softmax=True
16            , device="cpu"
17            ):
18        super(MLP, self).__init__()
19
20        self.in_features = in_features
21        self.out_features = out_features
22        self.num_hidden_layers = len(hidden_layers)
23        self.hidden_layers = hidden_layers
24        self.use_dropout = use_dropout
25        self.use_batch_norm = use_batch_norm
26        self.actv_func = actv_func
27        self.use_softmax = use_softmax
28
29        self.device = deviceAdd on to another model
32        if pre_module_list:
33            self.module_list = pre_module_list
34        else:
35            self.module_list = nn.ModuleList()
36
37        self.build_()Send to gpu
40        self.to(self.device)42    def build_(self):Activation Functions for Fully connected layers # Start with input dimensions
45        dim = self.in_features
46        for i in range(self.num_hidden_layers):Create a fully connected layer between the last layer and the current hidden layer
49            self.module_list.append(nn.Linear(dim, self.hidden_layers[i]))Update the current dimension
51            dim = self.hidden_layers[i]
52
53            if self.use_batch_norm:
54                self.module_list.append( nn.BatchNorm1d(dim, affine=True) )Add the Activation function
57            self.module_list.append( self.GetActivation(name=self.actv_func[i]) )
58
59            if self.use_dropout:
60                self.module_list.append( nn.Dropout(p=0.10) )Fully connect to output dimensions
63        if dim != self.out_features:
64            self.module_list.append( nn.Linear(dim, self.out_features) )67    def forward(self, x):Flatten the 2d image into 1d Also convert into float for FC layer
70        x = torch.flatten(x.float(), start_dim=1)Apply each layer in the module list
73        for i in range( len(self.module_list) ):
74            x = self.module_list[i](x)
75
76        return x78    def GetActivation(self, name="relu"):
79        if name == "relu":
80            return nn.ReLU()
81        elif name == "leakyrelu":
82            return nn.LeakyReLU()
83        elif name == "Sigmoid":
84            return nn.Sigmoid()
85        elif name == "Tanh":
86            return nn.Tanh()
87        elif name == "Identity":
88            return nn.Identity()
89        else:
90            return nn.ReLU()