"""
Note: Some of this might be outdated, but plot_forecast_from_summaries is
used in several places. is being used to create forecast plots.
"""
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from plotly.graph_objs import Figure
import plotly.graph_objects as go
from chap_core.datatypes import ClimateData, HealthData, SummaryStatistics
from chap_core.predictor.protocol import IsSampler
[docs]
def prediction_plot(
true_data: HealthData,
predicition_sampler: IsSampler,
climate_data: ClimateData,
n_samples,
) -> Figure:
for i in range(n_samples):
new_observed = predicition_sampler.sample(climate_data)
plt.plot(new_observed, label="predicted", color="grey")
plt.plot(true_data.disease_cases, label="real", color="blue")
plt.legend()
plt.title("Prdicted path using estimated parameters vs real path")
return plt.gcf()
[docs]
def plot_forecast_from_summaries(
summaries: SummaryStatistics | list[SummaryStatistics],
true_data: HealthData,
transform=lambda x: x,
) -> Figure:
true_df = pd.DataFrame(
{
"x": [str(p) for p in true_data.time_period.topandas()],
"real": true_data.disease_cases,
}
)
if isinstance(summaries, list):
df = [summary.topandas() for summary in summaries]
for tmp in df:
tmp.time_period = tmp.time_period.astype(str)
else:
df = summaries.topandas()
df.time_period = df.time_period.astype(str)
return plot_forecasts_from_data_frame(df, true_df, transform)
[docs]
def plot_forecast(quantiles: np.ndarray, true_data: HealthData, x_pred=None) -> Figure:
x_true = [str(p) for p in true_data.time_period.topandas()]
if x_pred is None:
x_pred = x_true
else:
x_pred = [str(p) for p in x_pred]
df = pd.DataFrame({"x": x_pred, "10th": quantiles[0], "50th": quantiles[1], "90th": quantiles[2]})
true_df = pd.DataFrame({"x": x_true, "real": true_data.disease_cases})
true_df.x = true_df.x.astype(str)
return plot_forecasts_from_data_frame(df, true_df)
[docs]
def plot_forecasts_from_data_frame(
prediction_df: pd.DataFrame | list[pd.DataFrame], true_df, transform=lambda x: x
) -> Figure:
fig = go.Figure()
if isinstance(prediction_df, list):
for df in prediction_df:
add_prediction_lines(fig, df, transform, true_df)
else:
add_prediction_lines(fig, prediction_df, transform, true_df)
fig.add_scatter(
x=true_df["x"],
y=transform(true_df["real"]),
mode="lines",
name="real",
line=dict(color="blue"),
)
fig.update_layout(
title="Predicted path using estimated parameters vs real path",
xaxis_title="Time Period",
yaxis_title="Disease Cases",
)
return fig
[docs]
def add_prediction_lines(fig, prediction_df, transform, true_df):
last_idx = np.where(prediction_df["time_period"][0] == true_df["x"])[0][0]
if last_idx != 0:
last_row = true_df.iloc[last_idx - 1]
prepend_df = {
"time_period": [last_row["x"]],
"quantile_high": last_row["real"],
"quantile_low": last_row["real"],
"median": last_row["real"],
}
prediction_df = pd.concat([pd.DataFrame(prepend_df), prediction_df], ignore_index=True)
fig.add_trace(
go.Scatter(
x=prediction_df["time_period"],
y=transform(prediction_df["quantile_high"]),
mode="lines",
line=dict(color="lightgrey"),
name="quantile_high",
),
)
fig.add_trace(
go.Scatter(
x=prediction_df["time_period"],
y=transform(prediction_df["quantile_low"]),
mode="lines",
line=dict(color="lightgrey"),
fill="tonexty",
fillcolor="rgba(68, 68, 68, 0.3)",
name="quantile_low",
)
)
fig.add_scatter(
x=prediction_df["time_period"],
y=transform(prediction_df["median"]),
mode="lines",
line=dict(color="grey"),
name="Median",
)
# add vertical line for last true data point
fig.add_shape(
dict(
type="line",
x0=true_df["x"].iloc[last_idx - 1],
x1=true_df["x"].iloc[last_idx - 1],
y0=0,
y1=max(max(prediction_df["quantile_high"]), max(true_df["real"])),
line=dict(color="red", width=2),
)
)