Source code for chap_core.plotting.backtest_plot

from abc import ABC, abstractmethod

import pandas as pd
from altair import FacetChart

from chap_core.assessment.flat_representations import convert_backtest_observations_to_flat_observations
from chap_core.database.tables import BackTest
import altair as alt
import textwrap

alt.data_transformers.enable("vegafusion")


[docs] def title_chart(text: str, width: int = 600, font_size: int = 24, pad: int = 10): """Return an Altair chart that just displays a title.""" return ( alt.Chart(pd.DataFrame({"x": [0], "y": [0]})) .mark_text( text=text, fontSize=font_size, fontWeight="bold", align="center", baseline="top", ) .properties(width=width, height=font_size + pad) )
[docs] def text_chart(text, line_length=80, font_size=12, align="left", pad_bottom=50): import altair as alt import pandas as pd lines = textwrap.wrap(text, width=line_length) df = pd.DataFrame({"line": lines, "y": range(len(lines))}) line_spacing = font_size + 2 total_height = len(lines) * line_spacing + pad_bottom chart = ( alt.Chart(df) .mark_text(align=align, baseline="top", fontSize=font_size) .encode(text="line", y=alt.Y("y:O", axis=None)) .properties(height=total_height) ) return chart
[docs] def clean_time(period): """Convert period to ISO date format for Altair/vegafusion compatibility.""" if len(period) == 6: # YYYYMM format -> YYYY-MM-01 (add day for full date) return f"{period[:4]}-{period[4:]}-01" elif len(period) == 7 and period[4] == "-": # YYYY-MM format -> YYYY-MM-01 (add day for full date) return f"{period}-01" else: return period
[docs] class BackTestPlotBase(ABC): """ Abstract base class for backtest plotting. Subclasses must implement: - from_backtest: Class method to create plot instance from a BackTest object - plot: Method to generate and return the visualization - name: Class variable with the name of the plot type """ name: str = ""
[docs] @classmethod @abstractmethod def from_backtest(cls, backtest: BackTest): """ Create a plot instance from a BackTest object. Parameters ---------- backtest : BackTest The backtest object containing forecast and observation data Returns ------- BackTestPlotBase An instance of the concrete plot class """ pass
[docs] @abstractmethod def plot(self): """ Generate and return the visualization. Returns ------- Chart object (implementation-specific) The visualization object (e.g., FacetChart for Altair-based plots) """ pass
[docs] class EvaluationBackTestPlot(BackTestPlotBase): """ Backtest-plot that shows truth vs predictions over time. """ name: str = "Evaluation Plot" def __init__(self, forecast_df: pd.DataFrame, observed_df: pd.DataFrame): self._forecast = forecast_df self._observed = observed_df
[docs] @classmethod def from_backtest(cls, backtest: BackTest) -> "EvaluationBackTestPlot": rows = [] quantiles = [0.1, 0.25, 0.5, 0.75, 0.9] for bt_forecast in backtest.forecasts: rows.append( { "time_period": clean_time(bt_forecast.period), "location": bt_forecast.org_unit, "split_period": clean_time(bt_forecast.last_seen_period), } | {f"q_{int(q * 100)}": v for q, v in zip(quantiles, bt_forecast.get_quantiles(quantiles))} ) df = pd.DataFrame(rows) flat_observations = convert_backtest_observations_to_flat_observations(backtest.dataset.observations) flat_observations["time_period"] = flat_observations["time_period"].apply(clean_time) return cls(df, flat_observations)
[docs] def plot(self) -> FacetChart: # Replicate observations for each split_period to show in all facets unique_split_periods = self._forecast["split_period"].unique() # Create observed data with all combinations of split_period observed_replicated = [] for split_period in unique_split_periods: tmp = self._observed.copy() tmp["split_period"] = split_period observed_replicated.append(tmp) observed_with_split = pd.concat(observed_replicated, ignore_index=True) # Combine all data into a single dataset for faceting # Add a column to distinguish data types forecast_data = self._forecast.copy() forecast_data["data_type"] = "forecast" observed_data = observed_with_split.copy() observed_data["data_type"] = "observed" # Align column names - add disease_cases to forecast (as NaN) and quantiles to observed (as NaN) for col in ["q_10", "q_25", "q_50", "q_75", "q_90"]: if col not in observed_data.columns: observed_data[col] = None if "disease_cases" not in forecast_data.columns: forecast_data["disease_cases"] = None # Drop all-NA columns before concatenation to avoid FutureWarning forecast_data = forecast_data.dropna(axis=1, how="all") observed_data = observed_data.dropna(axis=1, how="all") # Combine datasets combined_data = pd.concat([forecast_data, observed_data], ignore_index=True) # Create base chart with combined data base = alt.Chart(combined_data) # Forecast line (median) line = ( base.transform_filter(alt.datum.data_type == "forecast") .mark_line() .encode( x="time_period:T", y=alt.Y("q_50:Q", scale=alt.Scale(zero=False)), ) ) # Error bands error1 = ( base.transform_filter(alt.datum.data_type == "forecast") .mark_errorband(color="blue", opacity=0.3) .encode( x="time_period:T", y=alt.Y("q_10:Q", scale=alt.Scale(zero=False)), y2="q_90:Q", ) ) error2 = ( base.transform_filter(alt.datum.data_type == "forecast") .mark_errorband(color="blue", opacity=0.5) .encode( x="time_period:T", y=alt.Y("q_25:Q", scale=alt.Scale(zero=False)), y2="q_75:Q", ) ) # Observations observations = ( base.transform_filter(alt.datum.data_type == "observed") .mark_line(color="orange") .encode( x="time_period:T", y=alt.Y("disease_cases:Q", scale=alt.Scale(zero=False)), ) ) # Layer all components full_layer = error1 + error2 + line + observations # Facet the combined layer return ( full_layer.facet(column="split_period:O", row="location:N") .resolve_scale(y="independent") .properties(title="BackTest Forecasts with Observations") )