Source code for dice_ml.dice

"""Module pointing to different implementations of DiCE based on different
   frameworks such as Tensorflow or PyTorch or sklearn, and different methods
   such as RandomSampling, DiCEKD or DiCEGenetic"""

from raiutils.exceptions import UserConfigValidationException

from dice_ml.constants import BackEndTypes, SamplingStrategy
from dice_ml.data_interfaces.private_data_interface import PrivateData
from dice_ml.explainer_interfaces.explainer_base import ExplainerBase


[docs]class Dice(ExplainerBase): """An interface class to different DiCE implementations.""" def __init__(self, data_interface, model_interface, method="random", **kwargs): """Init method :param data_interface: an interface to access data related params. :param model_interface: an interface to access the output or gradients of a trained ML model. :param method: Name of the method to use for generating counterfactuals """ self.decide_implementation_type(data_interface, model_interface, method, **kwargs)
[docs] def decide_implementation_type(self, data_interface, model_interface, method, **kwargs): """Decides DiCE implementation type.""" if model_interface.backend == BackEndTypes.Sklearn: if method == SamplingStrategy.KdTree and isinstance(data_interface, PrivateData): raise UserConfigValidationException( 'Private data interface is not supported with kdtree explainer' ' since kdtree explainer needs access to entire training data') self.__class__ = decide(model_interface, method) self.__init__(data_interface, model_interface, **kwargs)
def _generate_counterfactuals(self, query_instance, total_CFs, desired_class="opposite", desired_range=None, permitted_range=None, features_to_vary="all", stopping_threshold=0.5, posthoc_sparsity_param=0.1, posthoc_sparsity_algorithm="linear", verbose=False, **kwargs): raise NotImplementedError("This method should be implemented by the concrete classes " "that inherit from ExplainerBase")
[docs]def decide(model_interface, method): """Decides DiCE implementation type. To add new implementations of DiCE, add the class in explainer_interfaces subpackage and import-and-return the class in an elif loop as shown in the below method. """ if method == SamplingStrategy.Random: # random sampling of CFs from dice_ml.explainer_interfaces.dice_random import DiceRandom return DiceRandom elif method == SamplingStrategy.Genetic: from dice_ml.explainer_interfaces.dice_genetic import DiceGenetic return DiceGenetic elif method == SamplingStrategy.KdTree: from dice_ml.explainer_interfaces.dice_KD import DiceKD return DiceKD elif method == SamplingStrategy.Gradient: if model_interface.backend == BackEndTypes.Tensorflow1: # pretrained Keras Sequential model with Tensorflow 1.x backend from dice_ml.explainer_interfaces.dice_tensorflow1 import \ DiceTensorFlow1 return DiceTensorFlow1 elif model_interface.backend == BackEndTypes.Tensorflow2: # pretrained Keras Sequential model with Tensorflow 2.x backend from dice_ml.explainer_interfaces.dice_tensorflow2 import \ DiceTensorFlow2 return DiceTensorFlow2 elif model_interface.backend == BackEndTypes.Pytorch: # PyTorch backend from dice_ml.explainer_interfaces.dice_pytorch import DicePyTorch return DicePyTorch else: raise UserConfigValidationException( "{0} is only supported for differentiable neural network models. " "Please choose one of {1}, {2} or {3}".format( method, SamplingStrategy.Random, SamplingStrategy.Genetic, SamplingStrategy.KdTree )) elif method is None: # all other backends backend_dice = model_interface.backend['explainer'] module_name, class_name = backend_dice.split('.') module = __import__("dice_ml.explainer_interfaces." + module_name, fromlist=[class_name]) return getattr(module, class_name) else: raise UserConfigValidationException("Unsupported sample strategy {0} provided. " "Please choose one of {1}, {2} or {3}".format( method, SamplingStrategy.Random, SamplingStrategy.Genetic, SamplingStrategy.KdTree ))