"""Data check that checks if the target data contains certain distributions that may need to be transformed prior training to improve model performance."""
import numpy as np
import woodwork as ww
from scipy.stats import jarque_bera, shapiro
from evalml.data_checks import (
DataCheck,
DataCheckAction,
DataCheckActionCode,
DataCheckError,
DataCheckMessageCode,
DataCheckWarning,
)
from evalml.utils import infer_feature_types
[docs]class TargetDistributionDataCheck(DataCheck):
"""Check if the target data contains certain distributions that may need to be transformed prior training to improve model performance. Uses the Shapiro-Wilks test when the dataset is <=5000 samples, otherwise uses Jarque-Bera."""
[docs] def validate(self, X, y):
"""Check if the target data has a certain distribution.
Args:
X (pd.DataFrame, np.ndarray): Features. Ignored.
y (pd.Series, np.ndarray): Target data to check for underlying distributions.
Returns:
dict (DataCheckError): List with DataCheckErrors if certain distributions are found in the target data.
Examples:
>>> import pandas as pd
...
>>> y = [0.946, 0.972, 1.154, 0.954, 0.969, 1.222, 1.038, 0.999, 0.973, 0.897]
>>> target_check = TargetDistributionDataCheck()
>>> assert target_check.validate(None, y) == {
... "errors": [],
... "warnings": [{"message": "Target may have a lognormal distribution.",
... "data_check_name": "TargetDistributionDataCheck",
... "level": "warning",
... "code": "TARGET_LOGNORMAL_DISTRIBUTION",
... "details": {"shapiro-statistic/pvalue": '0.8/0.045', "columns": None, "rows": None}}],
... "actions": [{'code': 'TRANSFORM_TARGET',
... 'metadata': {'transformation_strategy': 'lognormal',
... 'is_target': True,
... "columns": None,
... "rows": None}}]}
...
...
>>> y = pd.Series([1, 1, 1, 2, 2, 3, 4, 4, 5, 5, 5])
>>> assert target_check.validate(None, y) == {'warnings': [], 'errors': [], 'actions': []}
...
...
>>> y = pd.Series(pd.date_range('1/1/21', periods=10))
>>> assert target_check.validate(None, y) == {
... 'warnings': [],
... 'errors': [{'message': 'Target is unsupported datetime type. Valid Woodwork logical types include: integer, double',
... 'data_check_name': 'TargetDistributionDataCheck',
... 'level': 'error',
... 'details': {'columns': None, 'rows': None, 'unsupported_type': 'datetime'},
... 'code': 'TARGET_UNSUPPORTED_TYPE'}],
... 'actions': []}
"""
results = {"warnings": [], "errors": [], "actions": []}
if y is None:
results["errors"].append(
DataCheckError(
message="Target is None",
data_check_name=self.name,
message_code=DataCheckMessageCode.TARGET_IS_NONE,
details={},
).to_dict()
)
return results
y = infer_feature_types(y)
allowed_types = [
ww.logical_types.Integer.type_string,
ww.logical_types.Double.type_string,
]
is_supported_type = y.ww.logical_type.type_string in allowed_types
if not is_supported_type:
results["errors"].append(
DataCheckError(
message="Target is unsupported {} type. Valid Woodwork logical types include: {}".format(
y.ww.logical_type.type_string,
", ".join([ltype for ltype in allowed_types]),
),
data_check_name=self.name,
message_code=DataCheckMessageCode.TARGET_UNSUPPORTED_TYPE,
details={"unsupported_type": y.ww.logical_type.type_string},
).to_dict()
)
return results
normalization_test = shapiro if len(y) <= 5000 else jarque_bera
normalization_test_string = "shapiro" if len(y) <= 5000 else "jarque_bera"
# Check if a normal distribution is detected with p-value above 0.05
if normalization_test(y).pvalue >= 0.05:
return results
y_new = round(y, 6)
if any(y <= 0):
y_new = y + abs(y.min()) + 1
y_new = y_new[
y_new < (y_new.mean() + 3 * round(y.std(), 3))
] # Drop values greater than 3 standard deviations
norm_test_og = normalization_test(y_new)
norm_test_log = normalization_test(np.log(y_new))
log_detected = False
# If the p-value of the log transformed target is greater than or equal to the p-value of the original target
# with outliers dropped, then it would imply that the log transformed target has more of a normal distribution
if norm_test_log.pvalue >= norm_test_og.pvalue:
log_detected = True
if log_detected:
details = {
f"{normalization_test_string}-statistic/pvalue": f"{round(norm_test_og.statistic, 1)}/{round(norm_test_og.pvalue, 3)}"
}
results["warnings"].append(
DataCheckWarning(
message="Target may have a lognormal distribution.",
data_check_name=self.name,
message_code=DataCheckMessageCode.TARGET_LOGNORMAL_DISTRIBUTION,
details=details,
).to_dict()
)
results["actions"].append(
DataCheckAction(
DataCheckActionCode.TRANSFORM_TARGET,
metadata={
"is_target": True,
"transformation_strategy": "lognormal",
},
).to_dict()
)
return results