class_imbalance_data_check¶
Module Contents¶
Classes Summary¶
Check if any of the target labels are imbalanced, or if the number of values for each target are below 2 times the number of CV folds. Use for classification problems. |
Contents¶
-
class
evalml.data_checks.class_imbalance_data_check.
ClassImbalanceDataCheck
(threshold=0.1, min_samples=100, num_cv_folds=3)[source]¶ Check if any of the target labels are imbalanced, or if the number of values for each target are below 2 times the number of CV folds. Use for classification problems.
- Parameters
threshold (float) – The minimum threshold allowed for class imbalance before a warning is raised. This threshold is calculated by comparing the number of samples in each class to the sum of samples in that class and the majority class. For example, a multiclass case with [900, 900, 100] samples per classes 0, 1, and 2, respectively, would have a 0.10 threshold for class 2 (100 / (900 + 100)). Defaults to 0.10.
min_samples (int) – The minimum number of samples per accepted class. If the minority class is both below the threshold and min_samples, then we consider this severely imbalanced. Must be greater than 0. Defaults to 100.
num_cv_folds (int) – The number of cross-validation folds. Must be positive. Choose 0 to ignore this warning. Defaults to 3.
Methods
Returns a name describing the data check.
Checks if any target labels are imbalanced beyond a threshold for binary and multiclass problems
-
name
(cls)¶ Returns a name describing the data check.
-
validate
(self, X, y)[source]¶ - Checks if any target labels are imbalanced beyond a threshold for binary and multiclass problems
Ignores NaN values in target labels if they appear.
- Parameters
X (pd.DataFrame, np.ndarray) – Features. Ignored.
y (pd.Series, np.ndarray) – Target labels to check for imbalanced data.
- Returns
- Dictionary with DataCheckWarnings if imbalance in classes is less than the threshold,
and DataCheckErrors if the number of values for each target is below 2 * num_cv_folds.
- Return type
dict
Example
>>> import pandas as pd >>> X = pd.DataFrame() >>> y = pd.Series([0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]) >>> target_check = ClassImbalanceDataCheck(threshold=0.10) >>> assert target_check.validate(X, y) == {"errors": [{"message": "The number of instances of these targets is less than 2 * the number of cross folds = 6 instances: [0]", "data_check_name": "ClassImbalanceDataCheck", "level": "error", "code": "CLASS_IMBALANCE_BELOW_FOLDS", "details": {"target_values": [0]}}], "warnings": [{"message": "The following labels fall below 10% of the target: [0]", "data_check_name": "ClassImbalanceDataCheck", "level": "warning", "code": "CLASS_IMBALANCE_BELOW_THRESHOLD", "details": {"target_values": [0]}}, {"message": "The following labels in the target have severe class imbalance because they fall under 10% of the target and have less than 100 samples: [0]", "data_check_name": "ClassImbalanceDataCheck", "level": "warning", "code": "CLASS_IMBALANCE_SEVERE", "details": {"target_values": [0]}}], "actions": []}