Source code for chap_core.hpo.objective

from typing import Literal, Optional

from chap_core.models.model_template import ModelTemplate
from chap_core.database.model_templates_and_config_tables import ModelConfiguration
from chap_core.assessment.prediction_evaluator import evaluate_model
from chap_core.exceptions import NoPredictionsError
from chap_core.file_io.example_data_set import DataSetType

import logging

logger = logging.getLogger()
logger.setLevel(logging.INFO)


[docs] class Objective: def __init__( self, model_template: ModelTemplate, metric: str = "MSE", prediction_length: int = 3, # 6, n_splits: int = 4, ignore_environment: bool = False, debug: bool = False, log_file: Optional[str] = None, run_directory_type: Optional[Literal["latest", "timestamp", "use_existing"]] = "timestamp", ): self.model_template = model_template self.metric = metric self.prediction_length = prediction_length self.n_splits = n_splits def __call__(self, config, dataset: Optional[DataSetType] = None) -> float: """ This method takes a concrete configuration produced by a Searcher, runs model evaluation, and returns a scalar score of the selected metric. """ logger.info("Validating model configuration") model_configs = {"user_option_values": config} # TODO: should prob be removed model_config = ModelConfiguration.model_validate(model_configs) logger.info("Validated model configuration") model = self.model_template.get_model(model_config) model = model() try: # evaluate_model should handle CV/nested CV and return mean results # stratified fold/splits results = evaluate_model( model, dataset, prediction_length=self.prediction_length, n_test_sets=self.n_splits, ) except NoPredictionsError as e: logger.error(f"No predictions were made: {e}") return # maybe return float("inf") here? return results[0][self.metric]