Source code for dice_ml.utils.neuralnetworks

from torch import nn, sigmoid


[docs]class FFNetwork(nn.Module): def __init__(self, input_size, is_classifier=True): super(FFNetwork, self).__init__() self.is_classifier = is_classifier self.flatten = nn.Flatten() self.linear_relu_stack = nn.Sequential( nn.Linear(input_size, 16), nn.ReLU(), nn.Linear(16, 1), )
[docs] def forward(self, x): x = self.flatten(x) out = self.linear_relu_stack(x) out = sigmoid(out) if not self.is_classifier: out = 3 * out # output between 0 and 3 return out
[docs]class MulticlassNetwork(nn.Module): def __init__(self, input_size: int, num_class: int): super(MulticlassNetwork, self).__init__() self.linear_relu_stack = nn.Sequential( nn.Linear(input_size, 16), nn.ReLU(), nn.Linear(16, num_class) ) self.softmax = nn.Softmax(dim=1)
[docs] def forward(self, x): x = self.linear_relu_stack(x) out = self.softmax(x) return out