Source code for dice_ml.counterfactual_explanations

import json
import os

import jsonschema
from raiutils.exceptions import UserConfigValidationException

from dice_ml.constants import _SchemaVersions
from dice_ml.diverse_counterfactuals import (CounterfactualExamples,
                                             _DiverseCFV2SchemaConstants)


class _CommonSchemaConstants:
    LOCAL_IMPORTANCE = 'local_importance'
    SUMMARY_IMPORTANCE = 'summary_importance'
    METADATA = 'metadata'


class _CounterfactualExpV1SchemaConstants:
    CF_EXAMPLES_LIST = 'cf_examples_list'
    LOCAL_IMPORTANCE = _CommonSchemaConstants.LOCAL_IMPORTANCE
    SUMMARY_IMPORTANCE = _CommonSchemaConstants.SUMMARY_IMPORTANCE
    METADATA = _CommonSchemaConstants.METADATA


class _CounterfactualExpV2SchemaConstants:
    TEST_DATA = 'test_data'
    CFS_LIST = 'cfs_list'
    LOCAL_IMPORTANCE = _CommonSchemaConstants.LOCAL_IMPORTANCE
    SUMMARY_IMPORTANCE = _CommonSchemaConstants.SUMMARY_IMPORTANCE
    METADATA = _CommonSchemaConstants.METADATA
    MODEL_TYPE = 'model_type'
    DATA_INTERFACE = 'data_interface'
    FEATURE_NAMES = 'feature_names'
    DESIRED_CLASS = 'desired_class'
    DESIRED_RANGE = 'desired_range'
    FEATURE_NAMES_INCLUDING_TARGET = 'feature_names_including_target'


def _check_supported_json_output_versions(version):
    return version in _SchemaVersions.ALL_VERSIONS


[docs]class CounterfactualExplanations: """A class to store counterfactual examples for one or more inputs and feature importance scores. :param cf_examples_list: A list of CounterfactualExamples instances :param local_importance: List of estimated local importance scores. The size of the list is the number of input instances, each containing feature importance scores for that input. :param summary_importance: Estimated global feature importance scores based on the input set of CounterfactualExamples instances """ def __init__(self, cf_examples_list, local_importance=None, summary_importance=None, version=None): self._cf_examples_list = cf_examples_list self._local_importance = local_importance self._summary_importance = summary_importance self._metadata = {'version': version if version is not None else _SchemaVersions.CURRENT_VERSION} def __eq__(self, other_cf): if isinstance(other_cf, CounterfactualExplanations): return self.cf_examples_list == other_cf.cf_examples_list and \ self.local_importance == other_cf.local_importance and \ self.summary_importance == other_cf.summary_importance and \ self.metadata == other_cf.metadata return False @property def cf_examples_list(self): return self._cf_examples_list @property def local_importance(self): if isinstance(self._local_importance, list): sorted_local_importance = [] for local_importance_instance in self._local_importance: local_importance_instance = \ dict(sorted(local_importance_instance.items(), key=lambda x: x[1], reverse=True)) sorted_local_importance.append(local_importance_instance) self._local_importance = sorted_local_importance return self._local_importance @property def summary_importance(self): if isinstance(self._summary_importance, dict): self._summary_importance = \ dict(sorted(self._summary_importance.items(), key=lambda x: x[1], reverse=True)) return self._summary_importance @property def metadata(self): return self._metadata
[docs] def visualize_as_dataframe(self, display_sparse_df=True, show_only_changes=False): for cf_examples in self.cf_examples_list: cf_examples.visualize_as_dataframe( display_sparse_df=display_sparse_df, show_only_changes=show_only_changes)
[docs] def visualize_as_list(self, display_sparse_df=True, show_only_changes=False): for cf_examples in self.cf_examples_list: cf_examples.visualize_as_list( display_sparse_df=display_sparse_df, show_only_changes=show_only_changes)
@staticmethod def _check_cf_exp_output_against_json_schema( cf_dict, version): """ Validate the dictionary version of the counterfactual explanations. :param cf_dict: Serialized version of the counterfactual explanations. :type cf_dict: Dict """ schema_file_name = 'counterfactual_explanations_v{0}.json'.format(version) schema_path = os.path.join(os.path.dirname(__file__), 'schema', schema_file_name) with open(schema_path, 'r') as schema_file: schema_json = json.load(schema_file) jsonschema.validate(cf_dict, schema_json)
[docs] def to_json(self): """ Serialize Explanations object to json. """ serialization_version = self.metadata['version'] if serialization_version == _SchemaVersions.V1: cf_examples_str_list = [] for cf_examples in self.cf_examples_list: cf_examples_str = cf_examples.to_json( serialization_version=serialization_version) cf_examples_str_list.append(cf_examples_str) entire_dict = { _CounterfactualExpV1SchemaConstants.CF_EXAMPLES_LIST: cf_examples_str_list, _CounterfactualExpV1SchemaConstants.LOCAL_IMPORTANCE: self.local_importance, _CounterfactualExpV1SchemaConstants.SUMMARY_IMPORTANCE: self.summary_importance, _CounterfactualExpV1SchemaConstants.METADATA: self.metadata } CounterfactualExplanations._check_cf_exp_output_against_json_schema( entire_dict, version=serialization_version) return json.dumps(entire_dict) elif serialization_version == _SchemaVersions.V2: combined_test_instance_list = [] combined_final_cfs_list = [] data_interface = None feature_names = None feature_names_including_target = None model_type = None desired_class = None desired_range = None for cf_examples in self.cf_examples_list: cf_examples_str = cf_examples.to_json( serialization_version=serialization_version) # We need to load the json again since we need to decompose the # counterfactual example into different schema fields serialized_cf_examples = json.loads(cf_examples_str) combined_test_instance_list.append(serialized_cf_examples[ _DiverseCFV2SchemaConstants.TEST_INSTANCE_LIST]) combined_final_cfs_list.append(serialized_cf_examples[ _DiverseCFV2SchemaConstants.FINAL_CFS_LIST]) data_interface = serialized_cf_examples[ _DiverseCFV2SchemaConstants.DATA_INTERFACE] feature_names = serialized_cf_examples[ _DiverseCFV2SchemaConstants.FEATURE_NAMES] feature_names_including_target = serialized_cf_examples[ _DiverseCFV2SchemaConstants.FEATURE_NAMES_INCLUDING_TARGET] model_type = serialized_cf_examples[ _DiverseCFV2SchemaConstants.MODEL_TYPE] desired_class = serialized_cf_examples[ _DiverseCFV2SchemaConstants.DESIRED_CLASS] desired_range = serialized_cf_examples[ _DiverseCFV2SchemaConstants.DESIRED_RANGE] local_importance_matrix = None if self.local_importance is not None: local_importance_matrix = [] for local_importance_dict in self.local_importance: local_importance_list = [] for feature_name in feature_names: local_importance_list.append(local_importance_dict.get(feature_name)) local_importance_matrix.append(local_importance_list) summary_importance_list = None if self.summary_importance is not None: summary_importance_list = [] for feature_name in feature_names: summary_importance_list.append(self.summary_importance.get(feature_name)) entire_dict = { _CounterfactualExpV2SchemaConstants.TEST_DATA: combined_test_instance_list, _CounterfactualExpV2SchemaConstants.CFS_LIST: combined_final_cfs_list, _CounterfactualExpV2SchemaConstants.LOCAL_IMPORTANCE: local_importance_matrix, _CounterfactualExpV2SchemaConstants.SUMMARY_IMPORTANCE: summary_importance_list, _CounterfactualExpV2SchemaConstants.DATA_INTERFACE: data_interface, _CounterfactualExpV2SchemaConstants.FEATURE_NAMES: feature_names, _CounterfactualExpV2SchemaConstants.FEATURE_NAMES_INCLUDING_TARGET: feature_names_including_target, _CounterfactualExpV2SchemaConstants.MODEL_TYPE: model_type, _CounterfactualExpV2SchemaConstants.DESIRED_CLASS: desired_class, _CounterfactualExpV2SchemaConstants.DESIRED_RANGE: desired_range, _CounterfactualExpV1SchemaConstants.METADATA: self.metadata } CounterfactualExplanations._check_cf_exp_output_against_json_schema( entire_dict, version=serialization_version) return json.dumps(entire_dict) else: raise UserConfigValidationException( "Unsupported serialization version {}".format(serialization_version))
@staticmethod def _validate_serialization_version(version): if version is None: raise UserConfigValidationException("No version field in the json input") elif not _check_supported_json_output_versions(version): raise UserConfigValidationException("Incompatible version {} found in json input".format(version))
[docs] @staticmethod def from_json(json_str): """ Deserialize json string to a CounterfactualExplanations object. """ json_dict = json.loads(json_str) if _CommonSchemaConstants.METADATA in json_dict: version = json_dict[_CommonSchemaConstants.METADATA].get('version') CounterfactualExplanations._validate_serialization_version(version) if version == _SchemaVersions.V1: CounterfactualExplanations._check_cf_exp_output_against_json_schema( json_dict, version=version) cf_examples_list = [] for cf_examples_str in json_dict[_CounterfactualExpV1SchemaConstants.CF_EXAMPLES_LIST]: cf_examples_list.append(CounterfactualExamples.from_json(cf_examples_str)) return CounterfactualExplanations( cf_examples_list=cf_examples_list, local_importance=json_dict[_CounterfactualExpV1SchemaConstants.LOCAL_IMPORTANCE], summary_importance=json_dict[_CounterfactualExpV1SchemaConstants.SUMMARY_IMPORTANCE], version=version) else: CounterfactualExplanations._check_cf_exp_output_against_json_schema( json_dict, version=version) cf_examples_list = [] for index in range(0, len(json_dict[_CounterfactualExpV2SchemaConstants.CFS_LIST])): # We need to save the json again since we need to recompose the # counterfactual example. cf_examples_str = json.dumps( { _DiverseCFV2SchemaConstants.FINAL_CFS_LIST: json_dict[ _CounterfactualExpV2SchemaConstants.CFS_LIST][index], _DiverseCFV2SchemaConstants.TEST_INSTANCE_LIST: json_dict[ _CounterfactualExpV2SchemaConstants.TEST_DATA][index], _DiverseCFV2SchemaConstants.DATA_INTERFACE: json_dict[ _CounterfactualExpV2SchemaConstants.DATA_INTERFACE], _DiverseCFV2SchemaConstants.DESIRED_CLASS: json_dict[ _CounterfactualExpV2SchemaConstants.DESIRED_CLASS], _DiverseCFV2SchemaConstants.DESIRED_RANGE: json_dict[ _CounterfactualExpV2SchemaConstants.DESIRED_RANGE], _DiverseCFV2SchemaConstants.MODEL_TYPE: json_dict[ _CounterfactualExpV2SchemaConstants.MODEL_TYPE], _DiverseCFV2SchemaConstants.FEATURE_NAMES_INCLUDING_TARGET: json_dict[ _CounterfactualExpV2SchemaConstants.FEATURE_NAMES_INCLUDING_TARGET] } ) cf_examples_list.append( CounterfactualExamples.from_json(cf_examples_str) ) local_importance_list = None if json_dict[_CounterfactualExpV2SchemaConstants.LOCAL_IMPORTANCE] is not None: local_importance_list = [] for local_importance_instance in json_dict[ _CounterfactualExpV2SchemaConstants.LOCAL_IMPORTANCE]: local_importance_dict = {} feature_names = json_dict[_CounterfactualExpV2SchemaConstants.FEATURE_NAMES] for index in range(0, len(local_importance_instance)): local_importance_dict[feature_names[index]] = local_importance_instance[index] local_importance_list.append(local_importance_dict) summary_importance_dict = None if json_dict[_CounterfactualExpV2SchemaConstants.SUMMARY_IMPORTANCE] is not None: summary_importance_dict = {} feature_names = json_dict[ _CounterfactualExpV2SchemaConstants.FEATURE_NAMES] for index in range(0, len(json_dict[ _CounterfactualExpV2SchemaConstants.SUMMARY_IMPORTANCE])): summary_importance_dict[feature_names[index]] = json_dict[ _CounterfactualExpV2SchemaConstants.SUMMARY_IMPORTANCE][index] return CounterfactualExplanations( cf_examples_list=cf_examples_list, local_importance=local_importance_list, summary_importance=summary_importance_dict, version=version) else: return json_dict