Generating counterfactual explanations with any ML model
The goal of this notebook is to show how to generate CFs for ML models using frameworks other than TensorFlow or PyTorch. We show how to generate diverse CFs by three methods: 1. Independent random sampling of features (method_name='random')
2. Genetic algorithm (method_name='genetic')
3. Querying a KD tree (method_name='kdtree')
We use scikit-learn models for demonstration.
1. Independent random sampling of features
[1]:
# import DiCE
import dice_ml
from dice_ml.utils import helpers # helper functions
from sklearn.compose import ColumnTransformer
from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.ensemble import RandomForestClassifier
[2]:
%load_ext autoreload
%autoreload 2
Loading dataset
We use the “adult” income dataset from UCI Machine Learning Repository (https://archive.ics.uci.edu/ml/datasets/adult). We transform the data as described in dice_ml.utils.helpers module.
[3]:
dataset = helpers.load_adult_income_dataset()
[4]:
dataset.head()
[4]:
age | workclass | education | marital_status | occupation | race | gender | hours_per_week | income | |
---|---|---|---|---|---|---|---|---|---|
0 | 28 | Private | Bachelors | Single | White-Collar | White | Female | 60 | 0 |
1 | 30 | Self-Employed | Assoc | Married | Professional | White | Male | 65 | 1 |
2 | 32 | Private | Some-college | Married | White-Collar | White | Male | 50 | 0 |
3 | 20 | Private | Some-college | Single | Service | White | Female | 35 | 0 |
4 | 41 | Self-Employed | Some-college | Married | White-Collar | White | Male | 50 | 0 |
[5]:
d = dice_ml.Data(dataframe=dataset, continuous_features=['age', 'hours_per_week'], outcome_name='income')
Training a custom ML model
Below, we build an ML model using scikit-learn to demonstrate how our methods can work with any sklearn model.
[6]:
target = dataset["income"]
# Split data into train and test
datasetX = dataset.drop("income", axis=1)
x_train, x_test, y_train, y_test = train_test_split(datasetX,
target,
test_size=0.2,
random_state=0,
stratify=target)
numerical = ["age", "hours_per_week"]
categorical = x_train.columns.difference(numerical)
# We create the preprocessing pipelines for both numeric and categorical data.
numeric_transformer = Pipeline(steps=[
('scaler', StandardScaler())])
categorical_transformer = Pipeline(steps=[
('onehot', OneHotEncoder(handle_unknown='ignore'))])
transformations = ColumnTransformer(
transformers=[
('num', numeric_transformer, numerical),
('cat', categorical_transformer, categorical)])
# Append classifier to preprocessing pipeline.
# Now we have a full prediction pipeline.
clf = Pipeline(steps=[('preprocessor', transformations),
('classifier', RandomForestClassifier())])
model = clf.fit(x_train, y_train)
[7]:
# provide the trained ML model to DiCE's model object
backend = 'sklearn'
m = dice_ml.Model(model=model, backend=backend)
Generate diverse counterfactuals
[8]:
# initiate DiCE
exp_random = dice_ml.Dice(d, m, method="random")
[9]:
query_instances = x_train[4:6]
[10]:
# generate counterfactuals
dice_exp_random = exp_random.generate_counterfactuals(query_instances, total_CFs=2, desired_class="opposite", verbose=False)
100%|█████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 2.98it/s]
[11]:
dice_exp_random.visualize_as_dataframe(show_only_changes=True)
Query instance (original outcome : 0)
age | workclass | education | marital_status | occupation | race | gender | hours_per_week | income | |
---|---|---|---|---|---|---|---|---|---|
0 | 27 | Private | School | Single | Blue-Collar | White | Male | 40 | 0 |
Diverse Counterfactual set (new outcome: 1)
age | workclass | education | marital_status | occupation | race | gender | hours_per_week | income | |
---|---|---|---|---|---|---|---|---|---|
0 | 42 | - | Prof-school | - | - | - | - | - | 1 |
1 | 43 | - | Prof-school | - | - | - | - | - | 1 |
Query instance (original outcome : 1)
age | workclass | education | marital_status | occupation | race | gender | hours_per_week | income | |
---|---|---|---|---|---|---|---|---|---|
0 | 31 | Self-Employed | Some-college | Married | Sales | Other | Male | 60 | 1 |
Diverse Counterfactual set (new outcome: 0)
age | workclass | education | marital_status | occupation | race | gender | hours_per_week | income | |
---|---|---|---|---|---|---|---|---|---|
0 | - | - | Prof-school | Separated | - | - | - | - | 0 |
1 | 51 | - | Masters | - | - | - | - | - | 0 |
It can be observed that the random sampling method produces less sparse CFs in contrast to current DiCE’s implementation. The sparsity issue with random sampling worsens with increasing total_CFs
Further, different sets of counterfactuals can be generated with different random seeds.
[12]:
# generate counterfactuals
# default random seed is 17
dice_exp_random = exp_random.generate_counterfactuals(query_instances,
total_CFs=4,
desired_class="opposite",
random_seed=9)
100%|█████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 2.49it/s]
[13]:
dice_exp_random.visualize_as_dataframe(show_only_changes=True)
Query instance (original outcome : 0)
age | workclass | education | marital_status | occupation | race | gender | hours_per_week | income | |
---|---|---|---|---|---|---|---|---|---|
0 | 27 | Private | School | Single | Blue-Collar | White | Male | 40 | 0 |
Diverse Counterfactual set (new outcome: 1)
age | workclass | education | marital_status | occupation | race | gender | hours_per_week | income | |
---|---|---|---|---|---|---|---|---|---|
0 | 87 | - | Assoc | - | - | - | - | - | 1 |
1 | 67 | - | Assoc | - | - | - | - | - | 1 |
2 | 61 | - | Assoc | - | - | - | - | - | 1 |
3 | 72 | - | Assoc | - | - | - | - | - | 1 |
Query instance (original outcome : 1)
age | workclass | education | marital_status | occupation | race | gender | hours_per_week | income | |
---|---|---|---|---|---|---|---|---|---|
0 | 31 | Self-Employed | Some-college | Married | Sales | Other | Male | 60 | 1 |
Diverse Counterfactual set (new outcome: 0)
age | workclass | education | marital_status | occupation | race | gender | hours_per_week | income | |
---|---|---|---|---|---|---|---|---|---|
0 | - | - | - | Single | Professional | - | - | - | 0 |
1 | - | - | - | - | - | - | Female | 16 | 0 |
2 | 44 | - | - | Separated | - | - | - | - | 0 |
3 | 19 | - | - | - | - | - | - | - | 0 |
Selecting the features to vary
Here, you can ensure that DiCE varies only features that it makes sense to vary.
[14]:
# generate counterfactuals
dice_exp_random = exp_random.generate_counterfactuals(
query_instances, total_CFs=4, desired_class="opposite",
features_to_vary=['workclass', 'education', 'occupation', 'hours_per_week'])
100%|█████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:01<00:00, 1.64it/s]
[15]:
dice_exp_random.visualize_as_dataframe(show_only_changes=True)
Query instance (original outcome : 0)
age | workclass | education | marital_status | occupation | race | gender | hours_per_week | income | |
---|---|---|---|---|---|---|---|---|---|
0 | 27 | Private | School | Single | Blue-Collar | White | Male | 40 | 0 |
Diverse Counterfactual set (new outcome: 1)
age | workclass | education | marital_status | occupation | race | gender | hours_per_week | income | |
---|---|---|---|---|---|---|---|---|---|
0 | - | - | Bachelors | - | Sales | - | - | 82 | 1 |
1 | - | - | Bachelors | - | White-Collar | - | - | 85 | 1 |
2 | - | - | Bachelors | - | Sales | - | - | 70 | 1 |
Query instance (original outcome : 1)
age | workclass | education | marital_status | occupation | race | gender | hours_per_week | income | |
---|---|---|---|---|---|---|---|---|---|
0 | 31 | Self-Employed | Some-college | Married | Sales | Other | Male | 60 | 1 |
Diverse Counterfactual set (new outcome: 0)
age | workclass | education | marital_status | occupation | race | gender | hours_per_week | income | |
---|---|---|---|---|---|---|---|---|---|
0 | - | - | HS-grad | - | - | - | - | 34 | 0 |
1 | - | Other/Unknown | Doctorate | - | - | - | - | - | 0 |
2 | - | - | - | - | - | - | - | 10 | 0 |
3 | - | - | - | - | Blue-Collar | - | - | 16 | 0 |
Choosing feature ranges
Since the features are sampled randomly, they can freely vary across their range. In the below example, we show how range of continuous features can be controlled using permitted_range parameter that can now be passed during CF generation.
[16]:
# generate counterfactuals
dice_exp_random = exp_random.generate_counterfactuals(
query_instances, total_CFs=4, desired_class="opposite",
permitted_range={'age': [22, 50], 'hours_per_week': [40, 60]})
100%|█████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:01<00:00, 1.17it/s]
[17]:
dice_exp_random.visualize_as_dataframe(show_only_changes=True)
Query instance (original outcome : 0)
age | workclass | education | marital_status | occupation | race | gender | hours_per_week | income | |
---|---|---|---|---|---|---|---|---|---|
0 | 27 | Private | School | Single | Blue-Collar | White | Male | 40 | 0 |
Diverse Counterfactual set (new outcome: 1)
age | workclass | education | marital_status | occupation | race | gender | hours_per_week | income | |
---|---|---|---|---|---|---|---|---|---|
0 | 37 | - | Prof-school | - | - | - | - | - | 1 |
1 | 42 | - | Prof-school | - | - | - | - | - | 1 |
2 | 38 | - | - | - | - | - | - | 60 | 1 |
3 | 38 | - | Prof-school | - | - | - | - | - | 1 |
Query instance (original outcome : 1)
age | workclass | education | marital_status | occupation | race | gender | hours_per_week | income | |
---|---|---|---|---|---|---|---|---|---|
0 | 31 | Self-Employed | Some-college | Married | Sales | Other | Male | 60 | 1 |
Diverse Counterfactual set (new outcome: 0)
age | workclass | education | marital_status | occupation | race | gender | hours_per_week | income | |
---|---|---|---|---|---|---|---|---|---|
0 | - | - | - | Divorced | Other/Unknown | - | - | - | 0 |
1 | 24 | - | - | - | - | White | - | - | 0 |
2 | - | - | School | - | - | - | - | 52 | 0 |
3 | - | Private | - | - | - | - | - | - | 0 |
2. Genetic Algorithm
Here, we show how to use DiCE can be used to generate CFs for any ML model by using the genetic algorithm to find the best counterfactuals close to the query point. The genetic algorithm converges quickly, and promotes diverse counterfactuals.
Training a custom ML model
Currently, the genetic algorithm method works with scikit-learn models. We will use the same model as shown previously in the notebook. Support for Tensorflow 1&2 and Pytorch will be implemented soon.
Generate diverse counterfactuals
[18]:
# initiate DiceGenetic
exp_genetic = dice_ml.Dice(d, m, method='genetic')
[19]:
# generate counterfactuals
dice_exp_genetic = exp_genetic.generate_counterfactuals(query_instances, total_CFs=4, desired_class="opposite", verbose=True)
0%| | 0/2 [00:00<?, ?it/s]
Initializing initial parameters to the genetic algorithm...
Initialization complete! Generating counterfactuals...
50%|██████████████████████████████████████████▌ | 1/2 [00:03<00:03, 3.78s/it]
Diverse Counterfactuals found! total time taken: 00 min 03 sec
Initializing initial parameters to the genetic algorithm...
Initialization complete! Generating counterfactuals...
100%|█████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:04<00:00, 2.46s/it]
Diverse Counterfactuals found! total time taken: 00 min 01 sec
[20]:
dice_exp_genetic.visualize_as_dataframe(show_only_changes=True)
Query instance (original outcome : 0)
age | workclass | education | marital_status | occupation | race | gender | hours_per_week | income | |
---|---|---|---|---|---|---|---|---|---|
0 | 27 | Private | School | Single | Blue-Collar | White | Male | 40 | 0 |
Diverse Counterfactual set (new outcome: 1)
age | workclass | education | marital_status | occupation | race | gender | hours_per_week | income | |
---|---|---|---|---|---|---|---|---|---|
0 | - | - | Some-college | Married | Professional | - | - | - | 1 |
0 | - | - | Bachelors | Married | Sales | - | - | - | 1 |
0 | - | - | Assoc | Married | White-Collar | - | - | - | 1 |
0 | - | - | Some-college | Married | - | Other | - | - | 1 |
Query instance (original outcome : 1)
age | workclass | education | marital_status | occupation | race | gender | hours_per_week | income | |
---|---|---|---|---|---|---|---|---|---|
0 | 31 | Self-Employed | Some-college | Married | Sales | Other | Male | 60 | 1 |
Diverse Counterfactual set (new outcome: 0)
age | workclass | education | marital_status | occupation | race | gender | hours_per_week | income | |
---|---|---|---|---|---|---|---|---|---|
0 | - | - | School | - | Professional | White | - | - | 0 |
0 | - | - | HS-grad | - | Blue-Collar | White | - | - | 0 |
0 | - | Private | School | - | Blue-Collar | - | - | - | 0 |
0 | - | - | HS-grad | - | White-Collar | White | - | - | 0 |
We can also ensure that the genetic algorithm also only varies the features that you wish to vary
[21]:
# generate counterfactuals
dice_exp_genetic = exp_genetic.generate_counterfactuals(
query_instances, total_CFs=2, desired_class="opposite",
features_to_vary=['workclass', 'education', 'occupation', 'hours_per_week'])
dice_exp_genetic.visualize_as_dataframe(show_only_changes=True)
100%|█████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:14<00:00, 7.28s/it]
Query instance (original outcome : 0)
age | workclass | education | marital_status | occupation | race | gender | hours_per_week | income | |
---|---|---|---|---|---|---|---|---|---|
0 | 27 | Private | School | Single | Blue-Collar | White | Male | 40 | 0 |
Diverse Counterfactual set (new outcome: 1)
age | workclass | education | marital_status | occupation | race | gender | hours_per_week | income | |
---|---|---|---|---|---|---|---|---|---|
0 | - | - | Bachelors | - | White-Collar | - | - | 71 | 1 |
0 | - | - | Assoc | - | White-Collar | - | - | 82 | 1 |
Query instance (original outcome : 1)
age | workclass | education | marital_status | occupation | race | gender | hours_per_week | income | |
---|---|---|---|---|---|---|---|---|---|
0 | 31 | Self-Employed | Some-college | Married | Sales | Other | Male | 60 | 1 |
Diverse Counterfactual set (new outcome: 0)
age | workclass | education | marital_status | occupation | race | gender | hours_per_week | income | |
---|---|---|---|---|---|---|---|---|---|
0 | - | - | Assoc | - | - | - | - | - | 0 |
0 | - | Government | - | - | Professional | - | - | - | 0 |
You can also constrain the features to vary only within the permitted range
[22]:
# generate counterfactuals
dice_exp_genetic = exp_genetic.generate_counterfactuals(
query_instances, total_CFs=2, desired_class="opposite",
permitted_range={'age': [22, 50], 'hours_per_week': [40, 60]})
dice_exp_genetic.visualize_as_dataframe(show_only_changes=True)
100%|█████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:02<00:00, 1.25s/it]
Query instance (original outcome : 0)
age | workclass | education | marital_status | occupation | race | gender | hours_per_week | income | |
---|---|---|---|---|---|---|---|---|---|
0 | 27 | Private | School | Single | Blue-Collar | White | Male | 40 | 0 |
Diverse Counterfactual set (new outcome: 1)
age | workclass | education | marital_status | occupation | race | gender | hours_per_week | income | |
---|---|---|---|---|---|---|---|---|---|
0 | - | - | Some-college | Married | Professional | - | - | - | 1 |
0 | - | - | Bachelors | Married | Sales | - | - | - | 1 |
Query instance (original outcome : 1)
age | workclass | education | marital_status | occupation | race | gender | hours_per_week | income | |
---|---|---|---|---|---|---|---|---|---|
0 | 31 | Self-Employed | Some-college | Married | Sales | Other | Male | 60 | 1 |
Diverse Counterfactual set (new outcome: 0)
age | workclass | education | marital_status | occupation | race | gender | hours_per_week | income | |
---|---|---|---|---|---|---|---|---|---|
0 | - | Private | Bachelors | - | White-Collar | - | - | - | 0 |
0 | - | Private | Assoc | - | - | White | - | - | 0 |
3. Querying a KD Tree
Here, we show how to use DiCE can be used to generate CFs for any ML model by finding the closest points in the dataset that give the output as the desired class. We do this efficiently by building KD trees for each class, and querying the KD tree of the desired class to find the k closest counterfactuals from the dataset. The idea behind finding the closest points from the training data itself is to ensure that the counterfactuals displayed are feasible.
Training a custom ML model
Currently, the KD tree algorithm method works with scikit-learn models. Again, we will use the same model as shown previously in the notebook. Support for Tensorflow 1&2 and Pytorch will be implemented soon.
Generate diverse counterfactuals
[23]:
# initiate DiceKD
exp_KD = dice_ml.Dice(d, m, method='kdtree')
[24]:
# generate counterfactuals
dice_exp_KD = exp_KD.generate_counterfactuals(query_instances, total_CFs=4, desired_class="opposite")
100%|█████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:01<00:00, 1.01it/s]
[25]:
dice_exp_KD.visualize_as_dataframe(show_only_changes=True)
Query instance (original outcome : 0)
age | workclass | education | marital_status | occupation | race | gender | hours_per_week | income | |
---|---|---|---|---|---|---|---|---|---|
0 | 27 | Private | School | Single | Blue-Collar | White | Male | 40 | 0 |
Diverse Counterfactual set (new outcome: 1)
age | workclass | education | marital_status | occupation | race | gender | hours_per_week | income | |
---|---|---|---|---|---|---|---|---|---|
11233 | 26 | - | Bachelors | Married | - | - | - | - | 1 |
3155 | 28 | - | Assoc | Married | - | - | - | - | 1 |
2571 | - | - | Some-college | Married | - | Other | - | - | 1 |
8361 | - | - | Some-college | Married | Professional | - | - | - | 1 |
Query instance (original outcome : 1)
age | workclass | education | marital_status | occupation | race | gender | hours_per_week | income | |
---|---|---|---|---|---|---|---|---|---|
0 | 31 | Self-Employed | Some-college | Married | Sales | Other | Male | 60 | 1 |
Diverse Counterfactual set (new outcome: 0)
age | workclass | education | marital_status | occupation | race | gender | hours_per_week | income | |
---|---|---|---|---|---|---|---|---|---|
15736 | 32 | - | - | - | - | - | - | - | 0 |
20071 | 29 | - | - | - | Blue-Collar | - | - | - | 0 |
17010 | - | - | - | - | - | White | - | - | 0 |
15712 | - | - | Assoc | - | Professional | - | - | - | 0 |
Selecting the features to vary
Here, again, you can vary only features that you wish to vary. Please note that the output counterfactuals are only from the training data. If you want other counterfactuals, please use the random or genetic method.
[26]:
# generate counterfactuals
dice_exp_KD = exp_KD.generate_counterfactuals(
query_instances, total_CFs=4, desired_class="opposite",
features_to_vary=['age', 'workclass', 'education', 'occupation', 'hours_per_week'])
100%|█████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:02<00:00, 1.45s/it]
[27]:
dice_exp_KD.visualize_as_dataframe(show_only_changes=True)
Query instance (original outcome : 0)
age | workclass | education | marital_status | occupation | race | gender | hours_per_week | income | |
---|---|---|---|---|---|---|---|---|---|
0 | 27 | Private | School | Single | Blue-Collar | White | Male | 40 | 0 |
Diverse Counterfactual set (new outcome: 1)
age | workclass | education | marital_status | occupation | race | gender | hours_per_week | income | |
---|---|---|---|---|---|---|---|---|---|
1267 | 25 | - | Assoc | - | Sales | - | - | - | 1 |
Query instance (original outcome : 1)
age | workclass | education | marital_status | occupation | race | gender | hours_per_week | income | |
---|---|---|---|---|---|---|---|---|---|
0 | 31 | Self-Employed | Some-college | Married | Sales | Other | Male | 60 | 1 |
Diverse Counterfactual set (new outcome: 0)
age | workclass | education | marital_status | occupation | race | gender | hours_per_week | income | |
---|---|---|---|---|---|---|---|---|---|
15736 | 32 | - | - | - | - | - | - | - | 0 |
20071 | 29 | - | - | - | Blue-Collar | - | - | - | 0 |
15712 | - | - | Assoc | - | Professional | - | - | - | 0 |
554 | - | Private | Bachelors | - | White-Collar | - | - | - | 0 |
Selecting the feature ranges
Here, you can control the ranges of continuous features.
[28]:
# generate counterfactuals
dice_exp_KD = exp_KD.generate_counterfactuals(
query_instances, total_CFs=5, desired_class="opposite",
permitted_range={'age': [30, 50], 'hours_per_week': [40, 60]})
dice_exp_KD.visualize_as_dataframe(show_only_changes=True)
100%|█████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:02<00:00, 1.38s/it]
Query instance (original outcome : 0)
age | workclass | education | marital_status | occupation | race | gender | hours_per_week | income | |
---|---|---|---|---|---|---|---|---|---|
0 | 27 | Private | School | Single | Blue-Collar | White | Male | 40 | 0 |
No counterfactuals found!
Query instance (original outcome : 1)
age | workclass | education | marital_status | occupation | race | gender | hours_per_week | income | |
---|---|---|---|---|---|---|---|---|---|
0 | 31 | Self-Employed | Some-college | Married | Sales | Other | Male | 60 | 1 |
Diverse Counterfactual set (new outcome: 0)
age | workclass | education | marital_status | occupation | race | gender | hours_per_week | income | |
---|---|---|---|---|---|---|---|---|---|
15736 | 32 | - | - | - | - | - | - | - | 0 |
17010 | - | - | - | - | - | White | - | - | 0 |
12896 | - | - | Assoc | - | - | White | - | - | 0 |
15712 | - | - | Assoc | - | Professional | - | - | - | 0 |
22686 | - | Private | Bachelors | - | - | White | - | - | 0 |