Source code for chap_core.runners.mlflow_runner
from chap_core.exceptions import ModelFailedException
from chap_core.runners.runner import TrainPredictRunner
import mlflow.exceptions
import mlflow.projects
from mlflow.utils.process import ShellCommandException
import logging
logger = logging.getLogger(__name__)
[docs]
class MlFlowTrainPredictRunner(TrainPredictRunner):
def __init__(self, model_path, model_configuration_filename=None, train_params=None):
self.model_path = model_path
self.model_configuration_filename = model_configuration_filename
# This logic should probably be a better
# Find out which parameters are used in the MLproject file
# Assumes now that the extra parameters are the same in train and predict
if train_params is None:
self.extra_params = []
else:
self.extra_params = [key for key in train_params if key not in ["train_data", "model"]]
[docs]
def train(self, train_file_name, model_file_name, polygons_file_name=None):
try:
# train_file_name = Path(self.model_path) / Path(train_file_name)
keys = {"train_data": str(train_file_name), "model": str(model_file_name)}
logger.info("Training model using MLflow, working dir is %s. Train data: %s" % (self.model_path, keys))
possible_extra = {
"model_config": str(self.model_configuration_filename) if self.model_configuration_filename else None,
}
keys.update({key: val for key, val in possible_extra.items() if key in self.extra_params})
return mlflow.projects.run(
str(self.model_path),
entry_point="train",
parameters=keys,
build_image=True,
)
except ShellCommandException as e:
logger.error(
"Error running mlflow project, might be due to missing pyenv (See: https://github.com/pyenv/pyenv#installation)"
)
raise ModelFailedException(str(e)) from e
except mlflow.exceptions.ExecutionException as e:
logger.error("Executation of model failed for some reason. Check the logs for more information")
raise ModelFailedException(str(e)) from e
[docs]
def predict(self, model_file_name, historic_data, future_data, output_file, polygons_file_name=None):
logging.info("Running predict with output to %s" % output_file)
if self.model_configuration_filename is not None:
("Model configuration not supported for MLflow runner")
params = {
"historic_data": str(historic_data),
"future_data": str(future_data),
"model": str(model_file_name),
"out_file": str(output_file),
}
logging.info("Params for predict: %s" % params)
extra_params = {
"model_config": str(self.model_configuration_filename) if self.model_configuration_filename else None,
}
params.update({key: val for key, val in extra_params.items() if key in self.extra_params})
return mlflow.projects.run(
str(self.model_path),
entry_point="predict",
parameters=params,
)