Source code for chap_core.plotting.season_plot

import numpy as np
import pandas as pd
import altair as alt
from .dataset_plot import DatasetPlot

alt.renderers.default = "browser"


[docs] class SeasonPlot(DatasetPlot):
[docs] def data(self) -> pd.DataFrame: df = self._df.copy() df["log1p"] = np.log1p(df["disease_cases"]) df["time_period"] = pd.to_datetime(df["time_period"]) df["month"] = df["time_period"].dt.month - 1 df["year"] = df["time_period"].dt.year means = ((month, group["log1p"].mean()) for month, group in df.groupby("month")) min_month, val = min(means, key=lambda x: x[1]) assert df["month"].max() == 11 offset_month = df["month"] - min_month df["seasonal_month"] = offset_month % 12 df["season_idx"] = df["year"] + offset_month // 12 # Create season_idx (season index based on years from start) df["season_idx"] = df["season_idx"] - df["season_idx"].min() return df
[docs] def plot(self) -> alt.FacetChart: df = self.data() return ( alt.Chart(df) .mark_line(point=False, strokeWidth=2) .encode( x=alt.X("seasonal_month:O", title="Month"), y=alt.Y("log1p:Q", title="Log1p Disease Cases"), color=alt.Color("season_idx:N", title="Season Year"), ) .facet(facet=alt.Facet("location:N", title="Location"), columns=3) )
[docs] class SeasonCorrelationPlot(DatasetPlot): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._season_plot = SeasonPlot(*args, **kwargs)
[docs] def data(self) -> pd.DataFrame: df = self._season_plot.data() season_stats = df.groupby(["location", "season_idx"])["log1p"].agg(["mean", "std", "max"]).reset_index() season_stats.columns = ["location", "season_idx", "season_mean", "season_std", "season_max"] df = df.merge(season_stats, on=["location", "season_idx"], how="left") return df
[docs] def plot(self): df = self.data() return ( alt.Chart(df) .mark_point(filled=True, size=100) .encode( x=alt.X("mean_temperature:Q", title="Predictor: Mean Temperature", scale=alt.Scale(zero=False)), y=alt.Y("season_max:Q", title="Season Max (Log1p Disease Cases)", scale=alt.Scale(zero=False)), color=alt.Color("seasonal_month:N", title="Seasonal Month"), tooltip=["location:N", "season_idx:N", "season_mean:Q", "season_std:Q", "mean_temperature:Q"], ) .facet( row=alt.Row("location:N", title="Location"), column=alt.Column("seasonal_month:O", title="Seasonal Month"), ) .resolve_scale(y="independent", x="independent") )
[docs] class SeasonCorrelationBarPlot(SeasonCorrelationPlot): feature_name = "mean_temperature" # Example feature to correlate with season_mean
[docs] def data(self) -> pd.DataFrame: df = super().data() last_months_subset = df[df["seasonal_month"] >= 9].copy() last_months_subset["seasonal_month"] -= 12 last_months_subset["season_idx"] += 1 df = pd.concat([df, last_months_subset], ignore_index=True) # Calculate correlation coefficient between season_mean and mean_temperature for each season_idx and location correlations = [] for (location, seasonal_month), group in df.groupby(["location", "seasonal_month"]): for feature_name in self._get_feature_names(): for outcome in ["max", "mean", "std"]: # Suppress RuntimeWarning for invalid values (e.g., when std is 0) with np.errstate(invalid="ignore"): corr = group[f"season_{outcome}"].corr(group[feature_name]) correlations.append( { "location": location, "seasonal_month": seasonal_month, "correlation": corr, "feature": feature_name, "outcome": outcome, "combination": f"{feature_name}_vs_season_{outcome}", } ) return pd.DataFrame(correlations)
[docs] def plot(self) -> alt.FacetChart: df = self.data() return ( alt.Chart(df) .mark_bar() .encode( x=alt.X("seasonal_month:O", title="Seasonal Month"), y=alt.Y("correlation:Q", title="Correlation (Season Max vs Temperature)"), color=alt.Color( "correlation:Q", scale=alt.Scale(scheme="redblue", domain=[-1, 1]), title="Correlation" ), tooltip=["location:N", "seasonal_month:O", "correlation:Q"], ) .facet( column=alt.Column("location:N", title="Location"), row=alt.Row("combination:N", title="Feature vs Outcome"), ) .properties( title={ "text": "Seasonal Correlation Analysis", "subtitle": "Correlation between season maximum disease cases and mean temperature by seasonal month and location. Red bars indicate negative correlation, blue bars indicate positive correlation.", } ) )
[docs] def test_season_plot(df: pd.DataFrame): plot = SeasonPlot(df) data = plot.data() print(data) assert "seasonal_month" in data.columns chart = plot.plot() chart.save("season_plot.html") chart.save("season_plot.png")
[docs] def test_season_correlation_plot(df: pd.DataFrame): plot = SeasonCorrelationPlot(df) data = plot.data() print(data) assert "season_mean" in data.columns assert "season_std" in data.columns chart = plot.plot() chart.save("season_correlation_plot.html") chart.save("season_correlation_plot.png")
[docs] def test_season_correlation_bar_plot(df: pd.DataFrame): plot = SeasonCorrelationBarPlot(df) data = plot.data() print(data) assert "correlation" in data.columns assert "location" in data.columns chart = plot.plot() chart.save("season_correlation_bar_plot.html") chart.save("season_correlation_bar_plot.png")