Source code for chap_core.climate_predictor

import dataclasses
from collections import defaultdict

import numpy as np
from sklearn import linear_model

from .datatypes import ClimateData, SimpleClimateData
from chap_core.spatio_temporal_data.temporal_dataclass import DataSet
from chap_core.time_period import PeriodRange, Month, Week


[docs] def get_climate_predictor(train_data: DataSet[ClimateData]): if isinstance(train_data.period_range[0], Month): estimator = MonthlyClimatePredictor() else: assert isinstance(train_data.period_range[0], Week) estimator = WeeklyClimatePredictor() estimator.train(train_data) return estimator
[docs] class MonthlyClimatePredictor: def __init__(self): self._models = defaultdict(dict) self._cls = None def _feature_matrix(self, time_period: PeriodRange): return time_period.month[:, None] == np.arange(1, 13)
[docs] def train(self, train_data: DataSet[ClimateData]): train_data = train_data.remove_field("disease_cases") for location, data in train_data.items(): self._cls = data.__class__ x = self._feature_matrix(data.time_period) for field in dataclasses.fields(data): if field.name in ("time_period"): continue y = getattr(data, field.name) # assert float type assert y.dtype.kind in ("f", "i"), (field.name, y.dtype) assert not np.isnan(y).any(), (field.name, y) model = linear_model.LinearRegression() model.fit(x, y[:, None]) self._models[location][field.name] = model
[docs] def predict(self, time_period: PeriodRange): x = self._feature_matrix(time_period) prediction_dict = {} for location, models in self._models.items(): prediction_dict[location] = self._cls( time_period, **{field: model.predict(x).ravel() for field, model in models.items()}, ) return DataSet(prediction_dict)
[docs] class WeeklyClimatePredictor(MonthlyClimatePredictor): def _feature_matrix(self, time_period: PeriodRange): t = time_period.week[:, None] == np.arange(1, 53) t[..., -1] |= time_period.week == 53 return t
[docs] class FutureWeatherFetcher:
[docs] def get_future_weather(self, period_range: PeriodRange) -> DataSet[SimpleClimateData]: ...
[docs] class SeasonalForecastFetcher: def __init__(self, folder_path): self.folder_path = folder_path
[docs] def get_future_weather(self, period_range: PeriodRange) -> DataSet[SimpleClimateData]: ...
[docs] class QuickForecastFetcher: def __init__(self, historical_data: DataSet[SimpleClimateData]): self._climate_predictor = get_climate_predictor(historical_data)
[docs] def get_future_weather(self, period_range: PeriodRange) -> DataSet[SimpleClimateData]: return self._climate_predictor.predict(period_range)
[docs] class FetcherNd: def __init__(self, historical_data: DataSet[SimpleClimateData]): self.historical_data = historical_data self._cls = list(historical_data.values())[0].__class__
[docs] def get_future_weather(self, period_range: PeriodRange) -> DataSet[SimpleClimateData]: prediction_dict = {} for location, data in self.historical_data.items(): prediction_dict[location] = self._cls( period_range, **{ field.name: getattr(data, field.name)[-len(period_range) :] for field in dataclasses.fields(data) if field.name not in ("time_period",) }, ) return DataSet(prediction_dict)