# Explaining Multi-class XGBoost Models with SHAP

Posted on May 12, 2019 in posts

These days, when people talk about machine learning, they are usually referring to the modern nonlinear methods that tend to win Kaggle competetitions: Random Forests, Gradient Boosted Trees, XGBoost, or the various forms of Neural Networks. People talk about how these modern methods generally provide lower bias and are able to better optimize an objective function than the more traditional methods like Linear Regression or Logistic Regression (for classification). However, when organizations- specifically organizations in heavily regulated industries like finance, healthcare, and insurance - talk about machine learning, they tend to talk about how they can't implement machine learning in their business because it's too much of a "black box."

These organizations make underwriting and pricing decisions based on predictons for annual income, credit default risk, probability of death, disease risk, and many others. They worry about a series of regulatory requirements forcing them to explain why a particular decision was reached on a single sample, in a clear and defensible manner. Nobody wants to be the first to test a new regulatory standard and it is far easier to continue business-as-usual, so these organizations like their tidy formulas and interpretable coefficients and they won't give them up without good reason. From the perspective of a data scientist, that good reason is lower model bias leading to better predictions further leading to better customer experiences, a reduction in regulatory issues, and ultimately a stronger competitive advantage and higher profits for the enterprise. The first mover has much to gain, but also a lot to lose. By applying the techniques discussed here it should become clear there are ways to create value and effectively mitigate the regulatory risks involved.

In recent years the "black box" nature of nonparametric machine learnings models has given way to several methods that help us crack open what is happening inside a complex model. Thanks to ongoing research in the field of ML model explainability, we now have at least five good methods with which we can explore the inner workings of our models.

An exhaustive review of all methods is outside the scope of this article, but below is a non-exhaustive set of links for those interested in further research:

The goals of this post are to:

- Build an XGBoost binary classifier
- Showcase SHAP to explain model predictions so a regulator can understand
- Discuss some edge cases and limitations of SHAP in a multi-class problem

In a well-argued piece, one of the team members behind SHAP explains why this is the ideal choice for explaining ML models and is superior to other methods. SHAP stands for 'Shapley Additive Explanations' and it applies game theory to local explanations to create consistent and locally accurate additive feature attributions. If this doesn't make a lot of sense, don't worry, the graphs below will mostly speak for themselves.

In this post I will demonstrate a simple XGBoost example for a binary and multiclass classification problem, and how to use SHAP to effectively explain what is going on under the hood. I will begin with a binary classifier using the Titanic Survival Dataset.

Then we will explain the predictions using SHAP plots like this one:

## 1. Build the Model¶

First we import the necessary packages:

```
import xgboost as xgb
from sklearn.model_selection import train_test_split, StratifiedKFold, GridSearchCV, cross_validate
import matplotlib.pylab as pl
import pandas as pd
import numpy as np
import sklearn.metrics as metrics
import matplotlib.pyplot as plt
```

Next, we import our data an look at a preview

```
titanic = pd.read_csv('./extra/data/titanic.csv')
print(titanic.shape)
titanic.head()
```

```
#Check for missing values
print(titanic.isna().sum()[titanic.isna().sum()>0])
```

Our target column is the binary `survived`

and we will use every column except `name`

, `ticket`

, and `cabin`

. First, we will create an engineered feature `family_size`

by adding in the siblings and parents features. Then, we must deal with missing values in the `age`

and `embarked`

columns so we will impute values. Next, we need to dummy encode the two remaining text columns `sex`

and `embarked`

. Finally, we can drop extra columns, assign our X and y, and train our model.

```
#store off the original data for later
titanic_orig = titanic
#Create family size engineered feature
titanic['family_size'] = titanic['sibsp'] + titanic['parch']
#impute missing age with mean age
titanic['age'] = titanic['age'].fillna(titanic['age'].mean())
#impute missing embarked with mode
titanic['embarked'] = titanic['embarked'].fillna(titanic['embarked'].mode().values[0])
#dummy encode text columns
cols_to_encode = ['pclass','sex','embarked'] #,'cabin_letter'
titanic = pd.get_dummies(titanic, columns=cols_to_encode, drop_first=True)
#drop extra columns from X
cols_to_drop = ['survived','name','ticket','cabin']
y = titanic.loc[:,'survived']
X = titanic.drop(cols_to_drop, axis='columns')
display(X.head())
print('Survived class distribution: \n', y.value_counts(normalize = True),'\n')
```

Above, we see a sample of our final training dataset and below it the distribution of the survived column - only 38% of passengers survived! Because decision tree models are robust to multicollinearity and scaling - and because this is a very simple dataset - we can skip the usual EDA and data normalization procedures and jump to model training and evaluation. Below we train an XGBoost binary classifier using k-fold cross-validation to tune our hyperparameters to ensure an optimal model fit.

```
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)
xgbcl = xgb.XGBClassifier(base_score=0.5, booster='gbtree', colsample_bylevel=1.0,
gamma=0.0, max_delta_step=0.0, min_child_weight=1.0,
missing=None, n_jobs=-1, objective='binary:logistic', random_state=42, reg_alpha=0.0,
reg_lambda=1.0, scale_pos_weight=1.0, tree_method='auto')
kfold = StratifiedKFold(n_splits=3, shuffle=True, random_state=42)
param_grid = {
'colsample_bytree':[.75,1],
'learning_rate':[0.01,0.05,0.1,0.3,0.5],
'max_depth':[1,2,3,5],
'subsample':[.75,1],
'n_estimators': list(range(50, 400, 50))
}
grid_search = GridSearchCV(estimator=xgbcl, scoring='roc_auc', param_grid=param_grid, n_jobs=-1, cv=kfold)
grid_result = grid_search.fit(X_train, y_train)
print(f'Best: {grid_result.best_score_} using {grid_result.best_params_}','\n')
```

We have now found our optimal hyperparameters optimizing for area under the Receiver Operating Characteristic (AUC ROC). Next, we will use those optimal hyperparameters to train our final model but first, because the dataset is so small, we will do a final k-fold cross-validation to get stable error metrics and ensure a good fit.

```
#Set our final hyperparameters to the tuned values
xgbcl = xgb.XGBClassifier(base_score=0.5, booster='gbtree', colsample_bylevel=1.0,
gamma=0.0, max_delta_step=0.0, min_child_weight=1.0,
missing=None, n_jobs=-1, objective='binary:logistic', random_state=42, reg_alpha=0.0,
reg_lambda=1.0, scale_pos_weight=1.0, tree_method='auto',
colsample_bytree = grid_result.best_params_['colsample_bytree'],
learning_rate = grid_result.best_params_['learning_rate'],
max_depth = grid_result.best_params_['max_depth'],
subsample = grid_result.best_params_['subsample'],
n_estimators = grid_result.best_params_['n_estimators'])
kfold = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
#refit the model on k-folds to get stable avg error metrics
scores = cross_validate(estimator=xgbcl, X=X_train, y=y_train, cv=kfold, n_jobs=-1,
scoring=['accuracy', 'roc_auc', 'precision', 'recall', 'f1'])
print('Training 5-fold Cross Validation Results:\n')
print('AUC: ', scores['test_roc_auc'].mean())
print('Accuracy: ', scores['test_accuracy'].mean())
print('Precision: ', scores['test_precision'].mean())
print('Recall: ', scores['test_recall'].mean())
print('F1: ', scores['test_f1'].mean(), '\n')
```

Above, we see a good AUC in the high 80's, and an accuracy in the 80's which is far better than guessing `0`

every time yielding only a 61% accuracy. Next, we'll fit the final model and visualize the AUC. We can improve further by determining whether we care more about false positives or false negatives and tuning our prediction threshold accordingly, but this is good enough to stop and show off SHAP.

```
#Fit the final model
xgbcl.fit(X_train, y_train)
#Generate predictions against our training and test data
pred_train = xgbcl.predict(X_train)
proba_train = xgbcl.predict_proba(X_train)
pred_test = xgbcl.predict(X_test)
proba_test = xgbcl.predict_proba(X_test)
# Print model report
print("Classification report (Test): \n")
print(metrics.classification_report(y_test, pred_test))
print("Confusion matrix (Test): \n")
print(metrics.confusion_matrix(y_test, pred_test)/len(y_test))
print ('\nTrain Accuracy:', metrics.accuracy_score(y_train, pred_train))
print ('Test Accuracy:', metrics.accuracy_score(y_test, pred_test))
print ('\nTrain AUC:', metrics.roc_auc_score(y_train, proba_train[:,1]))
print ('Test AUC:', metrics.roc_auc_score(y_test, proba_test[:,1]))
# calculate the fpr and tpr for all thresholds of the classification
train_fpr, train_tpr, train_threshold = metrics.roc_curve(y_train, proba_train[:,1])
test_fpr, test_tpr, test_threshold = metrics.roc_curve(y_test, proba_test[:,1])
train_roc_auc = metrics.auc(train_fpr, train_tpr)
test_roc_auc = metrics.auc(test_fpr, test_tpr)
import matplotlib.pyplot as plt
fig, ax = plt.subplots(figsize=[7,5])
plt.title('Receiver Operating Characteristic')
plt.plot(train_fpr, train_tpr, 'b', label = 'Train AUC = %0.2f' % train_roc_auc)
plt.plot(test_fpr, test_tpr, 'g', label = 'Test AUC = %0.2f' % test_roc_auc)
plt.legend(loc = 'lower right')
plt.plot([0, 1], [0, 1],'r--')
plt.xlim([0, 1])
plt.ylim([0, 1])
plt.ylabel('True Positive Rate')
plt.xlabel('False Positive Rate')
plt.show()
# plot feature importance
xgb.plot_importance(xgbcl, importance_type='gain');
```

Above, we see the final model is making decent predictions with minor overfit. Using the built-in XGBoost feature importance method we see which attributes most reduced the loss function on the training dataset, in this case `sex_male`

was the most important feature by far, followed by `pclass_3`

which represents a 3rd class the ticket. We know from historical accounts that there were not enough lifeboats for everyone and two groups were prioritized: first class passengers and women with children. So, `sex`

and `pclass`

are justifiably important, but this method provides precious little to explain precisely why a prediction was made on a case-by-case basis.

Now that we have a trained model, let us make a prediction on a random row of data, and then use SHAP to understand why this was predicted.

## 2. Use SHAP to Explain the Model¶

```
#take a random row of data
X_rand = X.sample(1, random_state = 5)
display(titanic_orig.iloc[X_rand.index])
#generate a prediction for the row
rand_pred = xgbcl.predict(X_rand)
rand_proba = xgbcl.predict_proba(X_rand)
print(f'Predicted Probability (no, yes): {rand_proba}')
print(f'Predicted Survival (0 = no, 1 = yes): {rand_pred}','\n')
```

We see the input data of row 126 from the dataset belonging to a 29 year old male Mr. Martin McMahon posessing a 3rd class ticket, and the output prediction was `0`

with an 87% probability. Because we are using the default threshold of 50% for a prediction one way or another, 87% is more than enough to trigger a prediction of `0`

. But **why** did the model determine an 87% probability of death and only an 13% probability of survival for this particular passenger? This is the question a regulator wants answered if this passenger had survived and complains to the authority that he is very much alive and takes great offense at our inaccurate prediction. In this case, the model correctly predicted his unfortunate end, but even when we are right we still need to understand why. To gain this understanding we will import the SHAP package and explain this row of data.

Before we do, its worth mentioning how SHAP actually works. Greatly oversimplyfing, SHAP takes the base value for the dataset, in our case a 0.38 chance of survival for anyone aboard, and goes through the input data row-by-row and feature-by-feature varying its values to detect how it changes the base prediction holding all-else-equal for that row. For non-linear models the order in which the features are added matters so SHAP values arise from averaging the values across all possible orderings. In so doing, SHAP is essentially building a mini explainer model for a single row-prediction pair to explain how this prediction was reached. The full source text is available here. Now, let's have a look at SHAP.

```
import shap
explainer = shap.TreeExplainer(xgbcl, model_output='probability', feature_dependence='independent', data=X)
shap_values = explainer.shap_values(X.iloc[X_rand.index])
shap.initjs()
display(shap.force_plot(explainer.expected_value, shap_values, X.iloc[X_rand.index]))
#Display all features and SHAP values
display(pd.DataFrame(data=shap_values, columns=X.columns, index=[126]).transpose().sort_values(by=126, ascending=True))
```

Seeing a SHAP plot is like seeing the magician behind the green curtain in the Wizard of Oz. Its beauty lies in how the distribution of feature effects are additive (adding up to the total predicted probability for the `1`

class) and that they are localized to a single prediction. The table below the plot is the full list of features and their SHAP additive weights. We can see very clearly the model brought down his probability of survival by 16% because `sex_male == 1`

, and by an additional 5% because `pclass_3 == 1`

. On the other hand, in his case a `family_size == 0`

slightly helped his odds along with `embarked_S == 0`

. When deciding whether an input attribute value helped or hurt his chances SHAP assumes an *all else equal* logic - just as you would interpret coefficients (m) in a parametric model (y = mx + b). Furthermore, we can empirically show the additive nature of SHAP holds true: `base value (0.38 survival rate for any given passenger) + SUM(SHAP values) == Predicted Probability of Survival`

.

```
#Prove the sum of SHAP values and base_value sum to our prediction for class 1
import math
shap_sum = explainer.expected_value + shap_values.sum()
assert(math.isclose(shap_sum,rand_proba[0][1],rel_tol=1e-06)) #if this was False, and error would be thrown
```

Now we can use SHAP to view how the features affected the probabilities for a larger sample.

```
shap_values = explainer.shap_values(X)
shap.summary_plot(shap_values, X)
```

Above, we see being male is generally a bad thing, but the horizontal dispersion also implies that it depends on other factors. If it did not, we would see a single blob of red on the `sex_male`

line instead of points spread across the X-axis with varying negative SHAP values. At a glance we also see high values for `fare`

tended to aid in survival probability, meanwhile low values for `age`

greatly helped the survival odds. By looking at the SHAP dependence plots we can better understand the interdependence of the features.

```
shap.dependence_plot('sex_male', shap_values, X, interaction_index='pclass_3')
shap.dependence_plot('fare', shap_values, X) #when we don't specify an interaction_index, the strongest one is automatically chosen for us
shap.dependence_plot('age', shap_values, X, interaction_index='sex_male')
```

With SHAP dependence plots we can see how `sex_male`

influences the prediction and how in turn it is influenced by `pclass_3`

. We see a clear benefit on survival of being a woman, and further being in 3rd class hurt your odds as a woman but had a lesser effect if you were a man (because the survival odds are already so bad). The model believes it was better to be a woman in 3rd class than almost any man. The second plot illustrates that a higher `fare`

paid generally conferred a survival benefit, likely due to its influence on cabin class and therefore proximity to lifeboats. Finally, `age`

is interesting because we see a clear benefit to being a child below the age of 10 through an increase in probability of survival, but then we see an interesting spike in the 25-35 range. Further analysis would be warranted but this could be due to the most common ages of the parents who were prioritized alongside their children. We also see more evidence that being a woman at almost any age is better than being a man in terms of survivability.

If a regulator were to ask why a decision was made, SHAP can be used to demonstrate exactly which factors added up to the final decision and how they interacted with each other, even in a complex gradient boosted tree ensemble.

## 3. Multi-class SHAP Example¶

So now, let us move to a multi-class example. In this case its a bit more complex because SHAP has certain multi-class limitations. I'll keep the model building short so we can focus on the differences from binary classification with SHAP.

```
#For the multi-class example we use iris dataset
X,y = shap.datasets.iris()
X_display,y_display = shap.datasets.iris(display=True) # human readable feature values
# create a train/test split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
xgbc = xgb.XGBClassifier(learning_rate=0.5,
n_estimators=150,
max_depth=6,
min_child_weight=0,
gamma=0,
reg_lambda=1,
subsample=1,
colsample_bytree=0.75,
scale_pos_weight=1,
objective='multi:softprob',
num_class=3,
random_state=42)
mcl = xgbc.fit(X_train, y_train, eval_metric='mlogloss')
pred = mcl.predict(X_test)
proba = mcl.predict_proba(X_test)
#store our flower labels for results
y_map = pd.DataFrame(data=y, columns=['class'])
y_map['label'] = y_map['class'].map({0:'Iris-setosa',1:'Iris-versicolor',2:'Iris-virginica'})
```

```
#Attempt to use SHAP on multi-class
X_rand = X.sample(1, random_state=42)
idx = X_rand.index.values[0]
#This line will not work for a multi-class model, so we comment out
#explainer = shap.TreeExplainer(mcl, model_output='probability', feature_dependence='independent', data=X)
explainer = shap.TreeExplainer(mcl)
shap_values = explainer.shap_values(X.iloc[idx])
shap.initjs()
for which_class in range(0,3):
display(shap.force_plot(explainer.expected_value[which_class], shap_values[which_class], X_rand))
#Display all features and SHAP values
df1=pd.DataFrame(data=shap_values[0], columns=X.columns, index=[0])
df2=pd.DataFrame(data=shap_values[1], columns=X.columns, index=[1])
df3=pd.DataFrame(data=shap_values[2], columns=X.columns, index=[2])
df=pd.concat([df1,df2,df3])
display(df.transpose())
```