Source code for evalml.data_checks.natural_language_nan_data_check

"""Data check that checks each column in the input for natural language features and will issue an error if NaN values are present."""
from evalml.data_checks import DataCheck, DataCheckError, DataCheckMessageCode
from evalml.utils.woodwork_utils import infer_feature_types

error_contains_nan = "Input natural language column(s) ({}) contains NaN values. Please impute NaN values or drop these rows or columns."


[docs]class NaturalLanguageNaNDataCheck(DataCheck): """Checks each column in the input for natural language features and will issue an error if NaN values are present."""
[docs] def validate(self, X, y=None): """Check if any natural language columns contain NaN values. Args: X (pd.DataFrame, np.ndarray): Features. y (pd.Series, np.ndarray): Ignored. Defaults to None. Returns: dict: dict with a DataCheckError if NaN values are present in natural language columns. Example: >>> import pandas as pd >>> import woodwork as ww >>> import numpy as np ... >>> data = pd.DataFrame() >>> data['A'] = [None, "string_that_is_long_enough_for_natural_language"] >>> data['B'] = ['string_that_is_long_enough_for_natural_language', 'string_that_is_long_enough_for_natural_language'] >>> data['C'] = np.random.randint(0, 3, size=len(data)) >>> data.ww.init(logical_types={'A': 'NaturalLanguage', 'B': 'NaturalLanguage'}) ... >>> nl_nan_check = NaturalLanguageNaNDataCheck() >>> assert nl_nan_check.validate(data) == { ... "warnings": [], ... "actions": [], ... "errors": [DataCheckError(message='Input natural language column(s) (A) contains NaN values. Please impute NaN values or drop these rows or columns.', ... data_check_name=NaturalLanguageNaNDataCheck.name, ... message_code=DataCheckMessageCode.NATURAL_LANGUAGE_HAS_NAN, ... details={"columns": ['A']}).to_dict()] ... } """ results = {"warnings": [], "errors": [], "actions": []} X = infer_feature_types(X) X = X.ww.select("natural_language") X_describe = X.ww.describe_dict() nan_columns = [ str(col) for col in X_describe if X_describe[col]["nan_count"] > 0 ] if len(nan_columns) > 0: results["errors"].append( DataCheckError( message=error_contains_nan.format(", ".join(nan_columns)), data_check_name=self.name, message_code=DataCheckMessageCode.NATURAL_LANGUAGE_HAS_NAN, details={"columns": nan_columns}, ).to_dict() ) return results