Source code for evalml.data_checks.natural_language_nan_data_check

from evalml.data_checks import DataCheck, DataCheckError, DataCheckMessageCode
from evalml.utils.woodwork_utils import infer_feature_types

error_contains_nan = "Input natural language column(s) ({}) contains NaN values. Please impute NaN values or drop these rows or columns."


[docs]class NaturalLanguageNaNDataCheck(DataCheck): """Checks if natural language columns contain NaN values."""
[docs] def __init__(self): """Checks each column in the input for natural language features and will issue an error if NaN values are present. """
[docs] def validate(self, X, y=None): """Checks if any natural language columns contain NaN values. Arguments: X (pd.DataFrame, np.ndarray): Features. y (pd.Series, np.ndarray): Ignored. Defaults to None. Returns: dict: dict with a DataCheckError if NaN values are present in natural language columns. Example: >>> import pandas as pd >>> import woodwork as ww >>> import numpy as np >>> data = pd.DataFrame() >>> data['A'] = [None, "string_that_is_long_enough_for_natural_language"] >>> data['B'] = ['string_that_is_long_enough_for_natural_language', 'string_that_is_long_enough_for_natural_language'] >>> data['C'] = np.random.randint(0, 3, size=len(data)) >>> data.ww.init(logical_types={'A': 'NaturalLanguage', 'B': 'NaturalLanguage'}) >>> nl_nan_check = NaturalLanguageNaNDataCheck() >>> assert nl_nan_check.validate(data) == { ... "warnings": [], ... "actions": [], ... "errors": [DataCheckError(message='Input natural language column(s) (A) contains NaN values. Please impute NaN values or drop these rows or columns.', ... data_check_name=NaturalLanguageNaNDataCheck.name, ... message_code=DataCheckMessageCode.NATURAL_LANGUAGE_HAS_NAN, ... details={"columns": 'A'}).to_dict()] ... } """ results = { "warnings": [], "errors": [], "actions": [] } X = infer_feature_types(X) X = X.ww.select('natural_language') X_describe = X.ww.describe_dict() nan_columns = [str(col) for col in X_describe if X_describe[col]['nan_count'] > 0] if len(nan_columns) > 0: cols_str = ', '.join(nan_columns) results["errors"].append(DataCheckError(message=error_contains_nan.format(cols_str), data_check_name=self.name, message_code=DataCheckMessageCode.NATURAL_LANGUAGE_HAS_NAN, details={"columns": cols_str}).to_dict()) return results