import abc
from typing import Optional
import altair as alt
from chap_core.assessment.flat_representations import (
FlatMetric,
convert_backtest_observations_to_flat_observations,
convert_backtest_to_flat_forecasts,
)
from chap_core.assessment.metrics.base import MetricBase
from chap_core.database.base_tables import DBModel
from chap_core.database.tables import BackTest
alt.renderers.enable("browser")
[docs]
class MetricPlotV2(abc.ABC):
"""
Represents types of metrics plots, that always start from raw FlatMetric data.
Differnet plots can process this data in the way they want to produce a plot
"""
def __init__(self, metric_data: FlatMetric, geojson: Optional[dict] = None):
self._metric_data = metric_data
[docs]
def plot(self, title="Mean metric by horizon") -> alt.Chart:
return self.plot_from_df(title=title)
[docs]
@abc.abstractmethod
def plot_from_df(self, title: str = "") -> alt.Chart:
pass
[docs]
def plot_spec(self) -> dict:
chart = self.plot()
return chart.to_dict(format="vega")
[docs]
class VisualizationInfo(DBModel):
id: str
display_name: str
description: str
[docs]
class MetricByHorizonAndLocationMean(MetricPlotV2):
visualization_info = VisualizationInfo(
id="metric_by_horizon",
display_name="Horizon Plot",
description="Shows the aggregated metric by forecast horizon",
)
[docs]
def plot_from_df(self):
df = self._metric_data
adf = df.groupby(["horizon_distance", "location"]).agg({"metric": "mean"}).reset_index()
print(adf)
chart = (
alt.Chart(adf)
.mark_bar(point=True)
.encode(
x=alt.X("horizon_distance:O", title="Horizon (periods ahead)"),
y=alt.Y("metric:Q", title="Mean Metric Value"),
tooltip=["horizon_distance", "location", "metric"],
)
.properties(width=300, height=230, title="Mean Metric by Horizon")
)
return chart
[docs]
class MetricByHorizonV2Mean(MetricPlotV2):
visualization_info = VisualizationInfo(
id="metric_by_horizon",
display_name="Horizon Plot",
description="Shows the aggregated metric by forecast horizon",
)
[docs]
def plot_from_df(self, title="Mean metric by horizon"):
df = self._metric_data
adf = df.groupby(["horizon_distance"]).agg({"metric": "mean"}).reset_index()
chart = (
alt.Chart(adf)
.mark_bar(point=True)
.encode(
x=alt.X("horizon_distance:O", title="Horizon (periods ahead)"),
y=alt.Y("metric:Q", title="Mean Metric Value"),
tooltip=["horizon_distance", "metric"],
)
.properties(width=300, height=230, title=title)
)
return chart
[docs]
class MetricByHorizonV2Sum(MetricPlotV2):
visualization_info = VisualizationInfo(
id="metric_by_horizon_sum",
display_name="Horizon Plot (sum)",
description="Sums metric across locations per forecast horizon",
)
[docs]
def plot_from_df(self):
df = self._metric_data
chart = (
alt.Chart(df)
.mark_bar()
.encode(
x=alt.X("horizon_distance:O", title="Horizon (periods ahead)"),
y=alt.Y("sum(metric):Q", title="Samples above truth (count)"),
tooltip=[
alt.Tooltip("horizon_distance:O", title="Horizon"),
alt.Tooltip("sum(metric):Q", title="Count"),
],
)
.properties(width=300, height=230, title="Samples above truth by horizon")
)
return chart
[docs]
class MetricByTimePeriodAndLocationV2Mean(MetricPlotV2):
visualization_info = VisualizationInfo(
id="metric_by_time_period",
display_name="Time Period Plot",
description="Shows the aggregated metric by time period (per location)",
)
[docs]
def plot_from_df(self, title="Mean metric by location and time period") -> alt.Chart:
df = self._metric_data
adf = df.groupby(["time_period", "location"]).agg({"metric": "mean"}).reset_index()
chart = (
alt.Chart(adf)
.mark_line(point=True)
.encode(
x=alt.X("time_period:O", title="Time period"),
y=alt.Y("metric:Q", title="Mean Metric Value"),
color=alt.Color("location:N", title="Location"),
tooltip=["time_period", "location", "metric"],
)
.properties(width=300, height=230, title=title)
)
return chart
[docs]
class MetricByTimePeriodV2Sum(MetricPlotV2):
visualization_info = VisualizationInfo(
id="metric_by_time_sum",
display_name="Horizon Plot (sum)",
description="Sums metric across locations per forecast horizon",
)
[docs]
def plot_from_df(self):
df = self._metric_data
chart = (
alt.Chart(df)
.mark_line()
.encode(
x=alt.X("time_period:O", title="Time Period"),
y=alt.Y("sum(metric):Q", title="Samples above truth (count)"),
color=alt.Color("location:N", title="Location"),
tooltip=[
alt.Tooltip("time_period:O", title="Time Period"),
alt.Tooltip("sum(metric):Q", title="Count"),
],
)
.properties(width=300, height=230, title="Samples above truth by time period")
)
return chart
[docs]
class MetricByTimePeriodV2Mean(MetricPlotV2):
visualization_info = VisualizationInfo(
id="metric_by_time_mean",
display_name="Metric by time (mean)",
description="Mean metric across locations and horizons per time period",
)
[docs]
def plot_from_df(self, title="Mean metric by time period"):
df = self._metric_data
df = df.groupby(["time_period"]).agg({"metric": "mean"}).reset_index()
chart = (
alt.Chart(df)
.mark_line()
.encode(
x=alt.X("time_period:O", title="Time Period"),
y=alt.Y("mean(metric):Q", title="Mean Metric Value"),
tooltip=[
alt.Tooltip("time_period:O", title="Time Period"),
alt.Tooltip("mean(metric):Q", title="Count"),
],
)
.properties(width=300, height=230, title=title)
)
return chart
[docs]
class MetricMapV2(MetricPlotV2):
visualization_info = VisualizationInfo(
id="metric_map", display_name="Map", description="Shows a map of aggregated metrics per org unit"
)
def __init__(self, metric_data: FlatMetric, geojson: Optional[dict] = None):
super().__init__(metric_data, geojson)
self._geojson = geojson
[docs]
def plot_from_df(self, title="Metric Map by location") -> alt.Chart:
# Get the metric data DataFrame
df = self._metric_data
# Aggregate metrics by location (average across all time periods and horizons)
agg_df = df.groupby("location").agg({"metric": "mean"}).reset_index()
agg_df.rename(columns={"location": "org_unit", "metric": "value"}, inplace=True)
# Create map visualization with geojson
geojson_data = self._geojson
# Build Altair map chart
chart = (
alt.Chart(alt.Data(values=geojson_data["features"]))
.mark_geoshape(stroke="black", strokeWidth=0.5)
.encode(
color=alt.Color("value:Q", scale=alt.Scale(scheme="reds"), title="Metric Value"),
tooltip=[alt.Tooltip("properties.name:N", title="org unit"), "value:Q"],
)
.transform_lookup(
lookup="id", # Assuming geojson has org_unit property
from_=alt.LookupData(agg_df, "org_unit", ["value"]),
)
.project(type="equirectangular") # Use equirectangular projection for proper proportions
.properties(width=300, height=230, title=title)
)
return chart
[docs]
def make_plot_from_backtest_object(
backtest: BackTest, plotting_class: MetricPlotV2, metric: MetricBase, geojson=None
) -> alt.Chart:
# Convert to flat representation
flat_forecasts = convert_backtest_to_flat_forecasts(backtest.forecasts)
flat_observations = convert_backtest_observations_to_flat_observations(backtest.dataset.observations)
metric_data = metric.compute(flat_observations, flat_forecasts)
return plotting_class(metric_data, geojson).plot_spec()