Source code for chap_core.assessment.metrics.base

"""
Base classes for all metrics.
"""

from dataclasses import dataclass
import pandas as pd
import pandera.pandas as pa
from chap_core.assessment.flat_representations import (
    DIM_REGISTRY,
    DataDimension,
    FlatForecasts,
    FlatObserved,
)


[docs] @dataclass(frozen=True) class MetricSpec: output_dimensions: tuple[DataDimension, ...] = () metric_name: str = "metric" metric_id: str = "metric" description: str = "No description provided"
[docs] class MetricBase: """ Base class for metrics. Subclass this and implement the compute-method to create a new metric. Define the spec attribute to specify what the metric outputs. """ spec: MetricSpec = MetricSpec()
[docs] def get_metric(self, observations: FlatObserved, forecasts: FlatForecasts) -> pd.DataFrame: # Check taht obserations ar not nan null_mask = observations.disease_cases.isnull() observations = observations[~null_mask] out = self.compute(observations, forecasts) expected = [*(d for d in self.spec.output_dimensions), "metric"] missing = [c for c in expected if c not in out.columns] extra = [c for c in out.columns if c not in expected] if missing or extra: raise ValueError( f"{self.__class__.__name__} produced wrong columns.\n" f"Expected: {expected}\nMissing: {missing}\nExtra: {extra}" ) return self._make_schema().validate(out, lazy=False)
[docs] def compute(self, observations: pd.DataFrame, forecasts: pd.DataFrame) -> pd.DataFrame: raise NotImplementedError
def _make_schema(self) -> pa.DataFrameSchema: cols: dict[str, pa.Column] = {} for d in self.spec.output_dimensions: dtype, chk = DIM_REGISTRY[d] cols[d.value] = pa.Column(dtype, chk) if chk else pa.Column(dtype) cols["metric"] = pa.Column(float, nullable=True) return pa.DataFrameSchema(cols, strict=True, coerce=True)
[docs] def get_name(self) -> str: return self.spec.metric_name
[docs] def gives_highest_resolution(self) -> bool: """ Returns True if the metric gives one number per location/time_period/horizon_distance combination. """ return len(self.spec.output_dimensions) == 3
[docs] def is_full_aggregate(self) -> bool: """ Returns True if the metric gives only one number for the whole dataset """ return len(self.spec.output_dimensions) == 0