from sklearn.tree import DecisionTreeClassifier as SKDecisionTreeClassifier from skopt.space import Integer from evalml.model_family import ModelFamily from evalml.pipelines.components.estimators import Estimator from evalml.problem_types import ProblemTypes from evalml.utils import deprecate_arg [docs]class DecisionTreeClassifier(Estimator): """Decision Tree Classifier.""" name = "Decision Tree Classifier" hyperparameter_ranges = { "criterion": ["gini", "entropy"], "max_features": ["auto", "sqrt", "log2"], "max_depth": Integer(4, 10) } model_family = ModelFamily.DECISION_TREE supported_problem_types = [ProblemTypes.BINARY, ProblemTypes.MULTICLASS, ProblemTypes.TIME_SERIES_BINARY, ProblemTypes.TIME_SERIES_MULTICLASS] [docs] def __init__(self, criterion="gini", max_features="auto", max_depth=6, min_samples_split=2, min_weight_fraction_leaf=0.0, random_state=None, random_seed=0, **kwargs): parameters = {"criterion": criterion, "max_features": max_features, "max_depth": max_depth, "min_samples_split": min_samples_split, "min_weight_fraction_leaf": min_weight_fraction_leaf} parameters.update(kwargs) random_seed = deprecate_arg("random_state", "random_seed", random_state, random_seed) dt_classifier = SKDecisionTreeClassifier(random_state=random_seed, **parameters) super().__init__(parameters=parameters, component_obj=dt_classifier, random_seed=random_seed)