Source code for chap_core.assessment.representations

from dataclasses import dataclass, field
from typing import List, Dict, Set


# Disease cases
[docs] @dataclass class DiseaseObservation: time_period: str disease_cases: int
[docs] @dataclass class DiseaseTimeSeries: observations: List[DiseaseObservation]
[docs] @dataclass class MultiLocationDiseaseTimeSeries: timeseries_dict: Dict[str, DiseaseTimeSeries] = field(default_factory=dict) def __setitem__(self, location, timeseries): self.timeseries_dict[location] = timeseries def __getitem__(self, location): return self.timeseries_dict[location]
[docs] def locations(self): return iter(self.timeseries_dict.keys())
[docs] def timeseries(self): return iter(self.timeseries_dict.values())
[docs] def filter_by_time_periods(self, time_periods: List[str]) -> "MultiLocationDiseaseTimeSeries": filtered = MultiLocationDiseaseTimeSeries() for location, timeseries in self.timeseries_dict.items(): filtered_observations = [obs for obs in timeseries.observations if obs.time_period in time_periods] filtered[location] = DiseaseTimeSeries(filtered_observations) return filtered
[docs] @dataclass class Error: time_period: str value: float
[docs] @dataclass class ErrorTimeSeries: observations: List[Error]
[docs] @dataclass class MultiLocationErrorTimeSeries: timeseries_dict: Dict[str, ErrorTimeSeries] def __getitem__(self, location): return self.timeseries_dict[location] def __setitem__(self, location, timeseries): self.timeseries_dict[location] = timeseries
[docs] def locations(self): return iter(self.timeseries_dict.keys())
[docs] def timeseries(self): return iter(self.timeseries_dict.values())
[docs] def num_locations(self): return len(self.timeseries_dict)
[docs] def num_timeperiods(self): return len(self.get_all_timeperiods())
[docs] def get_the_only_location(self): assert len(self.timeseries_dict) == 1 return list(self.timeseries_dict.keys())[0]
[docs] def get_the_only_timeseries(self): assert len(self.timeseries_dict) == 1 return list(self.timeseries_dict.values())[0]
[docs] def get_all_timeperiods(self): timeperiods = None for ts in self.timeseries(): current_timepriod_value = [o.time_period for o in ts.observations] if timeperiods is None: timeperiods = current_timepriod_value else: assert timeperiods == current_timepriod_value return timeperiods
[docs] def timeseries_length(self): lengths = [len(ts.observations) for ts in self.timeseries()] assert len(set(lengths)) == 1 return lengths[0]
[docs] def locationvalues_per_timepoint(self) -> List[Dict[str, Error]]: return [ dict([(location, timeseries.observations[i]) for location, timeseries in self.timeseries_dict.items()]) for i in range(self.timeseries_length()) ]
# Forecasts
[docs] @dataclass class Samples: time_period: str disease_case_samples: List[float]
[docs] @dataclass class Forecast: predictions: List[Samples]
[docs] @dataclass class MultiLocationForecast: timeseries: Dict[str, Forecast]
[docs] def time_periods(self) -> Set[str]: periods = set() for forecast in self.timeseries.values(): for sample in forecast.predictions: periods.add(sample.time_period) return periods