class_imbalance_data_check

Data check that checks if any of the target labels are imbalanced, or if the number of values for each target are below 2 times the number of CV folds.

Use for classification problems.

Module Contents

Classes Summary

ClassImbalanceDataCheck

Check if any of the target labels are imbalanced, or if the number of values for each target are below 2 times the number of CV folds. Use for classification problems.

Contents

class evalml.data_checks.class_imbalance_data_check.ClassImbalanceDataCheck(threshold=0.1, min_samples=100, num_cv_folds=3)[source]

Check if any of the target labels are imbalanced, or if the number of values for each target are below 2 times the number of CV folds. Use for classification problems.

Parameters
  • threshold (float) – The minimum threshold allowed for class imbalance before a warning is raised. This threshold is calculated by comparing the number of samples in each class to the sum of samples in that class and the majority class. For example, a multiclass case with [900, 900, 100] samples per classes 0, 1, and 2, respectively, would have a 0.10 threshold for class 2 (100 / (900 + 100)). Defaults to 0.10.

  • min_samples (int) – The minimum number of samples per accepted class. If the minority class is both below the threshold and min_samples, then we consider this severely imbalanced. Must be greater than 0. Defaults to 100.

  • num_cv_folds (int) – The number of cross-validation folds. Must be positive. Choose 0 to ignore this warning. Defaults to 3.

Methods

name

Return a name describing the data check.

validate

Check if any target labels are imbalanced beyond a threshold for binary and multiclass problems.

name(cls)

Return a name describing the data check.

validate(self, X, y)[source]

Check if any target labels are imbalanced beyond a threshold for binary and multiclass problems.

Ignores NaN values in target labels if they appear.

Parameters
  • X (pd.DataFrame, np.ndarray) – Features. Ignored.

  • y (pd.Series, np.ndarray) – Target labels to check for imbalanced data.

Returns

Dictionary with DataCheckWarnings if imbalance in classes is less than the threshold,

and DataCheckErrors if the number of values for each target is below 2 * num_cv_folds.

Return type

dict

Example

>>> import pandas as pd
>>> X = pd.DataFrame()
>>> y = pd.Series([0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
>>> target_check = ClassImbalanceDataCheck(threshold=0.10)
>>> assert target_check.validate(X, y) == {"errors": [{"message": "The number of instances of these targets is less than 2 * the number of cross folds = 6 instances: [0]",
...                                                    "data_check_name": "ClassImbalanceDataCheck",
...                                                    "level": "error",
...                                                    "code": "CLASS_IMBALANCE_BELOW_FOLDS",
...                                                    "details": {"target_values": [0]}}],
...                                      "warnings": [{"message": "The following labels fall below 10% of the target: [0]",
...                                                    "data_check_name": "ClassImbalanceDataCheck",
...                                                    "level": "warning",
...                                                    "code": "CLASS_IMBALANCE_BELOW_THRESHOLD",
...                                                    "details": {"target_values": [0]}},
...                                                    {"message": "The following labels in the target have severe class imbalance because they fall under 10% of the target and have less than 100 samples: [0]",
...                                                    "data_check_name": "ClassImbalanceDataCheck",
...                                                    "level": "warning",
...                                                    "code": "CLASS_IMBALANCE_SEVERE",
...                                                    "details": {"target_values": [0]}}],
...                                      "actions": []}