import pandas as pd from .data_check import DataCheck from .data_check_message import DataCheckWarning [docs]class LabelLeakageDataCheck(DataCheck): """Check if any of the features are highly correlated with the target.""" [docs] def __init__(self, pct_corr_threshold=0.95): """Check if any of the features are highly correlated with the target. Currently only supports binary and numeric targets and features. Arguments: pct_corr_threshold (float): The correlation threshold to be considered leakage. Defaults to 0.95. """ if pct_corr_threshold < 0 or pct_corr_threshold > 1: raise ValueError("pct_corr_threshold must be a float between 0 and 1, inclusive.") self.pct_corr_threshold = pct_corr_threshold [docs] def validate(self, X, y): """Check if any of the features are highly correlated with the target. Currently only supports binary and numeric targets and features. Arguments: X (pd.DataFrame): The input features to check y (pd.Series): The labels Returns: list (DataCheckWarning): list with a DataCheckWarning if there is label leakage detected. Example: >>> X = pd.DataFrame({ ... 'leak': [10, 42, 31, 51, 61], ... 'x': [42, 54, 12, 64, 12], ... 'y': [12, 5, 13, 74, 24], ... }) >>> y = pd.Series([10, 42, 31, 51, 40]) >>> label_leakage_check = LabelLeakageDataCheck(pct_corr_threshold=0.8) >>> assert label_leakage_check.validate(X, y) == [DataCheckWarning("Column 'leak' is 80.0% or more correlated with the target", "LabelLeakageDataCheck")] """ if not isinstance(X, pd.DataFrame): X = pd.DataFrame(X) if not isinstance(y, pd.Series): y = pd.Series(y) # only select numeric numerics = ['int16', 'int32', 'int64', 'float16', 'float32', 'float64', 'bool'] X = X.select_dtypes(include=numerics) if len(X.columns) == 0: return [] corrs = {label: abs(y.corr(col)) for label, col in X.iteritems() if abs(y.corr(col)) >= self.pct_corr_threshold} highly_corr_cols = {key: value for key, value in corrs.items() if value >= self.pct_corr_threshold} warning_msg = "Column '{}' is {}% or more correlated with the target" return [DataCheckWarning(warning_msg.format(col_name, self.pct_corr_threshold * 100), self.name) for col_name in highly_corr_cols]