EBM Internals - Binary classification#

This is part 2 of a 3 part series describing EBM internals and how to make predictions. For part 1, click here. For part 3, click here.

In this part 2 we’ll cover binary classification, interactions, missing values, ordinals, and the reduced discretization resolutions for interactions. Before reading this part you should be familiar with part 1

# boilerplate
from interpret import show
from interpret.glassbox import ExplainableBoostingClassifier
import numpy as np

from interpret import set_visualize_provider
from interpret.provider import InlineProvider
set_visualize_provider(InlineProvider())
# make a dataset composed of an ordinal categorical, and a continuous feature
X = [["low", 8.0], ["medium", 7.0], ["high", 9.0], [None, None]]
y = ["apples", "apples", "oranges", "oranges"]

# Fit a classification EBM with 1 interaction
# Define an ordinal feature with specified ordering
# Limit the number of interaction bins to force a lower resolution
# Eliminate the validation set to handle the small dataset
ebm = ExplainableBoostingClassifier(
    interactions=1,
    feature_types=[["low", "medium", "high"], 'continuous'], 
    max_interaction_bins=4,
    validation_size=0, outer_bags=1, max_rounds=900, min_samples_leaf=1, min_hessian=1e-9)
ebm.fit(X, y)
show(ebm.explain_global())
/opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/interpret/glassbox/_ebm/_ebm.py:738: UserWarning: Missing values detected. Our visualizations do not currently display missing values. To retain the glassbox nature of the model you need to either set the missing values to an extreme value like -1000 that will be visible on the graphs, or manually examine the missing value score in ebm.term_scores_[term_index][0]
  warn(





print(ebm.classes_)
['apples' 'oranges']

Like all scikit-learn classifiers, we store the list of classes in the ebm.classes_ attribute as a sorted array. In this example our classes are strings, but we also accept integers as we’ll see in part 3

print(ebm.feature_types)
[['low', 'medium', 'high'], 'continuous']

In this example we passed feature_types into the __init__ function of the ExplainableBoostingClassifier. Per scikit-learn convention, this was recorded unmodified in the ebm object.

print(ebm.feature_types_in_)
['ordinal', 'continuous']

The feature_types passed into __init__ were actualized into the base feature types of [‘ordinal’, ‘continuous’]. Following the spirit of scikit-learn’s SLEP007 convention, we recorded this in ebm.feature_types_in_

print(ebm.feature_names_in_)
['feature_0000', 'feature_0001']

Since we did not specify feature names, some default names were created for the model. If we had passed feature_names to the __init__ function of the ExplainableBoostingClassifier, or if we had used a Pandas dataframe with column names, then ebm.feature_names_in_ would have contained those names.

print(ebm.term_features_)
[(0,), (1,), (0, 1)]

Our model contains 3 additive terms. The first two terms are the main effect features, and the 3rd term is the pairwise interaction between the individual features. EBMs are not limited to only main and pair effects. We also support 3-way interactions, 4-way interactions, and higher order interactions as well. If there were any higher order interactions in the model, they would be listed in ebm.term_features_ as further tuples of indexes.

print(ebm.term_names_)
['feature_0000', 'feature_0001', 'feature_0000 & feature_0001']

ebm.term_names_ is a convenience attribute that joins ebm.term_features_ and ebm.feature_names_in_ to create names for each of the additive terms.

ebm.term_names_ is the result of:

term_names = [” & “.join(ebm.feature_names_in_[i] for i in grp) for grp in ebm.term_features_]

print(ebm.bins_)
[[{'low': 1, 'medium': 2, 'high': 3}], [array([7.5, 8.5]), array([8.5])]]

ebm.bins_ is a per-feature attribute. As described in part 1, ebm.bins_ defines how to bin both categorical (‘nominal’ and ‘ordinal’) and ‘continuous’ features.

For categorical features we use a dictionary that maps the category strings to bin indexes.

As described in part 1, continuous feature binning is defined with a list of cut points that partition the continuous range into regions. In this example, our dataset has 3 unique values for the continuous feature: 7.0, 8.0, and 9.0. Similarly to part 1 the main effects in this example have 2 bin cuts that separate these into 3 regions. In this example, the bin cuts for main effects are again 7.5 and 8.5.

EBMs support the ability to reduce the binning resolution when binning a feature for interactions. In the call to __init__ for the ExplainableBoostingClassifier, we specified max_interaction_bins=4, which limited the EBM to creating just 4 bins when binning for interactions. Two of those bins are reserved for ‘missing’ and ‘unknown’ values, which leaves the model with 2 bins for the remaining continuous feature values. We have 3 unique values in our dataset though, so the EBM is forced to decide which of these values to group together and choose a single cut point that separate them into the 2 regions. In this example, the EBM could have chosen any cut point between 7.0 and 9.0. It chose 8.5, which puts the 7.0 and 8.0 values in the lower bin and 9.0 in the upper bin.

The binning definitions for main effect and interactions are stored in a list for each feature in the ebm.bins_ attribute. In this example, ebm.bins_[1] contains a list of arrays: [array([7.5, 8.5]), array([8.5])]. The first array of [7.5, 8.5] at ebm.bins_[1][0] is the binning resolution for main effects. The second array of [8.5] at ebm.bins_[1][1] is the binning resolution used when binning for interactions.

The binning resolution does not stop at pairs. If an even lower resolution is desired for triples, then there would be a 3rd array of bin cuts included in the list. The last item in the list is the binning resolution used for all interaction orders higher than that position. If the EBM had contained just the binning resolution of [7.5, 8.5] in the list, then that resolution would be used for main effects, pairs, triples, and higher order interactions.

print(ebm.term_scores_[0])
[ 10.0755274  -10.05242207 -10.12968962  10.10658429   0.        ]

ebm.term_scores_[0] is the lookup table for the first feature in this example. Since the first feature is an ordinal categorial, we use the dictionary {‘low’: 1, ‘medium’: 2, ‘high’: 3} to lookup which bin to use for each categorical string. If the feature value is NaN, then we use the score at index 0. If the feature value is “low”, we use the score at index 1. If the feature value is “medium”, we use the score at index 2. If the feature value is “high”, we use the score at index 3. If the feature value is anything else, we use the score at index 4.

In this example the 0th bin score is non-zero because we included a missing value in the dataset for this feature.

print(ebm.term_scores_[1])
[ 7.23620483 -7.22427162 -7.36755721  7.35562401  0.        ]

ebm.term_scores_[1] is the lookup table for the second feature in this example. Since the second feature is a continuous feature, we use cut points for binning. The 0th bin index is again reserved for missing values, and the last bin index is again reserved for unknown values. In this example, the 0th bin score is non-zero because we included a missing value in the dataset for this feature.

The ebm.bins_[1] attribute contains a list having 2 arrays of cut points. In this case we are binning a main effects feature, so we use the bins at index 0, which is ebm.bins_[1][0].

print(ebm.term_scores_[2])
[[ 3.28000003 -3.00000002 -2.78000002  0.        ]
 [ 3.00000002 -3.28000003 -2.78000002  0.        ]
 [-2.78000002 -3.28000003  2.95000002  0.        ]
 [-2.78000002 -2.95000002  3.28000003  0.        ]
 [ 0.          0.          0.          0.        ]]

ebm.term_scores_[2] is the lookup table for the pair composed of both features in this example. The features involved in the pair can be found at ebm.term_features_[2]. The pair lookup table is two dimensional, so indexing into it requires two indexes. The first index will be the bin index of the first feature, and the second index will be the bin index of the second feature. Example:

pair_scores = ebm.term_scores_[2]

local_score = pair_scores[(feature_0_index, feature_1_index)]

Sample code

Finally, here’s some code which puts the above considerations together into a function that can make predictions for simplified scenarios. This code does not handle things like regression, multiclass, unknown values, or interactions beyond pairs.

If you need a drop-in complete function that can work in all EBM scenarios, see the multiclass example in part 3 which handles regression and binary classification in addition to multiclass and all the other nuances.

sample_scores = []
for sample in X:
    # start from the intercept for each sample
    score = float(ebm.intercept_)
    
    # We have 3 terms: two main effects, and 1 pair interaction
    for term_idx, features in enumerate(ebm.term_features_):
        # indexing into a tensor requires a multi-dimensional index
        tensor_index = []

        # main effects will have 1 feature, and pairs will have 2 features
        for feature_idx in features:
            feature_val = sample[feature_idx]
            bin_idx = 0  # if missing value, use bin index 0

            if feature_val is not None and feature_val is not np.nan:
                # we bin differently for main effects and pairs,
                # so determine which resolution is needed
                if len(features) == 1 or len(ebm.bins_[feature_idx]) == 1:
                    # this is a main effect or only one bin level
                    # is available, so use the highest resolution bins
                    bins = ebm.bins_[feature_idx][0]
                elif len(features) == 2 or len(ebm.bins_[feature_idx]) == 2:
                    # use the lower resolution bins
                    bins = ebm.bins_[feature_idx][1]
                else:
                    raise Exception("Unsupported bin resolution")

                if isinstance(bins, dict):
                    # categorical feature
                    bin_idx = bins[feature_val]
                else:
                    # continuous feature
                    # add 1 because the 0th bin is reserved for 'missing'
                    bin_idx = np.digitize(feature_val, bins) + 1

            tensor_index.append(bin_idx)
        # local_score is also the local feature importance
        local_score = ebm.term_scores_[term_idx][tuple(tensor_index)]
        score += local_score
    sample_scores.append(score)

logits = np.array(sample_scores)

# use the sigmoid function to convert the logits into probabilities
probabilities = 1 / (1 + np.exp(-logits))

print("probability of " + ebm.classes_[1])
print(ebm.predict_proba(X)[:, 1])
print(probabilities)
probability of oranges
[1.02547269e-09 1.09545709e-09 9.99999999e-01 9.99999999e-01]
[1.02547269e-09 1.09545709e-09 9.99999999e-01 9.99999999e-01]
/tmp/ipykernel_2905/1383054448.py:4: DeprecationWarning:

Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)

For regression, our default link function was the identity link function, so the scores were the actual predictions.

For classification, the scores are logits and we need to apply the inverse link function to calculate the probabilities. For binary classification the inverse link function is the sigmoid function.

Identically to regression in part 1, the ‘local_score’ variable contains the values shown for the local explanations.

show(ebm.explain_local(X, y), 0)