class_imbalance_data_check¶
Data check that checks if any of the target labels are imbalanced, or if the number of values for each target are below 2 times the number of CV folds.
Use for classification problems.
Module Contents¶
Classes Summary¶
Check if any of the target labels are imbalanced, or if the number of values for each target are below 2 times the number of CV folds. Use for classification problems. |
Contents¶
-
class
evalml.data_checks.class_imbalance_data_check.
ClassImbalanceDataCheck
(threshold=0.1, min_samples=100, num_cv_folds=3)[source]¶ Check if any of the target labels are imbalanced, or if the number of values for each target are below 2 times the number of CV folds. Use for classification problems.
- Parameters
threshold (float) – The minimum threshold allowed for class imbalance before a warning is raised. This threshold is calculated by comparing the number of samples in each class to the sum of samples in that class and the majority class. For example, a multiclass case with [900, 900, 100] samples per classes 0, 1, and 2, respectively, would have a 0.10 threshold for class 2 (100 / (900 + 100)). Defaults to 0.10.
min_samples (int) – The minimum number of samples per accepted class. If the minority class is both below the threshold and min_samples, then we consider this severely imbalanced. Must be greater than 0. Defaults to 100.
num_cv_folds (int) – The number of cross-validation folds. Must be positive. Choose 0 to ignore this warning. Defaults to 3.
Methods
Return a name describing the data check.
Check if any target labels are imbalanced beyond a threshold for binary and multiclass problems.
-
name
(cls)¶ Return a name describing the data check.
-
validate
(self, X, y)[source]¶ Check if any target labels are imbalanced beyond a threshold for binary and multiclass problems.
Ignores NaN values in target labels if they appear.
- Parameters
X (pd.DataFrame, np.ndarray) – Features. Ignored.
y (pd.Series, np.ndarray) – Target labels to check for imbalanced data.
- Returns
- Dictionary with DataCheckWarnings if imbalance in classes is less than the threshold,
and DataCheckErrors if the number of values for each target is below 2 * num_cv_folds.
- Return type
dict
Example
>>> import pandas as pd >>> X = pd.DataFrame() >>> y = pd.Series([0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]) >>> target_check = ClassImbalanceDataCheck(threshold=0.10) >>> assert target_check.validate(X, y) == {"errors": [{"message": "The number of instances of these targets is less than 2 * the number of cross folds = 6 instances: [0]", ... "data_check_name": "ClassImbalanceDataCheck", ... "level": "error", ... "code": "CLASS_IMBALANCE_BELOW_FOLDS", ... "details": {"target_values": [0], "rows": None, "columns": None}}], ... "warnings": [{"message": "The following labels fall below 10% of the target: [0]", ... "data_check_name": "ClassImbalanceDataCheck", ... "level": "warning", ... "code": "CLASS_IMBALANCE_BELOW_THRESHOLD", ... "details": {"target_values": [0], "rows": None, "columns": None}}, ... {"message": "The following labels in the target have severe class imbalance because they fall under 10% of the target and have less than 100 samples: [0]", ... "data_check_name": "ClassImbalanceDataCheck", ... "level": "warning", ... "code": "CLASS_IMBALANCE_SEVERE", ... "details": {"target_values": [0], "rows": None, "columns": None}}], ... "actions": []}