Interpretable Classification#

In this notebook we will fit classification explainable boosting machine (EBM), LogisticRegression, and ClassificationTree models. After fitting them, we will use their glassbox nature to understand their global and local explanations.

This notebook can be found in our examples folder on GitHub.

# install interpret if not already installed
try:
    import interpret
except ModuleNotFoundError:
    !pip install --quiet interpret pandas scikit-learn
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from interpret import show
from interpret.perf import ROC

from interpret import set_visualize_provider
from interpret.provider import InlineProvider
set_visualize_provider(InlineProvider())

df = pd.read_csv(
    "https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data",
    header=None)
df.columns = [
    "Age", "WorkClass", "fnlwgt", "Education", "EducationNum",
    "MaritalStatus", "Occupation", "Relationship", "Race", "Gender",
    "CapitalGain", "CapitalLoss", "HoursPerWeek", "NativeCountry", "Income"
]
X = df.iloc[:, :-1]
y = (df.iloc[:, -1] == " >50K").astype(int)

seed = 42
np.random.seed(seed)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.20, random_state=seed)

Explore the dataset

from interpret.data import ClassHistogram

hist = ClassHistogram().explain_data(X_train, y_train, name='Train Data')
show(hist)

Train the Explainable Boosting Machine (EBM)

from interpret.glassbox import ExplainableBoostingClassifier

ebm = ExplainableBoostingClassifier()
ebm.fit(X_train, y_train)
ExplainableBoostingClassifier()
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.

EBMs are glassbox models, so we can edit them

# post-process monotonize the Age feature
ebm.monotonize("Age", increasing=True)
ExplainableBoostingClassifier()
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.

Global Explanations: What the model learned overall

ebm_global = ebm.explain_global(name='EBM')
show(ebm_global)

Local Explanations: How an individual prediction was made

ebm_local = ebm.explain_local(X_test[:5], y_test[:5], name='EBM')
show(ebm_local, 0)

Evaluate EBM performance

ebm_perf = ROC(ebm).explain_perf(X_test, y_test, name='EBM')
show(ebm_perf)

Let's test out a few other Explainable Models

from interpret.glassbox import LogisticRegression, ClassificationTree

# We have to transform categorical variables to use Logistic Regression and Decision Tree
X = pd.get_dummies(X, prefix_sep='.').astype(float)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.20, random_state=seed)

lr = LogisticRegression(random_state=seed, penalty='l1', solver='liblinear')
lr.fit(X_train, y_train)

tree = ClassificationTree()
tree.fit(X_train, y_train)
<interpret.glassbox._decisiontree.ClassificationTree at 0x7f77e1ce3e80>

Compare performance using the Dashboard

lr_perf = ROC(lr).explain_perf(X_test, y_test, name='Logistic Regression')
show(lr_perf)
tree_perf = ROC(tree).explain_perf(X_test, y_test, name='Classification Tree')
show(tree_perf)

Glassbox: All of our models have global and local explanations

lr_global = lr.explain_global(name='Logistic Regression')
show(lr_global)
tree_global = tree.explain_global(name='Classification Tree')
show(tree_global)