Intro to SHAP, LIME and Model Interpretability

Machine learning interpretability refers to the ability to understand and explain the predictions and decisions made by machine learning models. It plays a crucial role in building trust and confidence in these models, especially when they are used in high-stakes applications such as healthcare, finance or any other sensitive businesses. Some of such methods are LIME and SHAP.

Interpretability methods aim to shed light on the “black box” nature of complex machine learning algorithms by providing insights into how they arrive at their predictions. These methods can be broadly categorised into two types:

  • Global interpretability, which explains the model’s behaviour as a whole, and
  • Local interpretability, which explains individual predictions.

We are going to see how we can interpret the machine learning models using two of the most widely used methods LIME and SHAP.


Check out Wandb for Deep Learning Experimentations


Flow of the Notebook

This notebook is divided into two parts. In the first part we are going to interpret a machine learning model using LIME and in the second part we will perform the similar approach with another library called as SHAP.

Although we are going to perform same operation with both of these libraries, the method by which they work are drastically different. So it is interesting to see both 😀

Without further ado, let’s get started.

Loading data, training and testing model

First step first, we need a data and a model to train on it so that we can interpret its results. For this notebook we are going to use Mobile Dataset from Kaggle. And perform a classification operation using a humanly non-interpretable model, Random forest!

Importing important Libraries

import pandas as pd 
import numpy as np 
import matplotlib.pyplot as plt 
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score

Importing data directly from Kaggle

Following cell would download the data from kaggle directly into colab instance. You would need to insert your kaggle key into the colab environment to run this cell successfully.

You can check out the tutorial here which can guide you through this.

# setting up kaggle API

! kaggle datasets download -d kukuroo3/body-performance-data
! unzip /content/body-performance-data.zip

Get data

Dataset contains one csv file.

df_train = pd.read_csv(path_to_data)

About Body performance Data

As the name suggests, the dataset contains the various parameters such as persons age, gender height as well as the amount of workout they do. Based on these parameters we have to classify every individual into four classes; A, B, C or D.

Columns

The details of every column are given here,

age : 20 ~64
gender : F,M
height_cm : (If you want to convert to feet, * divide by 30.48)
weight_kg
body fat_%
diastolic : diastolic blood pressure (min)
systolic : systolic blood pressure (min)
gripForce
sit and bend forward_cm
sit-ups counts
broad jump_cm
class : A,B,C,D ( A: best) / stratified

The ‘class’ column is our target column. The following command gives us peek into our dataset.

df_train.dtypes

This gives following output,

age                        float64
gender                      object
height_cm                  float64
weight_kg                  float64
body fat_%                 float64
diastolic                  float64
systolic                   float64
gripForce                  float64
sit and bend forward_cm    float64
sit-ups counts             float64
broad jump_cm              float64
class                       object
dtype: object

As you can see we have two object datatypes. We need to convert those to categorical variables.

Let’s look at the data itself,

df_train.head(10)
age	gender	height_cm	weight_kg	body fat_%	diastolic	systolic	gripForce	sit and bend forward_cm	sit-ups counts	broad jump_cm	class
0	27.0	M	172.3	75.24	21.3	80.0	130.0	54.9	18.4	60.0	217.0	C
1	25.0	M	165.0	55.80	15.7	77.0	126.0	36.4	16.3	53.0	229.0	A
2	31.0	M	179.6	78.00	20.1	92.0	152.0	44.8	12.0	49.0	181.0	C
3	32.0	M	174.5	71.10	18.4	76.0	147.0	41.4	15.2	53.0	219.0	B
4	28.0	M	173.8	67.70	17.1	70.0	127.0	43.5	27.1	45.0	217.0	B
5	36.0	F	165.4	55.40	22.0	64.0	119.0	23.8	21.0	27.0	153.0	B
6	42.0	F	164.5	63.70	32.2	72.0	135.0	22.7	0.8	18.0	146.0	D
7	33.0	M	174.9	77.20	36.9	84.0	137.0	45.9	12.3	42.0	234.0	B
8	54.0	M	166.8	67.50	27.6	85.0	165.0	40.4	18.6	34.0	148.0	C
9	28.0	M	185.0	84.60	14.4	81.0	156.0	57.9	12.1	55.0	213.0	B

Convert to categorical

Following is one of the ways to convert columns into categorical variables.

genders = {
    'M' : 0,
    'F' : 1
}

class_label = {
    'A' : 0,
    'B' : 1, 
    'C' : 2,
    'D' : 3
}

df_train['gender'] = df_train['gender'].map(genders)
df_train['class'] = df_train['class'].map(class_label)

Parameters and Target

# Separate parameters from target
X = df_train.drop('class', axis =1)
y = df_train['class']

# train test split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state= 42)

Training and Testing the model for Interpretability

Now that we have data set up, let’s train our model on the training data. This is just a simple task.

Note: Our focus in this notebook is interpretability and not model performance, hence we are skipping all the pre-processing steps for this data.

clf = RandomForestClassifier(n_estimators=100, random_state=42)
clf.fit(X_train, y_train)

Finding the model performance

We can find out the accuracy of the model following way,

pred = clf.predict(X_test)

print(accuracy_score(pred, y_test))

# 0.7450541246733856

For a not tuned model, it’s a pretty good accuracy.

Black box and interpretability

The accuracy score of a random forest model alone does not provide sufficient information for interpreting the model. While accuracy measures the overall correctness of predictions, it does not reveal the underlying factors or decision-making process of the model. Random forest models are ensemble methods that combine multiple decision trees, making them inherently complex and difficult to interpret based solely on accuracy.

That’s why we need some tools which can help us interpret the model better. And hence we are going to learn about LIME and SHAP!

LIME for ML interpretability

lime for interpretability

LIME, which stands for Local Interpretable Model-agnostic Explanations, is a popular technique in machine learning interpretability. It provides local explanations for individual predictions made by complex black-box models.

LIME creates a simplified, interpretable model around a specific prediction by perturbing the input data and observing the changes in the model’s output.

We will see how we can implement LIME interpretation for our classifier.

Install LIME

Installing LIME is similar to installing any other python package. Following cell will install LIME on this instance.

! pip install lime

Import Lime

We can directly import lime using import lime but here we would need two important modules from lime.

  • lime.lime_tabular is a function that enables local interpretability for tabular data in machine learning models by generating interpretable explanations for individual predictions.
  • explain_instance is a method that generates explanations for individual predictions made by machine learning models. It takes an instance of input data and produces a locally interpretable model that approximates the behaviour of the original black-box model for that specific instance.

We will use these modules when we need them.

import lime
from lime import lime_tabular

Build an interpreter

  • training_data: It expects a numpy array or pandas DataFrame containing your training data (X_train), which represents the input features used to train your machine learning model.
  • feature_names: This parameter should contain the names of the features in your X_train dataset. It helps in providing meaningful feature labels in the generated explanations.
  • mode: It specifies the mode of the explainer, which, in this case, is set to ‘classification’. This indicates that the explainer will be used to explain predictions made by a classification model.

By instantiating LimeTabularExplainer with these parameters, we create an explainer object (interpreter) that can be used to generate local explanations for individual predictions on tabular data using the LIME method.

interpretor = lime_tabular.LimeTabularExplainer(
    training_data=np.array(X_train),
    feature_names=X_train.columns,
    mode='classification'
)

Lime.explain_instance

This function plays a crucial role in understanding the factors influencing specific predictions and providing insights into the decision-making process of the model.

We are using the explain_instance method from the LIME library to generate an explanation for a specific data point using our trained LIME explainer (interpreter).

The predict_fn parameter specifies the function we use to make predictions on the data, in this case clf.predict_probawhich predicts the probabilities for each class.

After generating the explanation exp, we use the show_in_notebook method to display the explanation in a notebook.

This helps us visualise and understand the factors that contributed to the prediction made by our model for that specific data point.

exp = interpretor.explain_instance(
    data_row=X_test.iloc[100], ##new data
    predict_fn=clf.predict_proba
)

exp.show_in_notebook(show_table=True)

Important features which affect the classification are highlighted here!

lime interpretablity

Inner workings of LIME for Interpretability

Entire understanding of LIME is beyond the scope of this notebook. LIME is an algorithm proposed in the paper “Why Should I Trust You?”: Explaining the Predictions of Any Classifier.

You can check out the paper here.

Stepwise intro to LIME:

  1. Select an Instance: Choose a specific data instance or input for which you want to interpret the prediction made by your machine learning model.
  2. Perturb the Instance: Create slightly modified versions of the chosen instance by randomly perturbing its features while keeping the target output unchanged. This creates a new set of synthetic instances.
  3. Model Prediction: Pass these synthetic instances through the machine learning model and observe the predicted outputs for each instance.
  4. Weighing Instances: Assign weights to the synthetic instances based on their proximity to the original instance. Instances closer to the original instance are given higher weights, while those farther away receive lower weights.
  5. Train an Explainer Model: Build an “explainer” model, such as a linear regression model, using the synthetic instances and their corresponding model predictions as the training data. The weights assigned in the previous step are used to weigh the importance of each synthetic instance during training.
  6. Interpretation: Analyze the coefficients of the explainer model to understand the contribution of each feature in the original instance’s prediction. Positive or negative coefficients indicate whether a feature increases or decreases the prediction, respectively.

SHAP for ML interpretability

SHAP

SHAP (SHapley Additive exPlanations) is a unified framework for interpreting the predictions of machine learning models. It is based on the concept of Shapley values from cooperative game theory.

SHAP assigns a value to each feature in a prediction, measuring its contribution to the prediction’s outcome. It considers all possible combinations of features and calculates their individual importance. By averaging the contributions across all possible combinations, SHAP provides a comprehensive and fair assessment of feature importance.

We will have better idea of SHAP in the end when we compare it with LIME.

Install SHAP

It is as simple as installing any other python package

! pip install shap

Import SHAP

Here are modules that we would need for our interpretation,

  • shap.TreeExplainer: It used to create an explainer object specifically for tree-based models. The TreeExplainer takes a trained tree-based model (such as RandomForest) as input and prepares it for interpretation using SHAP values.
  • shap.initjs: This function is used to initialise the JavaScript library required to visualize SHAP plots in Jupyter Notebook.
  • shap.force_plot: This function is used to create a force plot visualisation. A force plot illustrates the individual contributions of features towards a specific prediction made by a machine learning model.
  • shap.summary_plot: This function generates a summary plot that provides an overview of feature importance across multiple predictions. It combines the SHAP values for each feature across a dataset to create a visual representation of the overall impact of features on model predictions.

You can check out the other modules in the official documentation here.

import shap 

shap.TreeExplainer

Just like we built an interpreter for LIME, SHAP uses TreeExplainer.

explainer = shap.TreeExplainer(clf) # passsing our classifier for interpretation

Explainability

Once we have our explainer ready we can put on some test cases and see the results using shap.

Following cell gives the code for which we can create explainable graph.

start_index = 1
end_index = 3
shap_values = explainer.shap_values(X_test[start_index:end_index]) # we can put out multiple test cases at once.

Visualise the interpretation : Local Features

As we did in the LIME, SHAP too can visualise the interpretations. This is slightly different than LIME but the core idea remains the same. The parameters with higher shaply values would contribute more for the output.

# Visualise local predictions
shap.initjs()

# Force plot
prediction = clf.predict(X_test[start_index:end_index])[0]
print(f"The RF predicted: {prediction}")
shap.force_plot(explainer.expected_value[1],
                shap_values[1],
                X_test[start_index:end_index])
shap linear graph

Visualise the global features

We can plot the global interpretation for the shaply values in the following way.

# Feature summary
shap.summary_plot(shap_values, X_test)
shap summary

As you can see this is similar to what we have done in the LIME, but we have more control over the parameter comparison with inclusion of local interpretation.

Inner workings of SHAP for Interpretability

Although it is not as intuitive as LIME, shap is a very strong algorithm for machine learning interpretation. To read more about this and the machine learning interpretation in general you can check out this book.

Here’s simple explanation for shap,

  1. Understanding Contributions: The SHAP algorithm aims to understand the contributions of individual features in making predictions with a machine learning model.
  2. All Possible Combinations: It considers all possible combinations of features and analyses their impact on predictions. Think of it as examining different ways puzzle pieces fit together.
  3. Calculating Importance: The algorithm assigns a value to each feature, indicating its importance or contribution to the final prediction. It quantifies how much each puzzle piece matters.
  4. Averaging Contributions: By averaging these contributions across all possible combinations, SHAP provides a fair assessment of the overall importance of each feature. It considers all the different ways the puzzle pieces can come together.
  5. Unveiling Feature Importance: The algorithm helps us understand which features have the greatest influence on the model’s predictions. It reveals the puzzle pieces that are most crucial for completing the picture.
  6. Interpreting the “Black Box”: SHAP algorithm helps us interpret machine learning models, which often act as “black boxes.” It sheds light on the reasoning behind their decisions, making it easier to trust and understand the models.

LIME vs SHAP for Interpretability

Here are five key points comparing Lime and SHAP:

  1. Interpretation Approach: Lime and SHAP take different approaches to interpretation. Lime focuses on generating local explanations by perturbing the input data and building local surrogate models, while SHAP provides both local and global explanations by leveraging Shapley values and cooperative game theory concepts.
  2. Model Compatibility: Lime is model-agnostic, meaning it can be applied to any machine learning model, regardless of its type or architecture. On the other hand, SHAP has specific support for tree-based models through the TreeExplainer, but it can also be applied to other models.
  3. Feature Importance: Lime provides feature importance explanations on a per-instance basis, highlighting the features that were most influential in the prediction for a specific data point. SHAP, on the other hand, provides both local and global feature importance, allowing you to understand the overall impact of features across the entire dataset.
  4. Explanatory Visualizations: Lime typically generates visualizations like feature importance bar charts or textual explanations, focusing on the specific instance being explained. SHAP, in addition to per-instance visualizations, offers various plots like force plots and summary plots that provide comprehensive insights into feature contributions at both local and global levels.
  5. Mathematical Foundation: Lime employs simple heuristics and uses weighted linear regression as an explainer model. SHAP, on the other hand, is grounded in game theory and relies on Shapley values, which have solid mathematical foundations for fair and consistent feature attribution across instances.

In summary, while both Lime and SHAP provide interpretability for machine learning models, they differ in their approach, model compatibility, the scope of feature importance, visualizations, and underlying mathematical foundations. The choice between Lime and SHAP depends on your specific needs and the characteristics of the models you are working with.

Hypothesis Tests vs Model interpretability

Some of you might be wondering why not use hypothesis tests for machine learning interpretation? P-value, t-tests and other hypothesis tests are also give us idea of which parameters are contributing more towards the output. Then why do we need tools like LIME or SHAP?

Importance of Model-agnostic tools for Interpretability

Lime, SHAP, and hypothesis tests like p-value tests and t-tests serve different purposes when it comes to interpreting machine learning models. Here are some key points of comparison:

Lime and SHAP:

  1. Model Interpretability: Lime and SHAP focus on providing interpretability for individual predictions by explaining the contribution of features. They help understand the decision-making process of complex models.

Hypothesis Tests (p-value tests, t-tests):

  1. Statistical Significance: Hypothesis tests assess the statistical significance of relationships or differences between variables. They help determine if an observed effect is likely due to chance or if it represents a meaningful relationship.
  2. Population vs. Individual Instances: Hypothesis tests analyse data at the population level and aim to make inferences about the overall population. Lime and SHAP, on the other hand, provide insights into individual instances or predictions.
  3. Model-agnostic vs. Model-dependent: Lime and SHAP are model-agnostic, meaning they can be applied to any model. Hypothesis tests, such as t-tests, are often model-dependent and assume specific distributional assumptions.

In summary, Lime and SHAP are useful for understanding individual predictions and feature importance in machine learning models, while hypothesis tests like p-value tests and t-tests are employed for assessing statistical significance and making inferences about relationships or differences between variables at the population level. These techniques serve different purposes and can be complementary in providing a comprehensive understanding of the data and the model’s behaviour.


Thanks for reading. Check out my GitHub for other projects.

Leave a Reply

Your email address will not be published. Required fields are marked *

*