"""Data check that checks if any of the target labels are imbalanced, or if the number of values for each target are below 2 times the number of CV folds.
Use for classification problems.
"""
from evalml.data_checks import (
DataCheck,
DataCheckError,
DataCheckMessageCode,
DataCheckWarning,
)
from evalml.utils import infer_feature_types
[docs]class ClassImbalanceDataCheck(DataCheck):
"""Check if any of the target labels are imbalanced, or if the number of values for each target are below 2 times the number of CV folds. Use for classification problems.
Args:
threshold (float): The minimum threshold allowed for class imbalance before a warning is raised.
This threshold is calculated by comparing the number of samples in each class to the sum of samples in that class and the majority class.
For example, a multiclass case with [900, 900, 100] samples per classes 0, 1, and 2, respectively,
would have a 0.10 threshold for class 2 (100 / (900 + 100)). Defaults to 0.10.
min_samples (int): The minimum number of samples per accepted class. If the minority class is both below the threshold and min_samples,
then we consider this severely imbalanced. Must be greater than 0. Defaults to 100.
num_cv_folds (int): The number of cross-validation folds. Must be positive. Choose 0 to ignore this warning. Defaults to 3.
"""
def __init__(self, threshold=0.1, min_samples=100, num_cv_folds=3):
if threshold <= 0 or threshold > 0.5:
raise ValueError(
"Provided threshold {} is not within the range (0, 0.5]".format(
threshold
)
)
self.threshold = threshold
if min_samples <= 0:
raise ValueError(
"Provided value min_samples {} is not greater than 0".format(
min_samples
)
)
self.min_samples = min_samples
if num_cv_folds < 0:
raise ValueError(
"Provided number of CV folds {} is less than 0".format(num_cv_folds)
)
self.cv_folds = num_cv_folds * 2
[docs] def validate(self, X, y):
"""Check if any target labels are imbalanced beyond a threshold for binary and multiclass problems.
Ignores NaN values in target labels if they appear.
Args:
X (pd.DataFrame, np.ndarray): Features. Ignored.
y (pd.Series, np.ndarray): Target labels to check for imbalanced data.
Returns:
dict: Dictionary with DataCheckWarnings if imbalance in classes is less than the threshold,
and DataCheckErrors if the number of values for each target is below 2 * num_cv_folds.
Examples:
>>> import pandas as pd
...
>>> X = pd.DataFrame()
>>> y = pd.Series([0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
>>> class_imb_dc = ClassImbalanceDataCheck(threshold=0.10)
>>> assert class_imb_dc.validate(X, y) == {
... "errors": [{"message": "The number of instances of these targets is less than 2 * the number of cross folds = 6 instances: [0]",
... "data_check_name": "ClassImbalanceDataCheck",
... "level": "error",
... "code": "CLASS_IMBALANCE_BELOW_FOLDS",
... "details": {"target_values": [0], "rows": None, "columns": None}}],
... "warnings": [{"message": "The following labels fall below 10% of the target: [0]",
... "data_check_name": "ClassImbalanceDataCheck",
... "level": "warning",
... "code": "CLASS_IMBALANCE_BELOW_THRESHOLD",
... "details": {"target_values": [0], "rows": None, "columns": None}},
... {"message": "The following labels in the target have severe class imbalance because they fall under 10% of the target and have less than 100 samples: [0]",
... "data_check_name": "ClassImbalanceDataCheck",
... "level": "warning",
... "code": "CLASS_IMBALANCE_SEVERE",
... "details": {"target_values": [0], "rows": None, "columns": None}}],
... "actions": []}
...
...
>>> y = pd.Series([0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2])
>>> class_imb_dc = ClassImbalanceDataCheck(threshold=0.30, min_samples=5, num_cv_folds=1)
>>> assert class_imb_dc.validate(X, y) == {
... 'warnings': [{'message': 'The following labels fall below 30% of the target: [0]',
... 'data_check_name': 'ClassImbalanceDataCheck',
... 'level': 'warning',
... 'code': 'CLASS_IMBALANCE_BELOW_THRESHOLD',
... 'details': {'target_values': [0], "rows": None, "columns": None}},
... {'message': 'The following labels in the target have severe class imbalance because they fall under 30% of the target and have less than 5 samples: [0]',
... 'data_check_name': 'ClassImbalanceDataCheck',
... 'level': 'warning',
... 'code': 'CLASS_IMBALANCE_SEVERE',
... 'details': {'target_values': [0], "rows": None, "columns": None}}],
... 'errors': [],
... 'actions': []}
...
...
>>> y = pd.Series([0, 0, 1, 1, 1, 1, 2, 2, 2, 2])
>>> class_imb_dc = ClassImbalanceDataCheck(threshold=0.30, num_cv_folds=1)
>>> assert class_imb_dc.validate(X, y) == {'warnings': [], 'errors': [], 'actions': []}
"""
results = {"warnings": [], "errors": [], "actions": []}
y = infer_feature_types(y)
fold_counts = y.value_counts(normalize=False, sort=True)
if len(fold_counts) == 0:
return results
# search for targets that occur less than twice the number of cv folds first
below_threshold_folds = fold_counts.where(fold_counts < self.cv_folds).dropna()
if len(below_threshold_folds):
below_threshold_values = below_threshold_folds.index.tolist()
error_msg = "The number of instances of these targets is less than 2 * the number of cross folds = {} instances: {}"
DataCheck._add_message(
DataCheckError(
message=error_msg.format(
self.cv_folds, sorted(below_threshold_values)
),
data_check_name=self.name,
message_code=DataCheckMessageCode.CLASS_IMBALANCE_BELOW_FOLDS,
details={"target_values": sorted(below_threshold_values)},
),
results,
)
counts = fold_counts / (fold_counts + fold_counts.values[0])
below_threshold = counts.where(counts < self.threshold).dropna()
# if there are items that occur less than the threshold, add them to the list of results
if len(below_threshold):
below_threshold_values = below_threshold.index.tolist()
warning_msg = "The following labels fall below {:.0f}% of the target: {}"
DataCheck._add_message(
DataCheckWarning(
message=warning_msg.format(
self.threshold * 100, below_threshold_values
),
data_check_name=self.name,
message_code=DataCheckMessageCode.CLASS_IMBALANCE_BELOW_THRESHOLD,
details={"target_values": below_threshold_values},
),
results,
)
sample_counts = fold_counts.where(fold_counts < self.min_samples).dropna()
if len(below_threshold) and len(sample_counts):
sample_count_values = sample_counts.index.tolist()
severe_imbalance = [v for v in sample_count_values if v in below_threshold]
warning_msg = "The following labels in the target have severe class imbalance because they fall under {:.0f}% of the target and have less than {} samples: {}"
DataCheck._add_message(
DataCheckWarning(
message=warning_msg.format(
self.threshold * 100, self.min_samples, severe_imbalance
),
data_check_name=self.name,
message_code=DataCheckMessageCode.CLASS_IMBALANCE_SEVERE,
details={"target_values": severe_imbalance},
),
results,
)
return results