Source code for dice_ml.utils.sample_architecture.vae_model

import torch
import torch.utils.data
from torch import nn


[docs]class CF_VAE(nn.Module): def __init__(self, d, encoded_size): super(CF_VAE, self).__init__() self.encoded_size = encoded_size self.data_size = len(d.ohe_encoded_feature_names) self.minx, self.maxx, self.encoded_categorical_feature_indexes, self.encoded_continuous_feature_indexes, \ self.cont_minx, self.cont_maxx, self.cont_precisions = d.get_data_params_for_gradient_dice() flattened_indexes = [item for sublist in self.encoded_categorical_feature_indexes for item in sublist] self.encoded_continuous_feature_indexes = [ix for ix in range(len(self.minx[0])) if ix not in flattened_indexes] self.encoded_start_cat = len(self.encoded_continuous_feature_indexes) # Plus 1 to the input encoding size and data size to incorporate the target class label self.encoder_mean = nn.Sequential( nn.Linear(self.data_size+1, 20), nn.BatchNorm1d(20), nn.Dropout(0.1), nn.ReLU(), nn.Linear(20, 16), nn.BatchNorm1d(16), nn.Dropout(0.1), nn.ReLU(), nn.Linear(16, 14), nn.BatchNorm1d(14), nn.Dropout(0.1), nn.ReLU(), nn.Linear(14, 12), nn.BatchNorm1d(12), nn.Dropout(0.1), nn.ReLU(), nn.Linear(12, self.encoded_size) ) self.encoder_var = nn.Sequential( nn.Linear(self.data_size+1, 20), nn.BatchNorm1d(20), nn.Dropout(0.1), nn.ReLU(), nn.Linear(20, 16), nn.BatchNorm1d(16), nn.Dropout(0.1), nn.ReLU(), nn.Linear(16, 14), nn.BatchNorm1d(14), nn.Dropout(0.1), nn.ReLU(), nn.Linear(14, 12), nn.BatchNorm1d(12), nn.Dropout(0.1), nn.ReLU(), nn.Linear(12, self.encoded_size), nn.Sigmoid() ) # Plus 1 to the input encoding size and data size to incorporate the target class label self.decoder_mean = nn.Sequential( nn.Linear(self.encoded_size+1, 12), nn.BatchNorm1d(12), nn.Dropout(0.1), nn.ReLU(), nn.Linear(12, 14), nn.BatchNorm1d(14), nn.Dropout(0.1), nn.ReLU(), nn.Linear(14, 16), nn.BatchNorm1d(16), nn.Dropout(0.1), nn.ReLU(), nn.Linear(16, 20), nn.BatchNorm1d(20), nn.Dropout(0.1), nn.ReLU(), nn.Linear(20, self.data_size), nn.Sigmoid() )
[docs] def encoder(self, x): mean = self.encoder_mean(x) logvar = 0.5 + self.encoder_var(x) return mean, logvar
[docs] def decoder(self, z): mean = self.decoder_mean(z) return mean
[docs] def sample_latent_code(self, mean, logvar): eps = torch.randn_like(logvar) return mean + torch.sqrt(logvar)*eps
[docs] def normal_likelihood(self, x, mean, logvar, raxis=1): return torch.sum(-0.5 * ((x - mean)*(1./logvar)*(x-mean) + torch.log(logvar)), axis=1)
[docs] def forward(self, x, c): c = c.view(c.shape[0], 1) c = torch.tensor(c).float() res = {} mc_samples = 50 em, ev = self.encoder(torch.cat((x, c), 1)) res['em'] = em res['ev'] = ev res['z'] = [] res['x_pred'] = [] res['mc_samples'] = mc_samples for _ in range(mc_samples): z = self.sample_latent_code(em, ev) x_pred = self.decoder(torch.cat((z, c), 1)) res['z'].append(z) res['x_pred'].append(x_pred) return res
[docs] def compute_elbo(self, x, c, pred_model): c = torch.tensor(c).float() c = c.view(c.shape[0], 1) em, ev = self.encoder(torch.cat((x, c), 1)) kl_divergence = 0.5*torch.mean(em**2 + ev - torch.log(ev) - 1, axis=1) z = self.sample_latent_code(em, ev) dm = self.decoder(torch.cat((z, c), 1)) log_px_z = torch.tensor(0.0) x_pred = dm return torch.mean(log_px_z), torch.mean(kl_divergence), x, x_pred, torch.argmax(pred_model(x_pred), dim=1)
[docs]class AutoEncoder(nn.Module): def __init__(self, d, encoded_size): super(AutoEncoder, self).__init__() self.encoded_size = encoded_size self.data_size = len(d.encoded_feature_names) self.encoded_categorical_feature_indexes = d.get_data_params()[2] self.encoded_continuous_feature_indexes = [] for i in range(self.data_size): valid = 1 for v in self.encoded_categorical_feature_indexes: if i in v: valid = 0 if valid: self.encoded_continuous_feature_indexes.append(i) self.encoded_start_cat = len(self.encoded_continuous_feature_indexes) self.encoder_mean = nn.Sequential( nn.Linear(self.data_size, 20), nn.BatchNorm1d(20), nn.Dropout(0.1), nn.ReLU(), nn.Linear(20, 16), nn.BatchNorm1d(16), nn.Dropout(0.1), nn.ReLU(), nn.Linear(16, 14), nn.BatchNorm1d(14), nn.Dropout(0.1), nn.ReLU(), nn.Linear(14, 12), nn.BatchNorm1d(12), nn.Dropout(0.1), nn.ReLU(), nn.Linear(12, self.encoded_size) ) self.encoder_var = nn.Sequential( nn.Linear(self.data_size, 20), nn.BatchNorm1d(20), nn.Dropout(0.1), nn.ReLU(), nn.Linear(20, 16), nn.BatchNorm1d(16), nn.Dropout(0.1), nn.ReLU(), nn.Linear(16, 14), nn.BatchNorm1d(14), nn.Dropout(0.1), nn.ReLU(), nn.Linear(14, 12), nn.BatchNorm1d(12), nn.Dropout(0.1), nn.ReLU(), nn.Linear(12, self.encoded_size), nn.Sigmoid() ) self.decoder_mean = nn.Sequential( nn.Linear(self.encoded_size, 12), nn.BatchNorm1d(12), nn.Dropout(0.1), nn.ReLU(), nn.Linear(12, 14), nn.BatchNorm1d(14), nn.Dropout(0.1), nn.ReLU(), nn.Linear(14, 16), nn.BatchNorm1d(16), nn.Dropout(0.1), nn.ReLU(), nn.Linear(16, 20), nn.BatchNorm1d(20), nn.Dropout(0.1), nn.ReLU(), nn.Linear(20, self.data_size), nn.Sigmoid() )
[docs] def encoder(self, x): mean = self.encoder_mean(x) logvar = 0.05 + self.encoder_var(x) return mean, logvar
[docs] def decoder(self, z): mean = self.decoder_mean(z) return mean
[docs] def sample_latent_code(self, mean, logvar): eps = torch.randn_like(logvar) return mean + torch.sqrt(logvar)*eps
[docs] def normal_likelihood(self, x, mean, logvar, raxis=1): return torch.sum(-0.5 * ((x - mean)*(1./logvar)*(x-mean) + torch.log(logvar)), axis=1)
[docs] def forward(self, x): res = {} mc_samples = 50 em, ev = self.encoder(x) res['em'] = em res['ev'] = ev res['z'] = [] res['x_pred'] = [] res['mc_samples'] = mc_samples for _ in range(mc_samples): z = self.sample_latent_code(em, ev) x_pred = self.decoder(z) res['z'].append(z) res['x_pred'].append(x_pred) return res