Source code for chap_core.assessment.dataset_splitting

from typing import Iterable, Protocol, Optional, Type

from chap_core.climate_predictor import FutureWeatherFetcher
from chap_core.datatypes import ClimateData
from chap_core.spatio_temporal_data.temporal_dataclass import DataSet
from chap_core.time_period import TimePeriod
from chap_core.time_period.relationships import previous


[docs] class IsTimeDelta(Protocol): pass
[docs] def split_test_train_on_period( data_set: DataSet, split_points: Iterable[TimePeriod], future_length: Optional[IsTimeDelta] = None, include_future_weather: bool = False, future_weather_class: Type[ClimateData] = ClimateData, ): func = train_test_split_with_weather if include_future_weather else train_test_split if include_future_weather: return ( train_test_split_with_weather(data_set, period, future_length, future_weather_class) for period in split_points ) return (func(data_set, period, future_length) for period in split_points)
[docs] def train_test_split( data_set: DataSet, prediction_start_period: TimePeriod, extension: Optional[IsTimeDelta] = None, restrict_test=True, ): last_train_period = previous(prediction_start_period) train_data = data_set.restrict_time_period(slice(None, last_train_period)) if extension is not None: end_period = prediction_start_period.extend_to(extension) else: end_period = None if restrict_test: test_data = data_set.restrict_time_period(slice(prediction_start_period, end_period)) else: test_data = data_set return train_data, test_data
[docs] def train_test_generator( dataset: DataSet, prediction_length: int, n_test_sets: int = 1, stride: int = 1, future_weather_provider: Optional[FutureWeatherFetcher] = None, ) -> tuple[DataSet, Iterable[tuple[DataSet, DataSet, DataSet]]]: """ Genereate a train set along with an iterator of test data that contains tuples of full data up until a split point and data without target variables for the remaining steps Parameters ---------- dataset The full dataset prediction_length How many periods to predict n_test_sets How many test sets to generate stride How many periods to stride between test sets future_weather_provider A function that can provide future weather data for the test sets Returns ------- tuple[DataSet, Iterable[tuple[DataSet, DataSet]]] The train set and an iterator of test sets """ split_idx = -(prediction_length + (n_test_sets - 1) * stride + 1) train_set = dataset.restrict_time_period(slice(None, dataset.period_range[split_idx])) historic_data = [ dataset.restrict_time_period(slice(None, dataset.period_range[split_idx + i * stride])) for i in range(n_test_sets) ] future_data = [ dataset.restrict_time_period( slice( dataset.period_range[split_idx + i * stride + 1], dataset.period_range[split_idx + i * stride + prediction_length], ) ) for i in range(n_test_sets) ] if future_weather_provider is not None: masked_future_data = [ future_weather_provider(hd).get_future_weather(fd.period_range) for (hd, fd) in zip(historic_data, future_data) ] else: masked_future_data = (dataset.remove_field("disease_cases") for dataset in future_data) train_set.metadata = dataset.metadata.model_copy() train_set.metadata.name += "_train_set" return train_set, zip(historic_data, masked_future_data, future_data)
[docs] def train_test_split_with_weather( data_set: DataSet, prediction_start_period: TimePeriod, extension: Optional[IsTimeDelta] = None, future_weather_class: Type[ClimateData] = ClimateData, ): train_set, test_set = train_test_split(data_set, prediction_start_period, extension) future_weather = test_set.remove_field("disease_cases") train_periods = {str(period) for data in train_set.data() for period in data.data().time_period} future_periods = {str(period) for data in future_weather.data() for period in data.data().time_period} assert train_periods & future_periods == set(), ( f"Train and future weather data overlap: {train_periods & future_periods}" ) return train_set, test_set, future_weather
[docs] def get_split_points_for_data_set(data_set: DataSet, max_splits: int, start_offset=1) -> list[TimePeriod]: periods = ( next(iter(data_set.data())).data().time_period ) # Uses the time for the first location, assumes it to be the same for all! return get_split_points_for_period_range(max_splits, periods, start_offset)
[docs] def get_split_points_for_period_range(max_splits, periods, start_offset): delta = (len(periods) - 1 - start_offset) // (max_splits + 1) return list(periods)[start_offset + delta :: delta][:max_splits]