Source code for evalml.pipelines.components.component_base

"""Base class for all components."""
import copy
from abc import ABC, abstractmethod

import cloudpickle

from evalml.exceptions import MethodPropertyNotFoundError
from evalml.pipelines.components.component_base_meta import ComponentBaseMeta
from evalml.utils import (
    classproperty,
    infer_feature_types,
    log_subtitle,
    safe_repr,
)
from evalml.utils.logger import get_logger


[docs]class ComponentBase(ABC, metaclass=ComponentBaseMeta): """Base class for all components. Args: parameters (dict): Dictionary of parameters for the component. Defaults to None. component_obj (obj): Third-party objects useful in component implementation. Defaults to None. random_seed (int): Seed for the random number generator. Defaults to 0. """ _default_parameters = None def __init__(self, parameters=None, component_obj=None, random_seed=0, **kwargs): """Base class for all components. Args: parameters (dict): Dictionary of parameters for the component. Defaults to None. component_obj (obj): Third-party objects useful in component implementation. Defaults to None. random_seed (int): Seed for the random number generator. Defaults to 0. """ self.random_seed = random_seed self._component_obj = component_obj self._parameters = parameters or {} self._is_fitted = False @property @classmethod @abstractmethod def name(cls): """Returns string name of this component.""" @property @classmethod @abstractmethod def modifies_features(cls): """Returns whether this component modifies (subsets or transforms) the features variable during transform. For Estimator objects, this attribute determines if the return value from `predict` or `predict_proba` should be used as features or targets. """ @property @classmethod @abstractmethod def modifies_target(cls): """Returns whether this component modifies (subsets or transforms) the target variable during transform. For Estimator objects, this attribute determines if the return value from `predict` or `predict_proba` should be used as features or targets. """ @property @classmethod @abstractmethod def training_only(cls): """Returns whether or not this component should be evaluated during training-time only, or during both training and prediction time.""" @classproperty def needs_fitting(self): """Returns boolean determining if component needs fitting before calling predict, predict_proba, transform, or feature_importances. This can be overridden to False for components that do not need to be fit or whose fit methods do nothing. Returns: True. """ return True @property def parameters(self): """Returns the parameters which were used to initialize the component.""" return copy.copy(self._parameters) @classproperty def default_parameters(cls): """Returns the default parameters for this component. Our convention is that Component.default_parameters == Component().parameters. Returns: dict: Default parameters for this component. """ if cls._default_parameters is None: cls._default_parameters = cls().parameters return cls._default_parameters @classproperty def _supported_by_list_API(cls): return not cls.modifies_target
[docs] def clone(self): """Constructs a new component with the same parameters and random state. Returns: A new instance of this component with identical parameters and random state. """ return self.__class__(**self.parameters, random_seed=self.random_seed)
[docs] def fit(self, X, y=None): """Fits component to data. Args: X (pd.DataFrame): The input training data of shape [n_samples, n_features] y (pd.Series, optional): The target training data of length [n_samples] Returns: self Raises: MethodPropertyNotFoundError: If component does not have a fit method or a component_obj that implements fit. """ X = infer_feature_types(X) if y is not None: y = infer_feature_types(y) try: self._component_obj.fit(X, y) return self except AttributeError: raise MethodPropertyNotFoundError( "Component requires a fit method or a component_obj that implements fit" )
[docs] def describe(self, print_name=False, return_dict=False): """Describe a component and its parameters. Args: print_name(bool, optional): whether to print name of component return_dict(bool, optional): whether to return description as dictionary in the format {"name": name, "parameters": parameters} Returns: None or dict: Returns dictionary if return_dict is True, else None. """ logger = get_logger(f"{__name__}.describe") if print_name: title = self.name log_subtitle(logger, title) for parameter in self.parameters: parameter_str = ("\t * {} : {}").format( parameter, self.parameters[parameter] ) logger.info(parameter_str) if return_dict: component_dict = {"name": self.name} component_dict.update({"parameters": self.parameters}) return component_dict
[docs] def save(self, file_path, pickle_protocol=cloudpickle.DEFAULT_PROTOCOL): """Saves component at file path. Args: file_path (str): Location to save file. pickle_protocol (int): The pickle data stream format. """ with open(file_path, "wb") as f: cloudpickle.dump(self, f, protocol=pickle_protocol)
[docs] @staticmethod def load(file_path): """Loads component at file path. Args: file_path (str): Location to load file. Returns: ComponentBase object """ with open(file_path, "rb") as f: return cloudpickle.load(f)
def __eq__(self, other): """Check for equality.""" if not isinstance(other, self.__class__): return False random_seed_eq = self.random_seed == other.random_seed if not random_seed_eq: return False attributes_to_check = ["_parameters", "_is_fitted"] for attribute in attributes_to_check: if getattr(self, attribute) != getattr(other, attribute): return False return True def __str__(self): """String representation of a component.""" return self.name def __repr__(self): """String representation of a component.""" parameters_repr = ", ".join( [f"{key}={safe_repr(value)}" for key, value in self.parameters.items()] ) return f"{(type(self).__name__)}({parameters_repr})"