Source code for chap_core.adaptors.rest_api
from typing import List
import pydantic
from fastapi import FastAPI
from starlette.middleware.cors import CORSMiddleware
from chap_core.datatypes import remove_field
from chap_core.model_spec import get_dataclass
from chap_core.spatio_temporal_data.temporal_dataclass import DataSet
import logging
logger = logging.getLogger(__name__)
[docs]
def generate_app(estimator, working_dir: str):
app = FastAPI()
origins = [
"*", # Allow all origins
"http://localhost:3000",
"localhost:3000",
]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
dc = get_dataclass(estimator)
model = pydantic.create_model("TrainingData", **dc.__annotations__)
training_data_filename = f"{working_dir}/training_data.csv"
model_path = f"{working_dir}/model"
@app.command()
def train(training_data: List[model]):
"""
Train a model using historic data
Parameters
----------
training_data_filename: str
The path to the training data file
model_path: str
The path to save the trained model
"""
logger.info(f"Loading data from {training_data_filename} as {dc}")
dataset = DataSet.df_from_pydantic_observations()
predictor = estimator.train(dataset)
predictor.save(model_path)
@app.command()
def predict(model_filename: str, historic_data_filename: str, future_data_filename: str, output_filename: str):
"""
Predict using a trained model
Parameters
----------
model_filename: str
The path to the model file trained with the train command
historic_data_filename: str
The path to the historic data file, i.e. real data up to the present/prediction start
future_data_filename: str
The path to the future data file, i.e. forecasted predictors for the future
"""
dataset = DataSet.from_csv(historic_data_filename, dc)
future_dc = remove_field(dc, "disease_cases")
future_data = DataSet.from_csv(future_data_filename, future_dc)
predictor = estimator.load_predictor(model_filename)
forecasts = predictor.predict(dataset, future_data)
forecasts.to_csv(output_filename)
return app