Source code for chap_core.api

import logging

from .assessment.forecast import forecast as do_forecast
from typing import Optional, List
from .datatypes import (
    HealthData,
    ClimateData,
    HealthPopulationData,
)
from .models.utils import get_model_from_directory_or_github_url
from .file_io.example_data_set import DataSetType, datasets
from .plotting.prediction_plot import plot_forecast_from_summaries
from .predictor import get_model
from .spatio_temporal_data.temporal_dataclass import DataSet
import dataclasses

from .time_period.date_util_wrapper import delta_month


logger = logging.getLogger(__name__)


[docs] class DummyControl:
[docs] def set_status(self, status): pass
@property def current_control(self): return None
[docs] @dataclasses.dataclass class AreaPolygons: shape_file: str
[docs] @dataclasses.dataclass class PredictionData: area_polygons: AreaPolygons = None health_data: DataSet[HealthData] = None climate_data: DataSet[ClimateData] = None population_data: DataSet[HealthPopulationData] = None disease_id: Optional[str] = None features: List[object] = None
[docs] def extract_disease_name(health_data: dict) -> str: return health_data["rows"][0][0]
[docs] def forecast( model_name: str, dataset_name: DataSetType, n_months: int, model_path: Optional[str] = None, ): logging.basicConfig(level=logging.INFO) dataset = datasets[dataset_name].load() if model_name == "external": model = get_model_from_directory_or_github_url(model_path) else: model = get_model(model_name) model = model() # model = get_model(model_name)() predictions = do_forecast(model, dataset, n_months * delta_month) figs = [] for location, prediction in predictions.items(): fig = plot_forecast_from_summaries( prediction.data(), dataset.get_location(location).data() ) # , lambda x: np.log(x+1)) figs.append(fig) return figs