Source code for chap_core.assessment.representations
from dataclasses import dataclass, field
from typing import List, Dict, Set
# Disease cases
[docs]
@dataclass
class DiseaseObservation:
time_period: str
disease_cases: int
[docs]
@dataclass
class DiseaseTimeSeries:
observations: List[DiseaseObservation]
[docs]
@dataclass
class MultiLocationDiseaseTimeSeries:
timeseries_dict: Dict[str, DiseaseTimeSeries] = field(default_factory=dict)
def __setitem__(self, location, timeseries):
self.timeseries_dict[location] = timeseries
def __getitem__(self, location):
return self.timeseries_dict[location]
[docs]
def locations(self):
return iter(self.timeseries_dict.keys())
[docs]
def timeseries(self):
return iter(self.timeseries_dict.values())
[docs]
def filter_by_time_periods(self, time_periods: List[str]) -> "MultiLocationDiseaseTimeSeries":
filtered = MultiLocationDiseaseTimeSeries()
for location, timeseries in self.timeseries_dict.items():
filtered_observations = [obs for obs in timeseries.observations if obs.time_period in time_periods]
filtered[location] = DiseaseTimeSeries(filtered_observations)
return filtered
[docs]
@dataclass
class Error:
time_period: str
value: float
[docs]
@dataclass
class ErrorTimeSeries:
observations: List[Error]
[docs]
@dataclass
class MultiLocationErrorTimeSeries:
timeseries_dict: Dict[str, ErrorTimeSeries]
def __getitem__(self, location):
return self.timeseries_dict[location]
def __setitem__(self, location, timeseries):
self.timeseries_dict[location] = timeseries
[docs]
def locations(self):
return iter(self.timeseries_dict.keys())
[docs]
def timeseries(self):
return iter(self.timeseries_dict.values())
[docs]
def num_locations(self):
return len(self.timeseries_dict)
[docs]
def num_timeperiods(self):
return len(self.get_all_timeperiods())
[docs]
def get_the_only_location(self):
assert len(self.timeseries_dict) == 1
return list(self.timeseries_dict.keys())[0]
[docs]
def get_the_only_timeseries(self):
assert len(self.timeseries_dict) == 1
return list(self.timeseries_dict.values())[0]
[docs]
def get_all_timeperiods(self):
timeperiods = None
for ts in self.timeseries():
current_timepriod_value = [o.time_period for o in ts.observations]
if timeperiods is None:
timeperiods = current_timepriod_value
else:
assert timeperiods == current_timepriod_value
return timeperiods
[docs]
def timeseries_length(self):
lengths = [len(ts.observations) for ts in self.timeseries()]
assert len(set(lengths)) == 1
return lengths[0]
[docs]
def locationvalues_per_timepoint(self) -> List[Dict[str, Error]]:
return [
dict([(location, timeseries.observations[i]) for location, timeseries in self.timeseries_dict.items()])
for i in range(self.timeseries_length())
]
# Forecasts
[docs]
@dataclass
class Samples:
time_period: str
disease_case_samples: List[float]
[docs]
@dataclass
class Forecast:
predictions: List[Samples]
[docs]
@dataclass
class MultiLocationForecast:
timeseries: Dict[str, Forecast]
[docs]
def time_periods(self) -> Set[str]:
periods = set()
for forecast in self.timeseries.values():
for sample in forecast.predictions:
periods.add(sample.time_period)
return periods