Source code for evalml.model_understanding.force_plots

"""Force plots."""
import numpy as np
import shap
from shap import initjs

from evalml.model_understanding.prediction_explanations import (
    explain_predictions,
)
from evalml.utils import jupyter_check


[docs]def graph_force_plot(pipeline, rows_to_explain, training_data, y, matplotlib=False): """Function to generate force plots for the desired rows of the training data. Args: pipeline (PipelineBase): The pipeline to generate the force plot for. rows_to_explain (list[int]): A list of the indices indicating which of the rows of the training_data to explain. training_data (pandas.DataFrame): The data used to train the pipeline. y (pandas.Series): The target data for the pipeline. matplotlib (bool): flag to display the force plot using matplotlib (outside of jupyter) Defaults to False. Returns: list[dict[shap.AdditiveForceVisualizer]]: The same as force_plot(), but with an additional key in each dictionary for the plot. """ def gen_force_plot(shap_values, training_data, expected_value, matplotlib): """Helper function to generate a single force plot.""" # Ensure there are as many features as shap values. assert training_data.shape[1] == len(shap_values) # TODO: Update this to make the force plot display multi-row array visualizer. training_data_sample = training_data.iloc[0] shap_plot = shap.force_plot( expected_value, np.array(shap_values), training_data_sample, matplotlib=matplotlib, ) return shap_plot if jupyter_check(): initjs() shap_plots = force_plot(pipeline, rows_to_explain, training_data, y) for row in shap_plots: for cls in row: cls_dict = row[cls] cls_dict["plot"] = gen_force_plot( cls_dict["shap_values"], training_data[cls_dict["feature_names"]], cls_dict["expected_value"], matplotlib=matplotlib, ) return shap_plots
[docs]def force_plot(pipeline, rows_to_explain, training_data, y): """Function to generate the data required to build a force plot. Args: pipeline (PipelineBase): The pipeline to generate the force plot for. rows_to_explain (list[int]): A list of the indices of the training_data to explain. training_data (pandas.DataFrame): The data used to train the pipeline. y (pandas.Series): The target data. Returns: list[dict]: list of dictionaries where each dict contains force plot data. Each dictionary entry represents the explanations for a single row. For single row binary force plots: [{'malignant': {'expected_value': 0.37, 'feature_names': ['worst concave points', 'worst perimeter', 'worst radius'], 'shap_values': [0.09, 0.09, 0.08], 'plot': AdditiveForceVisualizer}] For two row binary force plots: [{'malignant': {'expected_value': 0.37, 'feature_names': ['worst concave points', 'worst perimeter', 'worst radius'], 'shap_values': [0.09, 0.09, 0.08], 'plot': AdditiveForceVisualizer}, {'malignant': {'expected_value': 0.29, 'feature_names': ['worst concave points', 'worst perimeter', 'worst radius'], 'shap_values': [0.05, 0.03, 0.02], 'plot': AdditiveForceVisualizer}] Raises: TypeError: If rows_to_explain is not a list. TypeError: If all values in rows_to_explain aren't integers. """ if not isinstance(rows_to_explain, list): raise TypeError( "rows_to_explain should be provided as a list of row index integers!" ) if not all([isinstance(x, int) for x in rows_to_explain]): raise TypeError("rows_to_explain should only contain integers!") explanations = [] prediction_explanations = explain_predictions( pipeline, training_data, y, rows_to_explain, top_k_features=len(training_data.columns), include_shap_values=True, output_format="dict", ) row_explanations = prediction_explanations["explanations"] for row_explanation in row_explanations: row_exp = row_explanation["explanations"] row_exp_dict = {} for cls_exp in row_exp: cls = ( cls_exp["class_name"] if cls_exp["class_name"] is not None else "regression" ) expected_value = cls_exp["expected_value"] feature_names = cls_exp["feature_names"] shap_values = cls_exp["quantitative_explanation"] row_exp_dict[cls] = { "expected_value": expected_value, "feature_names": feature_names, "shap_values": shap_values, } explanations.append(row_exp_dict) return explanations