import cloudpickle
from .classification import (
LogisticRegressionPipeline,
RFClassificationPipeline,
XGBoostPipeline
)
from .regression import LinearRegressionPipeline, RFRegressionPipeline
from evalml.model_types import handle_model_types
from evalml.problem_types import handle_problem_types
ALL_PIPELINES = [RFClassificationPipeline, XGBoostPipeline, LogisticRegressionPipeline, LinearRegressionPipeline, RFRegressionPipeline]
[docs]def get_pipelines(problem_type, model_types=None):
"""Returns potential pipelines by model type
Arguments:
problem_type(ProblemTypes or str): the problem type the pipelines work for.
model_types(list[ModelTypes or str]): model types to match. if none, return all pipelines
Returns
pipelines, list of all pipeline
"""
if model_types is not None and not isinstance(model_types, list):
raise TypeError("model_types parameter is not a list.")
problem_pipelines = []
if model_types:
model_types = [handle_model_types(model_type) for model_type in model_types]
problem_type = handle_problem_types(problem_type)
for p in ALL_PIPELINES:
if problem_type in p.problem_types:
problem_pipelines.append(p)
if model_types is None:
return problem_pipelines
all_model_types = list_model_types(problem_type)
for model_type in model_types:
if model_type not in all_model_types:
raise RuntimeError("Unrecognized model type for problem type %s: %s" % (problem_type, model_type))
pipelines = []
for p in problem_pipelines:
if p.model_type in model_types:
pipelines.append(p)
return pipelines
[docs]def list_model_types(problem_type):
"""List model type for a particular problem type
Arguments:
problem_types (ProblemType or str): binary, multiclass, or regression
Returns:
model_types, list of model types
"""
problem_pipelines = []
problem_type = handle_problem_types(problem_type)
for p in ALL_PIPELINES:
if problem_type in p.problem_types:
problem_pipelines.append(p)
return list(set([p.model_type for p in problem_pipelines]))
[docs]def save_pipeline(pipeline, file_path):
"""Saves pipeline at file path
Args:
file_path (str) : location to save file
Returns:
None
"""
with open(file_path, 'wb') as f:
cloudpickle.dump(pipeline, f)
[docs]def load_pipeline(file_path):
"""Loads pipeline at file path
Args:
file_path (str) : location to load file
Returns:
Pipeline obj
"""
with open(file_path, 'rb') as f:
return cloudpickle.load(f)