Source code for chap_core.assessment.flat_representations

from __future__ import annotations
from enum import Enum
from typing import Mapping, List

import numpy as np
import pandas as pd
from pandera import Check
import pandera.pandas as pa
from pandera.pandas import DataFrameModel

from chap_core.database.tables import BackTestForecast
from chap_core.database.dataset_tables import ObservationBase
from chap_core.time_period import TimePeriod


[docs] class FlatData(DataFrameModel): """ Base class for data points that include location and time_period. """ location: pa.typing.Series[str] time_period: pa.typing.Series[str]
[docs] class FlatDataWithHorizon(FlatData): horizon_distance: pa.typing.Series[int]
[docs] class FlatObserved(FlatData): """ Observed disease cases """ disease_cases: pa.typing.Series[float] = pa.Field(nullable=True) # float to also allow nan
[docs] class FlatForecasts(FlatDataWithHorizon): """ Forecasted disease cases. Note that cases are in forecast field, and that samples is used so we can represent multiple samples per location/time_period/horizon_distance in the dataframe. """ sample: pa.typing.Series[int] # index of sample forecast: pa.typing.Series[float] # actual forecast value
[docs] class FlatMetric(FlatDataWithHorizon): metric: pa.typing.Series[float] = pa.Field(nullable=True)
[docs] def horizon_diff(period: str, period2: str) -> int: """Calculate the difference between two time periods in terms of time units.""" tp = TimePeriod.parse(period) tp2 = TimePeriod.parse(period2) return (tp - tp2) // tp.time_delta
def _convert_backtest_to_flat_forecasts(backtest_forecasts: List[BackTestForecast]) -> pd.DataFrame: """ Convert a list of BackTestForecast objects to a flat DataFrame format conforming to ForecastFlatDataSchema. Args: backtest_forecasts: List of BackTestForecast objects containing forecasts Returns: pd.DataFrame with columns: location, time_period, horizon_distance, sample, forecast """ dfs = [] for forecast in backtest_forecasts: # Calculate horizon distance using the horizon_diff function # horizon_distance represents how many time periods ahead this forecast is # from the last period we had data for horizon_distance = horizon_diff(str(forecast.period), str(forecast.last_seen_period)) # Each BackTestForecast contains multiple sample values # We need to create one row per sample df = _create_df(forecast, horizon_distance) dfs.append(df) """ for sample_idx, sample_value in enumerate(forecast.values): assert not np.isnan(sample_value), ("Sample values should not be NaN. " "Potentially something wrong with forecasts given by model") row = { "location": str(forecast.org_unit), "time_period": str(forecast.period), "horizon_distance": horizon_distance, "sample": sample_idx, "forecast": float(sample_value), # Convert to int as per schema } rows.append(row) """ # Create DataFrame from rows # df = pd.DataFrame(rows) df = pd.concat(dfs, ignore_index=True) assert len(df) > 0, "No forecast data found in backtest forecasts. Something wrong in model?" # Validate against schema # FlatForecasts.validate(df) return df def _create_df(forecast: BackTestForecast, horizon_distance: int): df = pd.DataFrame( { "location": str(forecast.org_unit), "time_period": str(forecast.period), "horizon_distance": horizon_distance, "sample": np.arange(len(forecast.values)), "forecast": forecast.values, } ) return df
[docs] def convert_backtest_to_flat_forecasts( backtest_forecasts: List[BackTestForecast], *, validate: bool = True ) -> pd.DataFrame: import numpy as np import pandas as pd total = sum(len(fc.values) for fc in backtest_forecasts) loc_col = np.empty(total, dtype=object) per_col = np.empty(total, dtype=object) hdist_col = np.empty(total, dtype=np.int64) sample_col = np.empty(total, dtype=np.int64) forecast_col = np.empty(total, dtype=np.float64) i = 0 for fc in backtest_forecasts: loc = str(fc.org_unit) per = str(fc.period) hdist = horizon_diff(per, str(fc.last_seen_period)) vals = np.asarray(fc.values) n = vals.shape[0] sl = slice(i, i + n) loc_col[sl] = loc per_col[sl] = per hdist_col[sl] = hdist sample_col[sl] = np.arange(n, dtype=np.int64) forecast_col[sl] = vals.astype(np.float64, copy=False) i += n df = pd.DataFrame( { "location": loc_col, "time_period": per_col, "horizon_distance": hdist_col, "sample": sample_col, "forecast": forecast_col, } ) # if validate: # FlatForecasts.validate(df) return df
[docs] def convert_backtest_observations_to_flat_observations( observations: List[ObservationBase], ) -> pd.DataFrame: """ Convert a list of ObservationBase objects to a flat DataFrame format conforming to ObservedFlatDataSchema. Args: observations: List of ObservationBase objects containing observations reference_period: Optional reference period to calculate horizon_distance from. If provided, horizon_distance will be calculated relative to this. If None, horizon_distance will be set to 0 for all observations. Returns: pd.DataFrame with columns: location, time_period, horizon_distance, disease_cases """ rows = [] for obs in observations: # Only process disease_cases observations if obs.feature_name == "disease_cases" and obs.value is not None: # Calculate horizon distance if reference period is provided row = {"location": str(obs.org_unit), "time_period": str(obs.period), "disease_cases": float(obs.value)} rows.append(row) # Create DataFrame from rows df = pd.DataFrame(rows) if not df.empty: # Validate against schema FlatObserved.validate(df) return df
[docs] def group_flat_forecast_by_horizon(flat_forecast_df: pd.DataFrame, aggregate_samples: bool = True) -> pd.DataFrame: """ Group flat forecast data by horizon distance for analysis. Args: flat_forecast_df: DataFrame conforming to ForecastFlatDataSchema aggregate_samples: If True, average across samples to get mean forecast Returns: pd.DataFrame grouped by location and horizon_distance """ if aggregate_samples: # Average across samples to get mean forecast per location/time_period/horizon grouped = flat_forecast_df.groupby(["location", "time_period", "horizon_distance"], as_index=False)[ "forecast" ].mean() else: grouped = flat_forecast_df return grouped
[docs] class DataDimension(str, Enum): """ Enum for the possible dimensions metrics datasets can have """ location = "location" time_period = "time_period" horizon_distance = "horizon_distance"
# Registry of types for each dimension DIM_REGISTRY: Mapping[DataDimension, tuple[type, Check | None]] = { DataDimension.location: (str, None), DataDimension.time_period: (str, None), DataDimension.horizon_distance: (int, Check.ge(0)), }