dice_ml.model_interfaces package

Submodules

dice_ml.model_interfaces.base_model module

Module containing a template class as an interface to ML model. Subclasses implement model interfaces for different ML frameworks such as TensorFlow, PyTorch OR Sklearn. All model interface methods are in dice_ml.model_interfaces

class dice_ml.model_interfaces.base_model.BaseModel(model=None, model_path='', backend='', func=None, kw_args=None)[source]

Bases: object

get_gradient()[source]
get_num_output_nodes(inp_size)[source]
get_num_output_nodes2(input_instance)[source]
get_output(input_instance, model_score=True)[source]

returns prediction probabilities for a classifier and the predicted output for a regressor.

Returns

an array of output scores for a classifier, and a singleton

array of predicted value for a regressor.

load_model()[source]

dice_ml.model_interfaces.keras_tensorflow_model module

Module containing an interface to trained Keras Tensorflow model.

class dice_ml.model_interfaces.keras_tensorflow_model.KerasTensorFlowModel(model=None, model_path='', backend='TF1', func=None, kw_args=None)[source]

Bases: BaseModel

get_gradient(input_instance)[source]
get_num_output_nodes(inp_size)[source]
get_output(input_tensor, training=False, transform_data=False)[source]

returns prediction probabilities

Parameters
  • input_tensor – test input.

  • training – to determine training mode in TF2.

  • transform_data – boolean to indicate if data transformation is required.

load_model()[source]

dice_ml.model_interfaces.pytorch_model module

Module containing an interface to trained PyTorch model.

class dice_ml.model_interfaces.pytorch_model.PyTorchModel(model=None, model_path='', backend='PYT', func=None, kw_args=None)[source]

Bases: BaseModel

get_gradient(input_instance)[source]
get_num_output_nodes(inp_size)[source]
get_output(input_instance, model_score=True, transform_data=False, out_tensor=False)[source]

returns prediction probabilities

Parameters
  • input_tensor – test input.

  • transform_data – boolean to indicate if data transformation is required.

load_model()[source]
set_eval_mode()[source]

Module contents