Validates that a model's training and prediction functions work correctly with a specific example dataset and produce output with the expected structure (matching time periods and locations).
Usage
validate_model_io(
train_fn,
predict_fn,
example_data,
model_configuration = list()
)Arguments
- train_fn
Training function that takes (training_data, model_configuration = list()) and returns a trained model object
- predict_fn
Prediction function that takes (historic_data, future_data, saved_model, model_configuration = list()) and returns predictions with a
sampleslist-column.- example_data
Named list containing training_data, historic_data, future_data, and predictions (as returned by get_example_data())
- model_configuration
Optional list of model configuration parameters to pass to train_fn and predict_fn
Value
A list with validation results:
success: Logical indicating if validation passederrors: Character vector of error messages (if any)n_predictions: Number of prediction rows returnedn_samples: Number of samples per forecast unit
Details
All models must return predictions with a samples list-column containing
numeric vectors. For deterministic models, use a single sample per forecast unit
(e.g., samples = list(c(42))). For probabilistic models, include multiple
Monte Carlo samples (e.g., 1000 samples per forecast unit).
Validation Checks
The following checks are performed on predictions:
Predictions must be a data frame or tibble
Must have a
sampleslist-column containing numeric vectorsAll rows must have the same number of samples
Row count must match the number of rows in
future_dataPredictions must not contain NaN values (Chap contract requirement)
Predictions must not contain NA values
Predictions must be non-negative (Chap contract requirement)
Examples
if (FALSE) { # \dontrun{
# Define simple model functions
my_train <- function(training_data, model_configuration = list()) {
means <- training_data |>
dplyr::group_by(location) |>
dplyr::summarise(mean_cases = mean(disease_cases, na.rm = TRUE))
list(means = means)
}
my_predict <- function(historic_data, future_data, saved_model,
model_configuration = list()) {
future_data |>
tibble::as_tibble() |>
dplyr::left_join(saved_model$means, by = "location") |>
dplyr::mutate(samples = purrr::map(mean_cases, ~c(.x))) |>
dplyr::select(-mean_cases)
}
# Get example data and validate
example_data <- get_example_data('laos', 'M')
result <- validate_model_io(my_train, my_predict, example_data)
if (result$success) {
cat("Model validation passed!\n")
} else {
cat("Validation failed:\n")
print(result$errors)
}
} # }