Source code for chap_core.assessment.evaluator

from abc import ABC, abstractmethod

from chap_core.assessment.representations import (
    MultiLocationDiseaseTimeSeries,
    MultiLocationForecast,
    MultiLocationErrorTimeSeries,
    ErrorTimeSeries,
    Error,
)


[docs] class Evaluator(ABC):
[docs] @abstractmethod def evaluate( self, all_truths: MultiLocationDiseaseTimeSeries, all_forecasts: MultiLocationForecast ) -> MultiLocationErrorTimeSeries: pass
[docs] def get_name(self) -> str: return self.__class__.__name__
[docs] class ComponentBasedEvaluator(Evaluator): def __init__(self, name, errorFunc, timeAggregationFunc, regionAggregationFunc): self._name = name self._errorFunc = errorFunc self._timeAggregationFunc = timeAggregationFunc self._regionAggregationFunc = regionAggregationFunc
[docs] def get_name(self): return self._name
[docs] def evaluate( self, all_truths: MultiLocationDiseaseTimeSeries, all_forecasts: MultiLocationForecast ) -> MultiLocationErrorTimeSeries: evaluation_result = MultiLocationErrorTimeSeries(timeseries_dict={}) for location in all_truths.locations(): current_error_series = ErrorTimeSeries(observations=[]) forecast_series = all_forecasts.timeseries[location] assert len(all_truths[location].observations) == len(forecast_series.predictions) truth_and_forecast_series = zip(all_truths[location].observations, forecast_series.predictions) errors = [] for truth, prediction in truth_and_forecast_series: assert truth.time_period == prediction.time_period errors.append(self._errorFunc(truth.disease_cases, prediction.disease_case_samples)) if self._timeAggregationFunc is None: current_error_series.observations.append(Error(time_period=truth.time_period, value=errors[-1])) if self._timeAggregationFunc is not None: current_error_series.observations.append( Error(time_period="Full_period", value=self._timeAggregationFunc(errors)) ) evaluation_result[location] = current_error_series if self._regionAggregationFunc is not None: final_evaluation_result = MultiLocationErrorTimeSeries( timeseries_dict={"Full_region": ErrorTimeSeries(observations=[])} ) for locationvalues in evaluation_result.locationvalues_per_timepoint(): aggregated_error = self._regionAggregationFunc([error.value for error in locationvalues.values()]) final_evaluation_result["Full_region"].observations.append( Error(time_period="Full_period", value=aggregated_error) ) else: final_evaluation_result = evaluation_result return final_evaluation_result