Source code for evalml.data_checks.target_distribution_data_check
"""Data check that checks if the target data contains certain distributions that may need to be transformed prior training to improve model performance."""importnumpyasnpimportwoodworkaswwfromscipy.statsimportjarque_bera,shapirofromevalml.data_checksimport(DataCheck,DataCheckActionCode,DataCheckActionOption,DataCheckError,DataCheckMessageCode,DataCheckWarning,)fromevalml.utilsimportinfer_feature_types
[docs]classTargetDistributionDataCheck(DataCheck):"""Check if the target data contains certain distributions that may need to be transformed prior training to improve model performance. Uses the Shapiro-Wilks test when the dataset is <=5000 samples, otherwise uses Jarque-Bera."""
[docs]defvalidate(self,X,y):"""Check if the target data has a certain distribution. Args: X (pd.DataFrame, np.ndarray): Features. Ignored. y (pd.Series, np.ndarray): Target data to check for underlying distributions. Returns: dict (DataCheckError): List with DataCheckErrors if certain distributions are found in the target data. Examples: >>> import pandas as pd Targets that exhibit a lognormal distribution will raise a warning for the user to transform the target. >>> y = [0.946, 0.972, 1.154, 0.954, 0.969, 1.222, 1.038, 0.999, 0.973, 0.897] >>> target_check = TargetDistributionDataCheck() >>> assert target_check.validate(None, y) == [ ... { ... "message": "Target may have a lognormal distribution.", ... "data_check_name": "TargetDistributionDataCheck", ... "level": "warning", ... "code": "TARGET_LOGNORMAL_DISTRIBUTION", ... "details": {"normalization_method": "shapiro", "statistic": 0.8, "p-value": 0.045, "columns": None, "rows": None}, ... "action_options": [ ... { ... "code": "TRANSFORM_TARGET", ... "data_check_name": "TargetDistributionDataCheck", ... "parameters": {}, ... "metadata": { ... "transformation_strategy": "lognormal", ... "is_target": True, ... "columns": None, ... "rows": None ... } ... } ... ] ... } ... ] ... >>> y = pd.Series([1, 1, 1, 2, 2, 3, 4, 4, 5, 5, 5]) >>> assert target_check.validate(None, y) == [] ... ... >>> y = pd.Series(pd.date_range("1/1/21", periods=10)) >>> assert target_check.validate(None, y) == [ ... { ... "message": "Target is unsupported datetime type. Valid Woodwork logical types include: integer, double, age, age_fractional", ... "data_check_name": "TargetDistributionDataCheck", ... "level": "error", ... "details": {"columns": None, "rows": None, "unsupported_type": "datetime"}, ... "code": "TARGET_UNSUPPORTED_TYPE", ... "action_options": [] ... } ... ] """messages=[]ifyisNone:messages.append(DataCheckError(message="Target is None",data_check_name=self.name,message_code=DataCheckMessageCode.TARGET_IS_NONE,details={},).to_dict(),)returnmessagesy=infer_feature_types(y)allowed_types=[ww.logical_types.Integer.type_string,ww.logical_types.Double.type_string,ww.logical_types.Age.type_string,ww.logical_types.AgeFractional.type_string,]is_supported_type=y.ww.logical_type.type_stringinallowed_typesifnotis_supported_type:messages.append(DataCheckError(message="Target is unsupported {} type. Valid Woodwork logical types include: {}".format(y.ww.logical_type.type_string,", ".join([ltypeforltypeinallowed_types]),),data_check_name=self.name,message_code=DataCheckMessageCode.TARGET_UNSUPPORTED_TYPE,details={"unsupported_type":y.ww.logical_type.type_string},).to_dict(),)returnmessages(is_log_distribution,normalization_test_string,norm_test_og,)=_detect_log_distribution_helper(y)ifis_log_distribution:details={"normalization_method":normalization_test_string,"statistic":round(norm_test_og.statistic,1),"p-value":round(norm_test_og.pvalue,3),}messages.append(DataCheckWarning(message="Target may have a lognormal distribution.",data_check_name=self.name,message_code=DataCheckMessageCode.TARGET_LOGNORMAL_DISTRIBUTION,details=details,action_options=[DataCheckActionOption(DataCheckActionCode.TRANSFORM_TARGET,data_check_name=self.name,metadata={"is_target":True,"transformation_strategy":"lognormal",},),],).to_dict(),)returnmessages
def_detect_log_distribution_helper(y):"""Helper method to detect log distribution. Returns boolean, the normalization test used, and test statistics."""normalization_test=shapiroiflen(y)<=5000elsejarque_beranormalization_test_string="shapiro"iflen(y)<=5000else"jarque_bera"# Check if a normal distribution is detected with p-value above 0.05pvalue=normalization_test(y).pvalueifpvalue>=0.05ornp.isnan(pvalue):returnFalse,normalization_test_string,Noney_new=round(y,6)ifany(y<=0):y_new=y+abs(y.min())+1y_new=y_new[y_new<(y_new.mean()+3*round(y.std(),3))]# Drop values greater than 3 standard deviationsnorm_test_og=normalization_test(y_new)norm_test_log=normalization_test(np.log(y_new))# If the p-value of the log transformed target is greater than or equal to the p-value of the original target# with outliers dropped, then it would imply that the log transformed target has more of a normal distributionifround(norm_test_log.pvalue,6)>=round(norm_test_og.pvalue,6):returnTrue,normalization_test_string,norm_test_ogreturnFalse,normalization_test_string,norm_test_og