Source code for evalml.model_family.model_family

"""Enum for family of machine learning models."""
from enum import Enum


[docs]class ModelFamily(Enum): """Enum for family of machine learning models.""" K_NEIGHBORS = "k_neighbors" """K Nearest Neighbors model family.""" RANDOM_FOREST = "random_forest" """Random Forest model family.""" SVM = "svm" """SVM model family.""" XGBOOST = "xgboost" """XGBoost model family.""" LIGHTGBM = "lightgbm" """LightGBM model family.""" LINEAR_MODEL = "linear_model" """Linear model family.""" CATBOOST = "catboost" """CatBoost model family.""" EXTRA_TREES = "extra_trees" """Extra Trees model family.""" ENSEMBLE = "ensemble" """Ensemble model family.""" DECISION_TREE = "decision_tree" """Decision Tree model family.""" EXPONENTIAL_SMOOTHING = "exponential_smoothing" """Exponential Smoothing model family.""" ARIMA = "arima" """ARIMA model family.""" VARMAX = "varmax" """VARMAX model family.""" BASELINE = "baseline" """Baseline model family.""" PROPHET = "prophet" """Prophet model family.""" VOWPAL_WABBIT = "vowpal_wabbit" """Vowpal Wabbit model family.""" NONE = "none" """None""" def __str__(self): """String representation of a ModelFamily enum.""" model_family_dict = { ModelFamily.K_NEIGHBORS.name: "K Nearest Neighbors", ModelFamily.RANDOM_FOREST.name: "Random Forest", ModelFamily.SVM.name: "SVM", ModelFamily.XGBOOST.name: "XGBoost", ModelFamily.LIGHTGBM.name: "LightGBM", ModelFamily.LINEAR_MODEL.name: "Linear", ModelFamily.CATBOOST.name: "CatBoost", ModelFamily.EXTRA_TREES.name: "Extra Trees", ModelFamily.DECISION_TREE.name: "Decision Tree", ModelFamily.BASELINE.name: "Baseline", ModelFamily.ENSEMBLE.name: "Ensemble", ModelFamily.EXPONENTIAL_SMOOTHING.name: "Exponential Smoothing", ModelFamily.ARIMA.name: "ARIMA", ModelFamily.VARMAX.name: "VARMAX", ModelFamily.PROPHET.name: "Prophet", ModelFamily.NONE.name: "None", } return model_family_dict[self.name] def __repr__(self): """String representation of a ModelFamily enum.""" return "ModelFamily." + self.name
[docs] def is_tree_estimator(self): """Checks whether the estimator's model family uses trees.""" tree_estimators = { self.CATBOOST, self.EXTRA_TREES, self.RANDOM_FOREST, self.DECISION_TREE, self.XGBOOST, self.LIGHTGBM, } return self in tree_estimators