Source code for chap_core.api
import logging
from .assessment.forecast import forecast as do_forecast
from typing import Optional, List
from .datatypes import (
HealthData,
ClimateData,
HealthPopulationData,
)
from .models.utils import get_model_from_directory_or_github_url
from .file_io.example_data_set import DataSetType, datasets
from .plotting.prediction_plot import plot_forecast_from_summaries
from .predictor import get_model
from .spatio_temporal_data.temporal_dataclass import DataSet
import dataclasses
from .time_period.date_util_wrapper import delta_month
logger = logging.getLogger(__name__)
[docs]
class DummyControl:
[docs]
def set_status(self, status):
pass
@property
def current_control(self):
return None
[docs]
@dataclasses.dataclass
class AreaPolygons:
shape_file: str
[docs]
@dataclasses.dataclass
class PredictionData:
area_polygons: AreaPolygons = None
health_data: DataSet[HealthData] = None
climate_data: DataSet[ClimateData] = None
population_data: DataSet[HealthPopulationData] = None
disease_id: Optional[str] = None
features: List[object] = None
[docs]
def forecast(
model_name: str,
dataset_name: DataSetType,
n_months: int,
model_path: Optional[str] = None,
):
logging.basicConfig(level=logging.INFO)
dataset = datasets[dataset_name].load()
if model_name == "external":
model = get_model_from_directory_or_github_url(model_path)
else:
model = get_model(model_name)
model = model()
# model = get_model(model_name)()
predictions = do_forecast(model, dataset, n_months * delta_month)
figs = []
for location, prediction in predictions.items():
fig = plot_forecast_from_summaries(
prediction.data(), dataset.get_location(location).data()
) # , lambda x: np.log(x+1))
figs.append(fig)
return figs