import dataclasses
import datetime
import json
import logging
# CHeck if CHAP_DATABASE_URL is set in the environment
import os
from pathlib import Path
import time
from typing import Iterable, List, Optional
import psycopg2
import sqlalchemy
from sqlalchemy.orm import selectinload
from sqlmodel import Session, SQLModel, create_engine, select
from chap_core.datatypes import FullData, SamplesWithTruth, create_tsdataclass
from chap_core.geometry import Polygons
from chap_core.log_config import is_debug_mode
from chap_core.predictor.naive_estimator import NaiveEstimator
from chap_core.time_period import TimePeriod, Month, Week
from .. import ModelTemplateInterface
from ..external.model_configuration import ModelTemplateConfigV2
from ..models import ModelTemplate
from ..models.configured_model import ConfiguredModel
from ..models.external_chapkit_model import ExternalChapkitModelTemplate
from ..rest_api.data_models import BackTestCreate
from ..spatio_temporal_data.converters import observations_to_dataset
from ..spatio_temporal_data.temporal_dataclass import DataSet as _DataSet
from .dataset_tables import DataSet, Observation, DataSetCreateInfo, DataSetInfo
from .debug import DebugEntry
from .model_spec_tables import ModelSpecRead
from .model_templates_and_config_tables import ConfiguredModelDB, ModelConfiguration, ModelTemplateDB
from .tables import BackTest, BackTestForecast, Prediction, PredictionSamplesEntry
logger = logging.getLogger(__name__)
engine = None
database_url = os.getenv("CHAP_DATABASE_URL", default=None)
logger.info(f"Database url: {database_url}")
if database_url is not None:
n = 0
while n < 30:
try:
engine = create_engine(database_url, echo=is_debug_mode())
break
except sqlalchemy.exc.OperationalError as e:
logger.error(f"Failed to connect to database: {e}. Trying again")
n += 1
time.sleep(1)
except psycopg2.OperationalError as e:
logger.error(f"Failed to connect to database: {e}. Trying again")
n += 1
time.sleep(1)
else:
raise ValueError("Failed to connect to database")
else:
logger.warning("Database url not set. Database operations will not work")
[docs]
class SessionWrapper:
"""
This is a wrapper around data access operations.
This class handles cases when putting things in/out of db requires
more than just adding/getting a row, e.g. transforming data etc.
"""
def __init__(self, local_engine=None, session=None):
self.engine = local_engine # or engine
self.session: Optional[Session] = session
def __enter__(self):
self.session = Session(self.engine)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.session.close()
return False
[docs]
def list_all(self, model):
return self.session.exec(select(model)).all()
[docs]
def create_if_not_exists(self, model, id_name="id"):
logging.info(f"Create if not exist with {model}")
T = type(model)
if not self.session.exec(select(T).where(getattr(T, id_name) == getattr(model, id_name))).first():
self.session.add(model)
self.session.commit()
return model
[docs]
def add_model_template(self, model_template: ModelTemplateDB) -> int:
# check if model template already exists
existing_template = self.session.exec(
select(ModelTemplateDB).where(ModelTemplateDB.name == model_template.name)
).first()
if existing_template:
logger.info(f"Model template with name {model_template.name} already exists. Returning existing id")
return existing_template.id
# add db entry
logger.info(f"Adding model template: {model_template}")
self.session.add(model_template)
self.session.commit()
# return id
return model_template.id
[docs]
def add_model_template_from_yaml_config(self, model_template_config: ModelTemplateConfigV2) -> int:
"""Sets the ModelSpecRead a yaml string.
Note that the yaml string is what's defined in a model template's MLProject file,
so source_url will have to be added manually."""
# TODO: maybe just use add_model_template and make sure to structure it correctly first
# TODO: needs cleanup
# TODO: existing check should probably use name instead of source url
# parse yaml content as dict
existing_template = self.session.exec(
select(ModelTemplateDB).where(ModelTemplateDB.name == model_template_config.name)
).first()
d = model_template_config.model_dump()
info = d.pop("meta_data")
d = d | info
if existing_template:
logger.info(f"Model template with name {model_template_config.name} already exists. Updating it")
# Update the existing template with new data
for key, value in d.items():
if hasattr(existing_template, key):
setattr(existing_template, key, value)
self.session.commit()
return existing_template.id
# Create new template
db_object = ModelTemplateDB(**d)
logger.info(f"Adding model template: {db_object}")
self.session.add(db_object)
self.session.commit()
return db_object.id
[docs]
def get_model_template(self, model_template_id: int) -> ModelTemplateInterface:
model_template = self.session.get(ModelTemplateDB, model_template_id)
if model_template is None:
raise ValueError(f"Model template with id {model_template_id} not found")
return model_template
[docs]
def get_backtest_with_truth(self, backtest_id: int) -> BackTest:
backtest = self.session.get(BackTest, backtest_id)
if backtest is None:
raise ValueError(f"Backtest with id {backtest_id} not found")
dataset = backtest.dataset
if dataset is None:
raise ValueError(f"Dataset for backtest with id {backtest_id} not found")
entries = backtest.forecasts
if entries is None or len(entries) == 0:
raise ValueError(f"No forecasts found for backtest with id {backtest_id}")
[docs]
def add_evaluation_results(
self,
evaluation_results: Iterable[_DataSet[SamplesWithTruth]],
last_train_period: TimePeriod,
info: BackTestCreate,
):
info.created = datetime.datetime.now()
# org_units = list({location for ds in evaluation_results for location in ds.locations()})
# split_points = list({er.period_range[0] for er in evaluation_results})
model_db = self.session.exec(select(ConfiguredModelDB).where(ConfiguredModelDB.name == info.model_id)).first()
model_db_id = model_db.id
backtest = BackTest(
**info.dict() | {"model_db_id": model_db_id, "model_template_version": model_db.model_template.version}
)
self.session.add(backtest)
org_units = set([])
split_points = set([])
# define metrics (for each period)
evaluation_results = list(
evaluation_results
) # hacky, to avoid metric funcs using up the iterable before we can loop all splitpoints
for eval_result in evaluation_results:
first_period: TimePeriod = eval_result.period_range[0]
split_points.add(first_period.id)
for location, samples_with_truth in eval_result.items():
# NOTE: samples_with_truth is class datatypes.SamplesWithTruth
org_units.add(location)
for period, sample_values, disease_cases in zip(
eval_result.period_range, samples_with_truth.samples, samples_with_truth.disease_cases
):
# add forecast series for this period
forecast = BackTestForecast(
period=period.id,
org_unit=location,
last_train_period=last_train_period.id,
last_seen_period=first_period.id,
values=sample_values.tolist(),
)
backtest.forecasts.append(forecast)
backtest.org_units = list(org_units)
backtest.split_periods = list(split_points)
self.session.commit()
# read full backtest object so that we can compute aggregate metrics using this object
# note: we don't do this anymore
# backtest = self.session.get(BackTest, backtest.id)
# aggregate_metrics = compute_all_aggregated_metrics_from_backtest(backtest)
# logger.info(f"aggregate metrics {aggregate_metrics}")
# backtest.aggregate_metrics = aggregate_metrics
# self.session.commit()
# add more
return backtest.id
[docs]
def add_predictions(self, predictions, dataset_id, model_id, name, metadata: dict = {}):
n_periods = len(list(predictions.values())[0])
samples_ = [
PredictionSamplesEntry(period=period.id, org_unit=location, values=value.tolist())
for location, data in predictions.items()
for period, value in zip(data.time_period, data.samples)
]
org_units = list(predictions.keys())
model_db_id = self.session.exec(select(ConfiguredModelDB.id).where(ConfiguredModelDB.name == model_id)).first()
prediction = Prediction(
dataset_id=dataset_id,
model_id=model_id,
name=name,
created=datetime.datetime.now(),
n_periods=n_periods,
meta_data=metadata,
forecasts=samples_,
org_units=org_units,
model_db_id=model_db_id,
)
self.session.add(prediction)
self.session.commit()
return prediction.id
[docs]
def add_dataset_from_csv(self, name: str, csv_path: Path, geojson_path: Optional[Path] = None):
dataset = _DataSet.from_csv(csv_path, dataclass=FullData)
geojson_content = open(geojson_path, "r").read() if geojson_path else None
features = None
if geojson_content is not None:
features = Polygons.from_geojson(json.loads(geojson_content), id_property="NAME_1").feature_collection()
features = features.model_dump_json()
return self.add_dataset(DataSetCreateInfo(name=name), dataset, features)
[docs]
def add_dataset(self, dataset_info: DataSetCreateInfo, orig_dataset: _DataSet, polygons):
"""
Add a dataset to the database. The dataset is provided as a spatio-temporal dataclass.
The polygons should be provided as a geojson feature collection.
The dataset_info should contain information about the dataset, such as its name and data sources.
The function sets some derived fields in the dataset_info, such as the first and last time period and the covariates.
The function returns the id of the newly created dataset.
"""
logger.info(
f"Adding dataset {dataset_info.name} with {len(list(orig_dataset.locations()))} locations and {len(orig_dataset.period_range)} time periods"
)
field_names = [
field.name
for field in dataclasses.fields(next(iter(orig_dataset.values())))
if field.name not in ["time_period", "location"]
]
logger.info(f"Field names in dataset: {field_names}")
if isinstance(orig_dataset.period_range[0], Month):
period_type = "month"
else:
assert isinstance(orig_dataset.period_range[0], Week), orig_dataset.period_range[0]
period_type = "week"
full_info = DataSetInfo(
first_period=orig_dataset.period_range[0].id,
last_period=orig_dataset.period_range[-1].id,
covariates=field_names,
created=datetime.datetime.now(),
org_units=list(orig_dataset.locations()),
period_type=period_type,
**dataset_info.model_dump(),
)
dataset = DataSet(geojson=polygons, **full_info.model_dump())
for location, data in orig_dataset.items():
field_names = [
field.name for field in dataclasses.fields(data) if field.name not in ["time_period", "location"]
]
for row in data:
for field in field_names:
observation = Observation(
period=row.time_period.id,
org_unit=location,
value=float(getattr(row, field)),
feature_name=field,
)
dataset.observations.append(observation)
self.session.add(dataset)
self.session.commit()
assert self.session.exec(select(Observation).where(Observation.dataset_id == dataset.id)).first() is not None
return dataset.id
[docs]
def get_dataset(self, dataset_id, dataclass: type | None = None) -> _DataSet:
dataset = self.session.get(DataSet, dataset_id)
if dataclass is None:
logger.info(f"Getting dataset with covariates: {dataset.covariates} and name: {dataset.name}")
field_names = dataset.covariates
dataclass = create_tsdataclass(field_names)
observations = dataset.observations
new_dataset = observations_to_dataset(dataclass, observations)
return new_dataset
[docs]
def get_dataset_by_name(self, dataset_name: str) -> Optional[DataSet]:
dataset = self.session.exec(select(DataSet).where(DataSet.name == dataset_name)).first()
return dataset
[docs]
def add_debug(self):
"""Function for debuging"""
debug_entry = DebugEntry(timestamp=time.time())
self.session.add(debug_entry)
self.session.commit()
return debug_entry.id
def _run_alembic_migrations(engine):
"""
Run Alembic migrations programmatically.
This is called after the custom migration system to apply any Alembic-managed schema changes.
"""
from alembic import command
from alembic.config import Config
logger.info("Running Alembic migrations")
try:
# Get the path to alembic.ini relative to the project root
project_root = Path(__file__).parent.parent.parent
alembic_ini_path = project_root / "alembic.ini"
if not alembic_ini_path.exists():
logger.warning(f"Alembic config not found at {alembic_ini_path}. Skipping Alembic migrations.")
return
# Create Alembic config
alembic_cfg = Config(str(alembic_ini_path))
# Pass the engine connection to Alembic for programmatic usage
with engine.connect() as connection:
alembic_cfg.attributes["connection"] = connection
command.upgrade(alembic_cfg, "head")
logger.info("Completed Alembic migrations successfully")
except Exception as e:
logger.error(f"Error during Alembic migrations: {e}", exc_info=True)
# Don't raise - allow system to continue if Alembic fails
# This ensures backward compatibility
[docs]
def create_db_and_tables():
# TODO: Read config for options on how to create the database migrate/update/seed/seed_and_update
if engine is not None:
logger.info("Engine set. Creating tables")
n = 0
while n < 30:
try:
# Step 1: Run custom migrations for backward compatibility (v1.0.17, etc.)
_run_generic_migration(engine)
# Step 2: Create any new tables that don't exist yet
SQLModel.metadata.create_all(engine)
# Step 3: Run Alembic migrations for future schema changes
_run_alembic_migrations(engine)
logger.info("Table created and migrations completed")
break
except sqlalchemy.exc.OperationalError as e:
logger.error(f"Failed to create tables: {e}. Trying again")
n += 1
time.sleep(1)
if n >= 20:
raise e
with SessionWrapper(engine) as session:
from .model_template_seed import seed_configured_models_from_config_dir
seed_configured_models_from_config_dir(session.session)
# seed_example_datasets(session)
else:
logger.warning("Engine not set. Tables not created")
def _run_v1_0_17_migrations(conn, engine):
"""
Specific migrations needed when upgrading from v1.0.17 to current version.
This handles data type conversions and corrections that the generic migration cannot handle.
"""
logger.info("Running v1.0.17 specific migrations")
inspector = sqlalchemy.inspect(engine)
existing_tables = inspector.get_table_names()
try:
# Fix: Check if modeltemplatedb table exists and has corrupted data
if "modeltemplatedb" in existing_tables:
logger.info("Checking for corrupted PeriodType enum values in modeltemplatedb")
# Check if there are any rows with invalid enum values (like '1' instead of 'week', 'month', etc.)
check_sql = """
SELECT COUNT(*) as count
FROM modeltemplatedb
WHERE supported_period_type NOT IN ('week', 'month', 'any', 'year')
"""
result = conn.execute(sqlalchemy.text(check_sql)).fetchone()
if result and result[0] > 0:
logger.warning(f"Found {result[0]} rows with corrupted PeriodType enum values, fixing...")
# Map common corrupted values to correct enum values
# If the value is '1', we'll default to 'any' as it's the most permissive
fix_sql = """
UPDATE modeltemplatedb
SET supported_period_type = 'any'::periodtype
WHERE supported_period_type NOT IN ('week', 'month', 'any', 'year')
"""
conn.execute(sqlalchemy.text(fix_sql))
conn.commit()
logger.info("Fixed corrupted PeriodType enum values")
# Fix: Ensure JSON columns that should be arrays are arrays, not objects
if "dataset" in existing_tables:
columns = {col["name"] for col in inspector.get_columns("dataset")}
# Fix data_sources if it exists and contains objects instead of arrays
if "data_sources" in columns:
logger.info("Fixing data_sources column in dataset table")
fix_sql = """
UPDATE dataset
SET data_sources = '[]'::json
WHERE data_sources IS NULL OR data_sources::text = '{}'
"""
conn.execute(sqlalchemy.text(fix_sql))
conn.commit()
# Fix org_units if it exists and contains objects instead of arrays
if "org_units" in columns:
logger.info("Fixing org_units column in dataset table")
fix_sql = """
UPDATE dataset
SET org_units = '[]'::json
WHERE org_units IS NULL OR org_units::text = '{}'
"""
conn.execute(sqlalchemy.text(fix_sql))
conn.commit()
# Fix covariates if it contains objects instead of arrays
if "covariates" in columns:
logger.info("Fixing covariates column in dataset table")
fix_sql = """
UPDATE dataset
SET covariates = '[]'::json
WHERE covariates IS NULL OR covariates::text = '{}'
"""
conn.execute(sqlalchemy.text(fix_sql))
conn.commit()
# Fix backtest table JSON columns
if "backtest" in existing_tables:
columns = {col["name"] for col in inspector.get_columns("backtest")}
if "org_units" in columns:
logger.info("Fixing org_units column in backtest table")
fix_sql = """
UPDATE backtest
SET org_units = '[]'::json
WHERE org_units IS NULL OR org_units::text = '{}'
"""
conn.execute(sqlalchemy.text(fix_sql))
conn.commit()
if "split_periods" in columns:
logger.info("Fixing split_periods column in backtest table")
fix_sql = """
UPDATE backtest
SET split_periods = '[]'::json
WHERE split_periods IS NULL OR split_periods::text = '{}'
"""
conn.execute(sqlalchemy.text(fix_sql))
conn.commit()
# Fix configuredmodeldb table JSON columns
if "configuredmodeldb" in existing_tables:
columns = {col["name"] for col in inspector.get_columns("configuredmodeldb")}
if "user_option_values" in columns:
logger.info("Fixing user_option_values column in configuredmodeldb table")
# This one should actually be an object {}, not an array
fix_sql = """
UPDATE configuredmodeldb
SET user_option_values = '{}'::json
WHERE user_option_values IS NULL
"""
conn.execute(sqlalchemy.text(fix_sql))
conn.commit()
if "additional_continuous_covariates" in columns:
logger.info("Fixing additional_continuous_covariates column in configuredmodeldb table")
fix_sql = """
UPDATE configuredmodeldb
SET additional_continuous_covariates = '[]'::json
WHERE additional_continuous_covariates IS NULL OR additional_continuous_covariates::text = '{}'
"""
conn.execute(sqlalchemy.text(fix_sql))
conn.commit()
logger.info("Completed v1.0.17 specific migrations successfully")
except Exception as e:
logger.error(f"Error during v1.0.17 migrations: {e}")
conn.rollback()
raise
def _run_generic_migration(engine):
"""
Generic migration function that adds missing columns to existing tables
and sets default values for new columns in existing records.
"""
logger.info("Running generic migration for missing columns")
with engine.connect() as conn:
# Run v1.0.17 specific migrations first
# _run_v1_0_17_migrations(conn, engine)
# Get current database schema
inspector = sqlalchemy.inspect(engine)
existing_tables = inspector.get_table_names()
for table_name, table in SQLModel.metadata.tables.items():
if table_name not in existing_tables:
logger.info(f"Table {table_name} doesn't exist yet, will be created")
continue
# Get existing columns in the database
existing_columns = {col["name"] for col in inspector.get_columns(table_name)}
# Check for missing columns
for column in table.columns:
if column.name not in existing_columns:
logger.info(f"Adding missing column {column.name} to table {table_name}")
# Add the column
column_type = column.type.compile(dialect=engine.dialect)
add_column_sql = f"ALTER TABLE {table_name} ADD COLUMN {column.name} {column_type}"
try:
conn.execute(sqlalchemy.text(add_column_sql))
# Set default value based on column type and properties
default_value = _get_column_default_value(column)
if default_value is not None:
update_sql = (
f"UPDATE {table_name} SET {column.name} = :default_val WHERE {column.name} IS NULL"
)
conn.execute(sqlalchemy.text(update_sql), {"default_val": default_value})
logger.info(f"Set default value for {column.name} in existing records")
conn.commit()
logger.info(f"Successfully added column {column.name} to {table_name}")
except Exception as e:
logger.error(f"Failed to add column {column.name} to {table_name}: {e}")
conn.rollback()
def _get_column_default_value(column):
"""
Determine appropriate default value for a column based on its type and properties.
"""
# Check if column has a default value defined
if column.default is not None:
if hasattr(column.default, "arg") and column.default.arg is not None:
# Check if it's a factory function (like list or dict)
if callable(column.default.arg):
try:
result = column.default.arg()
if isinstance(result, list):
return "[]"
elif isinstance(result, dict):
return "{}"
except Exception:
pass
return column.default.arg
# Check column type and provide appropriate defaults
column_type = str(column.type).lower()
if "json" in column_type or "pydanticlisttype" in column_type:
# For JSON columns, only default to [] if there's a default_factory set
# If there's no explicit default, use NULL (safer for Optional[dict] fields)
# Most list JSON columns have default_factory=list which is handled above
return None
elif "varchar" in column_type or "text" in column_type:
return "" # Empty string for text columns
elif "integer" in column_type or "numeric" in column_type:
return 0 # Zero for numeric columns
elif "boolean" in column_type:
return False # False for boolean columns
elif "timestamp" in column_type or "datetime" in column_type:
return None # Let NULL remain for timestamps
return None # Default to NULL for unknown types