"""General utility methods."""
import importlib
import logging
import os
import warnings
from collections import namedtuple
from functools import reduce
import numpy as np
import pandas as pd
from sklearn.utils import check_random_state
from evalml.exceptions import MissingComponentError
logger = logging.getLogger(__name__)
[docs]def import_or_raise(library, error_msg=None, warning=False):
"""Attempts to import the requested library by name. If the import fails, raises an ImportError or warning.
Args:
library (str): The name of the library.
error_msg (str): Rrror message to return if the import fails.
warning (bool): If True, import_or_raise gives a warning instead of ImportError. Defaults to False.
Returns:
Returns the library if importing succeeded.
Raises:
ImportError: If attempting to import the library fails because the library is not installed.
Exception: If importing the library fails.
"""
try:
return importlib.import_module(library)
except ImportError:
if error_msg is None:
error_msg = ""
msg = f"Missing optional dependency '{library}'. Please use pip to install {library}. {error_msg}"
if warning:
warnings.warn(msg)
else:
raise ImportError(msg)
except Exception as ex:
msg = f"An exception occurred while trying to import `{library}`: {str(ex)}"
if warning:
warnings.warn(msg)
else:
raise Exception(msg)
[docs]def convert_to_seconds(input_str):
"""Converts a string describing a length of time to its length in seconds.
Args:
input_str (str): The string to be parsed and converted to seconds.
Returns:
Returns the library if importing succeeded.
Raises:
AssertionError: If an invalid unit is used.
Examples:
>>> assert convert_to_seconds("10 hr") == 36000.0
>>> assert convert_to_seconds("30 minutes") == 1800.0
>>> assert convert_to_seconds("2.5 min") == 150.0
"""
hours = {"h", "hr", "hour", "hours"}
minutes = {"m", "min", "minute", "minutes"}
seconds = {"s", "sec", "second", "seconds"}
value, unit = input_str.split()
if unit[-1] == "s" and len(unit) != 1:
unit = unit[:-1]
if unit in seconds:
return float(value)
elif unit in minutes:
return float(value) * 60
elif unit in hours:
return float(value) * 3600
else:
msg = (
"Invalid unit. Units must be hours, mins, or seconds. Received '{}'".format(
unit
)
)
raise AssertionError(msg)
# specifies the min and max values a seed to np.random.RandomState is allowed to take.
# these limits were chosen to fit in the numpy.int32 datatype to avoid issues with 32-bit systems
# see https://docs.scipy.org/doc/numpy-1.15.0/reference/generated/numpy.random.RandomState.html
SEED_BOUNDS = namedtuple("SEED_BOUNDS", ("min_bound", "max_bound"))(0, 2 ** 31 - 1)
[docs]def get_random_state(seed):
"""Generates a numpy.random.RandomState instance using seed.
Args:
seed (None, int, np.random.RandomState object): seed to use to generate numpy.random.RandomState. Must be between SEED_BOUNDS.min_bound and SEED_BOUNDS.max_bound, inclusive.
Raises:
ValueError: If the input seed is not within the acceptable range.
Returns:
A numpy.random.RandomState instance.
"""
if isinstance(seed, (int, np.integer)) and (
seed < SEED_BOUNDS.min_bound or SEED_BOUNDS.max_bound < seed
):
raise ValueError(
'Seed "{}" is not in the range [{}, {}], inclusive'.format(
seed, SEED_BOUNDS.min_bound, SEED_BOUNDS.max_bound
)
)
return check_random_state(seed)
[docs]def get_random_seed(
random_state, min_bound=SEED_BOUNDS.min_bound, max_bound=SEED_BOUNDS.max_bound
):
"""Given a numpy.random.RandomState object, generate an int representing a seed value for another random number generator. Or, if given an int, return that int.
To protect against invalid input to a particular library's random number generator, if an int value is provided, and it is outside the bounds "[min_bound, max_bound)", the value will be projected into the range between the min_bound (inclusive) and max_bound (exclusive) using modular arithmetic.
Args:
random_state (int, numpy.random.RandomState): random state
min_bound (None, int): if not default of None, will be min bound when generating seed (inclusive). Must be less than max_bound.
max_bound (None, int): if not default of None, will be max bound when generating seed (exclusive). Must be greater than min_bound.
Returns:
int: Seed for random number generator
Raises:
ValueError: If boundaries are not valid.
"""
if not min_bound < max_bound:
raise ValueError(
"Provided min_bound {} is not less than max_bound {}".format(
min_bound, max_bound
)
)
if isinstance(random_state, np.random.RandomState):
return random_state.randint(min_bound, max_bound)
if random_state < min_bound or random_state >= max_bound:
return ((random_state - min_bound) % (max_bound - min_bound)) + min_bound
return random_state
[docs]class classproperty:
"""Allows function to be accessed as a class level property.
Example:
.. code-block::
class LogisticRegressionBinaryPipeline(PipelineBase):
component_graph = ['Simple Imputer', 'Logistic Regression Classifier']
@classproperty
def summary(cls):
summary = ""
for component in cls.component_graph:
component = handle_component_class(component)
summary += component.name + " + "
return summary
assert LogisticRegressionBinaryPipeline.summary == "Simple Imputer + Logistic Regression Classifier + "
assert LogisticRegressionBinaryPipeline().summary == "Simple Imputer + Logistic Regression Classifier + "
"""
def __init__(self, func):
self.func = func
def __get__(self, _, klass):
"""Get property value."""
return self.func(klass)
def _get_subclasses(base_class):
"""Gets all of the leaf nodes in the hiearchy tree for a given base class.
Args:
base_class (abc.ABCMeta): Class to find all of the children for.
Returns:
subclasses (list): List of all children that are not base classes.
"""
classes_to_check = base_class.__subclasses__()
subclasses = []
while classes_to_check:
subclass = classes_to_check.pop()
children = subclass.__subclasses__()
if children:
classes_to_check.extend(children)
else:
subclasses.append(subclass)
return subclasses
_not_used_in_automl = {
"BaselineClassifier",
"BaselineRegressor",
"TimeSeriesBaselineEstimator",
"StackedEnsembleClassifier",
"StackedEnsembleRegressor",
"KNeighborsClassifier",
"SVMClassifier",
"SVMRegressor",
"LinearRegressor",
"VowpalWabbitBinaryClassifier",
"VowpalWabbitMulticlassClassifier",
"VowpalWabbitRegressor",
}
[docs]def get_importable_subclasses(base_class, used_in_automl=True):
"""Get importable subclasses of a base class. Used to list all of our estimators, transformers, components and pipelines dynamically.
Args:
base_class (abc.ABCMeta): Base class to find all of the subclasses for.
used_in_automl: Not all components/pipelines/estimators are used in automl search. If True,
only include those subclasses that are used in the search. This would mean excluding classes related to
ExtraTrees, ElasticNet, and Baseline estimators.
Returns:
List of subclasses.
"""
all_classes = _get_subclasses(base_class)
classes = []
for cls in all_classes:
if "evalml.pipelines" not in cls.__module__:
continue
try:
cls()
classes.append(cls)
except (ImportError, MissingComponentError, TypeError):
logger.debug(
f"Could not import class {cls.__name__} in get_importable_subclasses"
)
if used_in_automl:
classes = [cls for cls in classes if cls.__name__ not in _not_used_in_automl]
return classes
def _rename_column_names_to_numeric(X, flatten_tuples=True):
"""Used in LightGBM and XGBoost estimator classes to rename column names when the input is a pd.DataFrame in case it has column names that contain symbols ([, ], <) that these estimators cannot natively handle.
Args:
X (pd.DataFrame): The input training data of shape [n_samples, n_features]
flatten_tuples (bool): Whether to flatten MultiIndex or tuple column names. LightGBM cannot handle columns with tuple names.
Returns:
Transformed X where column names are renamed to numerical values
"""
if isinstance(X, (np.ndarray, list)):
return pd.DataFrame(X)
X_renamed = X.copy()
if flatten_tuples and (len(X.columns) > 0 and isinstance(X.columns, pd.MultiIndex)):
flat_col_names = list(map(str, X_renamed.columns))
X_renamed.columns = flat_col_names
rename_cols_dict = dict(
(str(col), col_num) for col_num, col in enumerate(list(X.columns))
)
else:
rename_cols_dict = dict(
(col, col_num) for col_num, col in enumerate(list(X.columns))
)
X_renamed.rename(columns=rename_cols_dict, inplace=True)
return X_renamed
[docs]def jupyter_check():
"""Get whether or not the code is being run in a Ipython environment (such as Jupyter Notebook or Jupyter Lab).
Returns:
boolean: True if Ipython, False otherwise.
"""
try:
ipy = import_or_raise("IPython")
return ipy.core.getipython.get_ipython()
except Exception:
return False
[docs]def safe_repr(value):
"""Convert the given value into a string that can safely be used for repr.
Args:
value: The item to convert
Returns:
String representation of the value
"""
if isinstance(value, float):
if pd.isna(value):
return "np.nan"
if np.isinf(value):
return f"float('{repr(value)}')"
return repr(value)
[docs]def is_all_numeric(df):
"""Checks if the given DataFrame contains only numeric values.
Args:
df (pd.DataFrame): The DataFrame to check data types of.
Returns:
True if all the columns are numeric and are not missing any values, False otherwise.
"""
for col_tags in df.ww.semantic_tags.values():
if "numeric" not in col_tags:
return False
if df.isnull().any().any():
return False
return True
[docs]def pad_with_nans(pd_data, num_to_pad):
"""Pad the beginning num_to_pad rows with nans.
Args:
pd_data (pd.DataFrame or pd.Series): Data to pad.
num_to_pad (int): Number of nans to pad.
Returns:
pd.DataFrame or pd.Series
"""
if isinstance(pd_data, pd.Series):
padding = pd.Series([np.nan] * num_to_pad, name=pd_data.name)
else:
padding = pd.DataFrame({col: [np.nan] * num_to_pad for col in pd_data.columns})
padded = pd.concat([padding, pd_data], ignore_index=True)
# By default, pd.concat will convert all types to object if there are mixed numerics and objects
# The call to convert_dtypes ensures numerics stay numerics in the new dataframe.
return padded.convert_dtypes(
infer_objects=True,
convert_string=False,
convert_floating=False,
convert_integer=False,
convert_boolean=False,
)
def _get_rows_without_nans(*data):
"""Compute a boolean array marking where all entries in the data are non-nan.
Args:
*data (sequence of pd.Series or pd.DataFrame)
Returns:
np.ndarray: mask where each entry is True if and only if all corresponding entries in that index in data
are non-nan.
"""
def _not_nan(pd_data):
if pd_data is None or len(pd_data) == 0:
return np.array([True])
if isinstance(pd_data, pd.Series):
return ~pd_data.isna().values
elif isinstance(pd_data, pd.DataFrame):
return ~pd_data.isna().any(axis=1).values
else:
return pd_data
mask = reduce(lambda a, b: np.logical_and(_not_nan(a), _not_nan(b)), data)
return mask
[docs]def drop_rows_with_nans(*pd_data):
"""Drop rows that have any NaNs in all dataframes or series.
Args:
*pd_data: sequence of pd.Series or pd.DataFrame or None
Returns:
list of pd.DataFrame or pd.Series or None
"""
mask = _get_rows_without_nans(*pd_data)
def _subset(pd_data):
if pd_data is not None and not pd_data.empty:
return pd_data.iloc[mask]
return pd_data
return [_subset(data) for data in pd_data]
def _file_path_check(filepath=None, format="png", interactive=False, is_plotly=False):
"""Helper function to check the filepath being passed.
Args:
filepath (str or Path, optional): Location to save file.
format (str): Extension for figure to be saved as. Defaults to 'png'.
interactive (bool, optional): If True and fig is of type plotly.Figure, sets the format to 'html'.
is_plotly (bool, optional): Check to see if the fig being passed is of type plotly.Figure.
Returns:
String representing the final filepath the image will be saved to.
"""
if filepath:
filepath = str(filepath)
path_and_name, extension = os.path.splitext(filepath)
extension = extension[1:].lower() if extension else None
if is_plotly and interactive:
format_ = "html"
elif not extension and not interactive:
format_ = format
else:
format_ = extension
filepath = f"{path_and_name}.{format_}"
try:
f = open(filepath, "w")
f.close()
except (IOError, FileNotFoundError):
raise ValueError(
("Specified filepath is not writeable: {}".format(filepath))
)
return filepath
[docs]def save_plot(
fig, filepath=None, format="png", interactive=False, return_filepath=False
):
"""Saves fig to filepath if specified, or to a default location if not.
Args:
fig (Figure): Figure to be saved.
filepath (str or Path, optional): Location to save file. Default is with filename "test_plot".
format (str): Extension for figure to be saved as. Ignored if interactive is True and fig
is of type plotly.Figure. Defaults to 'png'.
interactive (bool, optional): If True and fig is of type plotly.Figure, saves the fig as interactive
instead of static, and format will be set to 'html'. Defaults to False.
return_filepath (bool, optional): Whether to return the final filepath the image is saved to. Defaults to False.
Returns:
String representing the final filepath the image was saved to if return_filepath is set to True.
Defaults to None.
"""
plotly_ = import_or_raise("plotly", error_msg="Cannot find dependency plotly")
graphviz_ = import_or_raise(
"graphviz", error_msg="Please install graphviz to visualize trees."
)
matplotlib = import_or_raise(
"matplotlib", error_msg="Cannot find dependency matplotlib"
)
plt_ = matplotlib.pyplot
axes_ = matplotlib.axes
is_plotly = False
is_graphviz = False
is_plt = False
is_seaborn = False
format = format if format else "png"
if isinstance(fig, plotly_.graph_objects.Figure):
is_plotly = True
elif isinstance(fig, graphviz_.Source):
is_graphviz = True
elif isinstance(fig, plt_.Figure):
is_plt = True
elif isinstance(fig, axes_.SubplotBase):
is_seaborn = True
if not filepath:
extension = "html" if interactive and is_plotly else format
filepath = os.path.join(os.getcwd(), f"test_plot.{extension}")
filepath = _file_path_check(
filepath, format=format, interactive=interactive, is_plotly=is_plotly
)
if is_plotly and interactive:
fig.write_html(file=filepath)
elif is_plotly and not interactive:
fig.write_image(file=filepath, engine="kaleido")
elif is_graphviz:
filepath_, format_ = os.path.splitext(filepath)
fig.format = "png"
filepath = f"{filepath_}.png"
fig.render(filename=filepath_, view=False, cleanup=True)
elif is_plt:
fig.savefig(fname=filepath)
elif is_seaborn:
fig = fig.figure
fig.savefig(fname=filepath)
if return_filepath:
return filepath
[docs]def deprecate_arg(old_arg, new_arg, old_value, new_value):
"""Helper to raise warnings when a deprecated arg is used.
Args:
old_arg (str): Name of old/deprecated argument.
new_arg (str): Name of new argument.
old_value (Any): Value the user passed in for the old argument.
new_value (Any): Value the user passed in for the new argument.
Returns:
old_value if not None, else new_value
"""
value_to_use = new_value
if old_value is not None:
warnings.warn(
f"Argument '{old_arg}' has been deprecated in favor of '{new_arg}'. "
f"Passing '{old_arg}' in future versions will result in an error."
)
value_to_use = old_value
return value_to_use
[docs]def contains_all_ts_parameters(problem_configuration):
"""Validates that the problem configuration contains all required keys.
Args:
problem_configuration (dict): Problem configuration.
Returns:
bool, str: True if the configuration contains all parameters. If False, msg is a non-empty
string with error message.
"""
required_parameters = {"time_index", "gap", "max_delay", "forecast_horizon"}
msg = ""
if (
not problem_configuration
or not all(p in problem_configuration for p in required_parameters)
or problem_configuration["time_index"] is None
):
msg = (
"problem_configuration must be a dict containing values for at least the time_index, gap, max_delay, "
f"and forecast_horizon parameters, and time_index cannot be None. Received {problem_configuration}."
)
return not (msg), msg
_validation_result = namedtuple(
"TSParameterValidationResult",
("is_valid", "msg", "smallest_split_size", "max_window_size"),
)
[docs]def are_ts_parameters_valid_for_split(
gap, max_delay, forecast_horizon, n_obs, n_splits
):
"""Validates the time series parameters in problem_configuration are compatible with split sizes.
Args:
gap (int): gap value.
max_delay (int): max_delay value.
forecast_horizon (int): forecast_horizon value.
n_obs (int): Number of observations in the dataset.
n_splits (int): Number of cross validation splits.
Returns:
TsParameterValidationResult - named tuple with four fields
is_valid (bool): True if parameters are valid.
msg (str): Contains error message to display. Empty if is_valid.
smallest_split_size (int): Smallest split size given n_obs and n_splits.
max_window_size (int): Max window size given gap, max_delay, forecast_horizon.
"""
split_size = n_obs // (n_splits + 1)
window_size = gap + max_delay + forecast_horizon
msg = ""
if split_size <= window_size:
msg = (
f"Since the data has {n_obs} observations and n_splits={n_splits}, "
f"the smallest split would have {split_size} observations. "
f"Since {gap + max_delay + forecast_horizon} (gap + max_delay + forecast_horizon) >= {split_size}, "
"then at least one of the splits would be empty by the time it reaches the pipeline. "
"Please use a smaller number of splits, reduce one or more these parameters, or collect more data."
)
return _validation_result(not msg, msg, split_size, window_size)