Source code for chap_core.models.external_chapkit_model

import logging
import pandas as pd
from chap_core.external.model_configuration import ModelTemplateConfigV2
from chap_core.models.external_model import ExternalModelBase
from chap_core.models.chapkit_rest_api_wrapper import CHAPKitRestAPIWrapper
from chap_core.spatio_temporal_data.temporal_dataclass import DataSet


from chap_core.datatypes import Samples

logger = logging.getLogger(__name__)


[docs] class ExternalChapkitModelTemplate: """Wrapper around External models that are based on chapkit. Note that get_model assumes you have already created a configuration with that specific chapkitmodel. This method is meant to be backwards compatible with ExternalModelTemplate """ def __init__(self, rest_api_url: str): self.rest_api_url = rest_api_url self.client = CHAPKitRestAPIWrapper(rest_api_url) # assert self.is_healthy(), f"Service at {rest_api_url} is not healthy. Is model running? Check {self.rest_api_url}/health"
[docs] def wait_for_healthy(self, timeout=60): import time start_time = time.time() while time.time() - start_time < timeout: if self.is_healthy(): return True time.sleep(2) raise TimeoutError( f"Model service at {self.rest_api_url} did not become healthy within {timeout} seconds. Check {self.rest_api_url}/health" )
[docs] def is_healthy(self) -> bool: try: response = self.client.health() return response["status"] == "healthy" except Exception as e: logger.info( f"Health check for model {self.rest_api_url} failed: {e}. Check health at {self.rest_api_url}/health" ) return False
[docs] def get_model(self, model_configuration: dict) -> "ExternalChapkitModel": """ Sends the model configuration for storing in the model (by sending to the model rest api). This returns a configuration id back that we can use to identify the model. """ import time if model_configuration is None: model_configuration = {} else: model_configuration = dict(model_configuration) timestamp = int(time.time() * 1000000) if "name" not in model_configuration: name = f"{self.name}_config_{timestamp}" else: # always make sure config has unique name for now. Chapkit uses name as identifier, # but we don't necesserarily do that on the chap side name = model_configuration["name"] + "_" + str(timestamp) if "model_template" in model_configuration: # remove model_template key model_configuration.pop("model_template") config_data = {"name": name, "data": model_configuration} logger.info(f"Creating model configuration with name {name} at {self.rest_api_url}. Data: {config_data}") # Create config with proper structure for new API # Use timestamp to make name unique # config_data = { # "name": model_configuration.get("name", f"{self.name}_config_{timestamp}"), # "data": model_configuration # } config_response = self.client.create_config(config_data) configuration_id = config_response["id"] # get all configs and assert that configuration_id is there all_configs = self.client.list_configs() assert any(cfg["id"] == configuration_id for cfg in all_configs), ( f"Created configuration {configuration_id} not found in list of configs" ) logger.info(f"Created model configuration with id {configuration_id} at {self.rest_api_url}") return ExternalChapkitModel(self.name, self.rest_api_url, configuration_id=configuration_id)
@property def name(self): """ This returns a unique name for the model. In the future, this might be some sort of id given by the model """ info = self.client.info() if "name" in info: # name not supported in current chapkit version, might be supported in the future name = info["name"] else: name = info["display_name"].lower().replace(" ", "_") version = info.get("version") return f"{name}_v{version}"
[docs] def get_model_template_config(self) -> ModelTemplateConfigV2: """ This method is meant to make things backwards compatible with old system. An object of type ModelTemplateConfigV2 is needed to store info about a ModelTemplate in the database. """ model_info = self.client.info() # Get user options from config schema config_schema = self.client.get_config_schema() print(config_schema) user_options = {} if "$defs" in config_schema and "ModelConfiguration" in config_schema["$defs"]: user_options = config_schema["$defs"]["ModelConfiguration"].get("properties", {}) # Build metadata dict from info endpoint meta_data_dict = { "display_name": model_info.get("display_name", "No Display Name"), "description": model_info.get("description") or model_info.get("summary", "No Description"), "author_note": model_info.get("author_note", ""), "author_assessed_status": model_info.get("author_assessed_status", "red"), "author": model_info.get("author", "Unknown Author"), "organization": model_info.get("organization"), "organization_logo_url": model_info.get("organization_logo_url"), "contact_email": model_info.get("contact_email"), "citation_info": model_info.get("citation_info"), } # Build complete config dict config_dict = { "name": self.name, "rest_api_url": self.rest_api_url, "meta_data": meta_data_dict, "required_covariates": model_info.get("required_covariates", []), "allow_free_additional_continuous_covariates": model_info.get( "allow_free_additional_continuous_covariates", False ), "user_options": user_options, # ModelTemplateInformation fields will use defaults if not provided: # - supported_period_type defaults to PeriodType.any # - target defaults to "disease_cases" # RunnerConfig fields not needed for REST API models: "entry_points": None, "docker_env": None, "python_env": None, "source_url": self.rest_api_url, } return ModelTemplateConfigV2.model_validate(config_dict)
[docs] class ExternalChapkitModel(ExternalModelBase): def __init__(self, model_name: str, rest_api_url: str, configuration_id: str): self.model_name = model_name self.rest_api_url = rest_api_url self.configuration_id = configuration_id self._location_mapping = None self._adapters = None self.client = CHAPKitRestAPIWrapper(rest_api_url) self._train_id = None
[docs] def train(self, train_data: DataSet, extra_args=None): frequency = self._get_frequency(train_data) pd = train_data.to_pandas() new_pd = self._adapt_data(pd, frequency=frequency) geo = train_data.polygons response = self.client.train_and_wait(self.configuration_id, new_pd, geo) if response["status"] == "failed": raise RuntimeError(f"Training failed: {response.get('error', 'Unknown error')}") artifact_id = response["model_artifact_id"] assert artifact_id is not None, response self._train_id = artifact_id return self
[docs] def predict(self, historic_data: DataSet, future_data: DataSet) -> DataSet: assert self._train_id is not None, "Model must be trained before prediction" geo = historic_data.polygons historic_data_pd = self._adapt_data(historic_data.to_pandas()) future_data_pd = self._adapt_data(future_data.to_pandas()) response = self.client.predict_and_wait( model_artifact_id=self._train_id, future_data=future_data_pd, historic_data=historic_data_pd, geo_features=geo, ) if response["status"] == "failed": raise RuntimeError(f"Prediction failed: {response.get('error', 'Unknown error')}") artifact_id = response["prediction_artifact_id"] assert artifact_id is not None, response.get("error", "No prediction artifact") # get artifact from the client prediction = self.client.get_artifact(artifact_id) data = prediction["data"]["predictions"] return DataSet.from_pandas(pd.DataFrame(data=data["data"], columns=data["columns"]), Samples)