Source code for evalml.data_checks.target_leakage_data_check

"""Data check that checks if any of the features are highly correlated with the target by using mutual information or Pearson correlation."""

from woodwork.config import CONFIG_DEFAULTS

from evalml.data_checks import (
    DataCheck,
    DataCheckActionCode,
    DataCheckActionOption,
    DataCheckMessageCode,
    DataCheckWarning,
)
from evalml.utils.woodwork_utils import infer_feature_types


[docs]class TargetLeakageDataCheck(DataCheck): """Check if any of the features are highly correlated with the target by using mutual information, Pearson correlation, and other correlation metrics. If method='mutual_info', this data check uses mutual information and supports all target and feature types. Other correlation metrics only support binary with numeric and boolean dtypes. This method will return a value in [-1, 1] if other correlation metrics are selected and will returns a value in [0, 1] if mutual information is selected. Correlation metrics available can be found in Woodwork's `dependence_dict method <https://woodwork.alteryx.com/en/stable/generated/woodwork.table_accessor.WoodworkTableAccessor.dependence_dict.html#woodwork.table_accessor.WoodworkTableAccessor.dependence_dict>`_. Args: pct_corr_threshold (float): The correlation threshold to be considered leakage. Defaults to 0.95. method (string): The method to determine correlation. Use 'all' or 'max' for the maximum correlation, or for specific correlation metrics, use their name (ie 'mutual_info' for mutual information, 'pearson' for Pearson correlation, etc). possible methods can be found in Woodwork's `config <https://woodwork.alteryx.com/en/stable/guides/setting_config_options.html?highlight=config#Viewing-Config-Settings>`_, under `correlation_metrics`. Defaults to 'all'. """ def __init__(self, pct_corr_threshold=0.95, method="all"): if pct_corr_threshold < 0 or pct_corr_threshold > 1: raise ValueError( "pct_corr_threshold must be a float between 0 and 1, inclusive.", ) methods = CONFIG_DEFAULTS["correlation_metrics"] if method not in methods: raise ValueError( f"Method '{method}' not in available correlation methods. Available methods include {methods}", ) self.pct_corr_threshold = pct_corr_threshold self.method = method self._method_to_check = "max" if method == "all" else method def _calculate_dependence(self, X, y): highly_corr_cols = [] X2 = X.ww.copy() target_str = "target_y" while target_str in list(X2.columns): target_str += "_y" X2.ww[target_str] = y try: dep_corr = X2.ww.dependence_dict( measures=self.method, target_col=target_str, ) except KeyError: # keyError raised when the target does not appear due to incompatibility with the metric, return [] return [] highly_corr_cols = sorted( [ corr_info["column_1"] for corr_info in dep_corr if abs(corr_info[self._method_to_check]) >= self.pct_corr_threshold ], key=lambda x: X2.columns.tolist().index(x), ) return highly_corr_cols
[docs] def validate(self, X, y): """Check if any of the features are highly correlated with the target by using mutual information, Pearson correlation, and/or Spearman correlation. If `method='mutual_info'` or `'method='max'`, supports all target and feature types. Other correlation metrics only support binary with numeric and boolean dtypes. This method will return a value in [-1, 1] if other correlation metrics are selected and will returns a value in [0, 1] if mutual information is selected. Args: X (pd.DataFrame, np.ndarray): The input features to check. y (pd.Series, np.ndarray): The target data. Returns: dict (DataCheckWarning): dict with a DataCheckWarning if target leakage is detected. Examples: >>> import pandas as pd Any columns that are strongly correlated with the target will raise a warning. This could be indicative of data leakage. >>> X = pd.DataFrame({ ... "leak": [10, 42, 31, 51, 61] * 15, ... "x": [42, 54, 12, 64, 12] * 15, ... "y": [13, 5, 13, 74, 24] * 15, ... }) >>> y = pd.Series([10, 42, 31, 51, 40] * 15) ... >>> target_leakage_check = TargetLeakageDataCheck(pct_corr_threshold=0.95) >>> assert target_leakage_check.validate(X, y) == [ ... { ... "message": "Column 'leak' is 95.0% or more correlated with the target", ... "data_check_name": "TargetLeakageDataCheck", ... "level": "warning", ... "code": "TARGET_LEAKAGE", ... "details": {"columns": ["leak"], "rows": None}, ... "action_options": [ ... { ... "code": "DROP_COL", ... "data_check_name": "TargetLeakageDataCheck", ... "parameters": {}, ... "metadata": {"columns": ["leak"], "rows": None} ... } ... ] ... } ... ] The default method can be changed to pearson from mutual_info. >>> X["x"] = y / 2 >>> target_leakage_check = TargetLeakageDataCheck(pct_corr_threshold=0.8, method="pearson") >>> assert target_leakage_check.validate(X, y) == [ ... { ... "message": "Columns 'leak', 'x' are 80.0% or more correlated with the target", ... "data_check_name": "TargetLeakageDataCheck", ... "level": "warning", ... "details": {"columns": ["leak", "x"], "rows": None}, ... "code": "TARGET_LEAKAGE", ... "action_options": [ ... { ... "code": "DROP_COL", ... "data_check_name": "TargetLeakageDataCheck", ... "parameters": {}, ... "metadata": {"columns": ["leak", "x"], "rows": None} ... } ... ] ... } ... ] """ messages = [] X = infer_feature_types(X) y = infer_feature_types(y) highly_corr_cols = self._calculate_dependence(X, y) warning_msg_singular = "Column {} is {}% or more correlated with the target" warning_msg_plural = "Columns {} are {}% or more correlated with the target" if highly_corr_cols: if len(highly_corr_cols) == 1: warning_msg = warning_msg_singular.format( "'{}'".format(str(highly_corr_cols[0])), self.pct_corr_threshold * 100, ) else: warning_msg = warning_msg_plural.format( (", ").join(["'{}'".format(str(col)) for col in highly_corr_cols]), self.pct_corr_threshold * 100, ) messages.append( DataCheckWarning( message=warning_msg, data_check_name=self.name, message_code=DataCheckMessageCode.TARGET_LEAKAGE, details={"columns": highly_corr_cols}, action_options=[ DataCheckActionOption( DataCheckActionCode.DROP_COL, data_check_name=self.name, metadata={"columns": highly_corr_cols}, ), ], ).to_dict(), ) return messages