|
|
""" |
|
|
plotting.py |
|
|
----------- |
|
|
Chart-generation functions for time-series visualisation. |
|
|
|
|
|
Every public function returns a :class:`matplotlib.figure.Figure` object. |
|
|
Callers (e.g. Streamlit pages) can pass the figure to ``st.pyplot(fig)`` |
|
|
or convert it to PNG bytes via :func:`fig_to_png_bytes`. |
|
|
|
|
|
All functions accept an optional *style_dict* (typically from |
|
|
:func:`ui_theme.get_miami_mpl_style`) and an optional *palette_colors* |
|
|
list so that colours stay consistent across the application. |
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import io |
|
|
import math |
|
|
from typing import Dict, List, Optional, Sequence |
|
|
|
|
|
|
|
|
import matplotlib |
|
|
matplotlib.use("Agg") |
|
|
|
|
|
import matplotlib.pyplot as plt |
|
|
import matplotlib.dates as mdates |
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MIAMI_RED: str = "#C41230" |
|
|
_DEFAULT_FIG_SIZE = (10, 6) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def fig_to_png_bytes(fig: matplotlib.figure.Figure, dpi: int = 150) -> bytes: |
|
|
"""Render *fig* to an in-memory PNG and return the raw bytes.""" |
|
|
buf = io.BytesIO() |
|
|
fig.savefig(buf, format="png", dpi=dpi, bbox_inches="tight") |
|
|
buf.seek(0) |
|
|
return buf.read() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _StyleContext: |
|
|
"""Context manager that temporarily applies *style_dict* to rcParams. |
|
|
|
|
|
On exit the previous values are restored so that other figures are not |
|
|
affected. |
|
|
""" |
|
|
|
|
|
def __init__(self, style_dict: Optional[Dict[str, object]]): |
|
|
self._style = style_dict |
|
|
self._saved: Dict[str, object] = {} |
|
|
|
|
|
def __enter__(self) -> "_StyleContext": |
|
|
if self._style: |
|
|
for key, value in self._style.items(): |
|
|
self._saved[key] = plt.rcParams.get(key) |
|
|
try: |
|
|
plt.rcParams[key] = value |
|
|
except (KeyError, ValueError): |
|
|
pass |
|
|
return self |
|
|
|
|
|
def __exit__(self, *exc_info: object) -> None: |
|
|
for key, value in self._saved.items(): |
|
|
try: |
|
|
plt.rcParams[key] = value |
|
|
except (KeyError, ValueError): |
|
|
pass |
|
|
|
|
|
|
|
|
def _default_color(palette_colors: Optional[List[str]], idx: int = 0) -> str: |
|
|
"""Pick a colour from *palette_colors* or fall back to MIAMI_RED.""" |
|
|
if palette_colors and len(palette_colors) > idx: |
|
|
return palette_colors[idx % len(palette_colors)] |
|
|
return MIAMI_RED |
|
|
|
|
|
|
|
|
def _finish_figure(fig: matplotlib.figure.Figure) -> matplotlib.figure.Figure: |
|
|
"""Apply common finishing touches and return the figure.""" |
|
|
fig.tight_layout() |
|
|
return fig |
|
|
|
|
|
|
|
|
def _auto_date_axis(ax: plt.Axes) -> None: |
|
|
"""Auto-format and rotate date tick labels.""" |
|
|
ax.xaxis.set_major_formatter(mdates.AutoDateFormatter(mdates.AutoDateLocator())) |
|
|
for label in ax.get_xticklabels(): |
|
|
label.set_rotation(30) |
|
|
label.set_ha("right") |
|
|
|
|
|
|
|
|
def _grid_dims(n: int) -> tuple[int, int]: |
|
|
"""Return (nrows, ncols) for a compact grid of *n* panels.""" |
|
|
ncols = min(n, 3) |
|
|
nrows = math.ceil(n / ncols) |
|
|
return nrows, ncols |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def plot_line_with_markers( |
|
|
df: pd.DataFrame, |
|
|
date_col: str, |
|
|
y_col: str, |
|
|
title: Optional[str] = None, |
|
|
style_dict: Optional[Dict[str, object]] = None, |
|
|
palette_colors: Optional[List[str]] = None, |
|
|
) -> matplotlib.figure.Figure: |
|
|
"""Simple line plot with small circle markers. |
|
|
|
|
|
Uses the first palette colour or *MIAMI_RED* as the default. |
|
|
""" |
|
|
with _StyleContext(style_dict): |
|
|
fig, ax = plt.subplots(figsize=_DEFAULT_FIG_SIZE) |
|
|
color = _default_color(palette_colors, 0) |
|
|
ax.plot( |
|
|
df[date_col], df[y_col], |
|
|
marker="o", markersize=4, linewidth=1.5, |
|
|
color=color, label=y_col, |
|
|
) |
|
|
ax.set_xlabel(date_col) |
|
|
ax.set_ylabel(y_col) |
|
|
if title: |
|
|
ax.set_title(title) |
|
|
_auto_date_axis(ax) |
|
|
ax.legend(loc="best") |
|
|
return _finish_figure(fig) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def plot_line_colored_markers( |
|
|
df: pd.DataFrame, |
|
|
date_col: str, |
|
|
y_col: str, |
|
|
color_by: str, |
|
|
palette_colors: List[str], |
|
|
title: Optional[str] = None, |
|
|
style_dict: Optional[Dict[str, object]] = None, |
|
|
) -> matplotlib.figure.Figure: |
|
|
"""Line plot where marker colour varies by a categorical column. |
|
|
|
|
|
A legend is added mapping each unique value of *color_by* to its |
|
|
colour. |
|
|
""" |
|
|
with _StyleContext(style_dict): |
|
|
fig, ax = plt.subplots(figsize=_DEFAULT_FIG_SIZE) |
|
|
|
|
|
|
|
|
ax.plot( |
|
|
df[date_col], df[y_col], |
|
|
linewidth=1.0, color="#AAAAAA", zorder=1, |
|
|
) |
|
|
|
|
|
|
|
|
categories = df[color_by].unique() |
|
|
n_cats = len(categories) |
|
|
if len(palette_colors) < n_cats: |
|
|
|
|
|
import itertools |
|
|
palette_colors = list(itertools.islice( |
|
|
itertools.cycle(palette_colors), n_cats |
|
|
)) |
|
|
|
|
|
color_map = {cat: palette_colors[i] for i, cat in enumerate(categories)} |
|
|
|
|
|
for cat in categories: |
|
|
mask = df[color_by] == cat |
|
|
ax.scatter( |
|
|
df.loc[mask, date_col], df.loc[mask, y_col], |
|
|
c=color_map[cat], label=str(cat), |
|
|
s=30, zorder=2, edgecolors="white", linewidths=0.3, |
|
|
) |
|
|
|
|
|
ax.set_xlabel(date_col) |
|
|
ax.set_ylabel(y_col) |
|
|
if title: |
|
|
ax.set_title(title) |
|
|
_auto_date_axis(ax) |
|
|
ax.legend(title=color_by, loc="best", fontsize=8, ncol=max(1, n_cats // 8)) |
|
|
return _finish_figure(fig) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def plot_seasonal( |
|
|
df: pd.DataFrame, |
|
|
date_col: str, |
|
|
y_col: str, |
|
|
period: str, |
|
|
palette_name_colors: List[str], |
|
|
title: Optional[str] = None, |
|
|
style_dict: Optional[Dict[str, object]] = None, |
|
|
) -> matplotlib.figure.Figure: |
|
|
"""Seasonal plot: one line per year/cycle, x-axis is within-period position. |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
period: |
|
|
``"month"`` (x = month 1-12) or ``"quarter"`` (x = quarter 1-4). |
|
|
palette_name_colors: |
|
|
List of hex colours; one per cycle/year. |
|
|
""" |
|
|
with _StyleContext(style_dict): |
|
|
tmp = df[[date_col, y_col]].copy() |
|
|
tmp["_year"] = tmp[date_col].dt.year |
|
|
|
|
|
if period.lower().startswith("q"): |
|
|
tmp["_period_pos"] = tmp[date_col].dt.quarter |
|
|
x_label = "Quarter" |
|
|
else: |
|
|
tmp["_period_pos"] = tmp[date_col].dt.month |
|
|
x_label = "Month" |
|
|
|
|
|
years = sorted(tmp["_year"].unique()) |
|
|
n_years = len(years) |
|
|
if len(palette_name_colors) < n_years: |
|
|
import itertools |
|
|
palette_name_colors = list(itertools.islice( |
|
|
itertools.cycle(palette_name_colors), n_years |
|
|
)) |
|
|
|
|
|
fig, ax = plt.subplots(figsize=_DEFAULT_FIG_SIZE) |
|
|
for i, year in enumerate(years): |
|
|
sub = tmp[tmp["_year"] == year].sort_values("_period_pos") |
|
|
ax.plot( |
|
|
sub["_period_pos"], sub[y_col], |
|
|
marker="o", markersize=4, linewidth=1.4, |
|
|
color=palette_name_colors[i], label=str(year), |
|
|
) |
|
|
|
|
|
ax.set_xlabel(x_label) |
|
|
ax.set_ylabel(y_col) |
|
|
if title: |
|
|
ax.set_title(title) |
|
|
ax.legend(title="Year", loc="best", fontsize=8, ncol=max(1, n_years // 6)) |
|
|
return _finish_figure(fig) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def plot_seasonal_subseries( |
|
|
df: pd.DataFrame, |
|
|
date_col: str, |
|
|
y_col: str, |
|
|
period: str, |
|
|
title: Optional[str] = None, |
|
|
style_dict: Optional[Dict[str, object]] = None, |
|
|
palette_colors: Optional[List[str]] = None, |
|
|
) -> matplotlib.figure.Figure: |
|
|
"""Subseries plot with vertical panels for each season and horizontal mean lines. |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
period: |
|
|
``"month"`` or ``"quarter"``. |
|
|
""" |
|
|
with _StyleContext(style_dict): |
|
|
tmp = df[[date_col, y_col]].copy() |
|
|
|
|
|
if period.lower().startswith("q"): |
|
|
tmp["_season"] = tmp[date_col].dt.quarter |
|
|
labels = {1: "Q1", 2: "Q2", 3: "Q3", 4: "Q4"} |
|
|
else: |
|
|
tmp["_season"] = tmp[date_col].dt.month |
|
|
labels = { |
|
|
1: "Jan", 2: "Feb", 3: "Mar", 4: "Apr", |
|
|
5: "May", 6: "Jun", 7: "Jul", 8: "Aug", |
|
|
9: "Sep", 10: "Oct", 11: "Nov", 12: "Dec", |
|
|
} |
|
|
|
|
|
seasons = sorted(tmp["_season"].unique()) |
|
|
n = len(seasons) |
|
|
fig_w = max(10, n * 1.3) |
|
|
fig, axes = plt.subplots(1, n, figsize=(fig_w, 5), sharey=True) |
|
|
if n == 1: |
|
|
axes = [axes] |
|
|
|
|
|
color = _default_color(palette_colors, 0) |
|
|
|
|
|
for idx, season in enumerate(seasons): |
|
|
ax = axes[idx] |
|
|
sub = tmp[tmp["_season"] == season].sort_values(date_col) |
|
|
x_positions = range(len(sub)) |
|
|
ax.plot(x_positions, sub[y_col].values, marker="o", markersize=3, |
|
|
linewidth=1.2, color=color) |
|
|
|
|
|
mean_val = sub[y_col].mean() |
|
|
ax.axhline(mean_val, color=MIAMI_RED, linewidth=1.8, linestyle="--", alpha=0.8) |
|
|
|
|
|
ax.set_title(labels.get(season, str(season)), fontsize=10) |
|
|
ax.set_xticks([]) |
|
|
ax.tick_params(axis="y", labelsize=8) |
|
|
if idx == 0: |
|
|
ax.set_ylabel(y_col) |
|
|
|
|
|
if title: |
|
|
fig.suptitle(title, fontsize=14, fontweight="bold", y=1.02) |
|
|
return _finish_figure(fig) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def plot_acf_pacf( |
|
|
acf_vals: np.ndarray, |
|
|
acf_ci: np.ndarray, |
|
|
pacf_vals: np.ndarray, |
|
|
pacf_ci: np.ndarray, |
|
|
title: Optional[str] = None, |
|
|
style_dict: Optional[Dict[str, object]] = None, |
|
|
) -> matplotlib.figure.Figure: |
|
|
"""Side-by-side ACF and PACF bar plots with confidence-interval bands. |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
acf_vals, pacf_vals: |
|
|
1-D arrays of autocorrelation values (lag 0, 1, ...). |
|
|
acf_ci, pacf_ci: |
|
|
Arrays of shape ``(n_lags, 2)`` giving the lower and upper CI bounds. |
|
|
""" |
|
|
with _StyleContext(style_dict): |
|
|
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5)) |
|
|
|
|
|
for ax, vals, ci, sub_title in [ |
|
|
(ax1, acf_vals, acf_ci, "ACF"), |
|
|
(ax2, pacf_vals, pacf_ci, "PACF"), |
|
|
]: |
|
|
lags = np.arange(len(vals)) |
|
|
ax.bar(lags, vals, width=0.3, color=MIAMI_RED, alpha=0.85, zorder=2) |
|
|
|
|
|
|
|
|
lower = ci[:, 0] |
|
|
upper = ci[:, 1] |
|
|
ax.fill_between(lags, lower, upper, color="#C41230", alpha=0.12, zorder=1) |
|
|
ax.axhline(0, color="black", linewidth=0.8) |
|
|
|
|
|
ax.set_xlabel("Lag") |
|
|
ax.set_ylabel("Correlation") |
|
|
ax.set_title(sub_title) |
|
|
|
|
|
if title: |
|
|
fig.suptitle(title, fontsize=14, fontweight="bold", y=1.02) |
|
|
return _finish_figure(fig) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def plot_decomposition( |
|
|
decomposition_result, |
|
|
title: Optional[str] = None, |
|
|
style_dict: Optional[Dict[str, object]] = None, |
|
|
) -> matplotlib.figure.Figure: |
|
|
"""4-panel plot: observed, trend, seasonal, residual. |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
decomposition_result: |
|
|
An object with ``.observed``, ``.trend``, ``.seasonal``, and |
|
|
``.resid`` attributes (e.g. from ``statsmodels.tsa.seasonal_decompose``). |
|
|
""" |
|
|
with _StyleContext(style_dict): |
|
|
components = [ |
|
|
("Observed", decomposition_result.observed), |
|
|
("Trend", decomposition_result.trend), |
|
|
("Seasonal", decomposition_result.seasonal), |
|
|
("Residual", decomposition_result.resid), |
|
|
] |
|
|
fig, axes = plt.subplots(4, 1, figsize=(10, 10), sharex=True) |
|
|
|
|
|
for ax, (label, series) in zip(axes, components): |
|
|
ax.plot(series.index, series.values, linewidth=1.2, color=MIAMI_RED) |
|
|
ax.set_ylabel(label, fontsize=10) |
|
|
ax.tick_params(axis="both", labelsize=9) |
|
|
|
|
|
|
|
|
_auto_date_axis(axes[-1]) |
|
|
|
|
|
if title: |
|
|
fig.suptitle(title, fontsize=14, fontweight="bold", y=1.01) |
|
|
return _finish_figure(fig) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def plot_rolling_overlay( |
|
|
df: pd.DataFrame, |
|
|
date_col: str, |
|
|
y_col: str, |
|
|
window: int, |
|
|
title: Optional[str] = None, |
|
|
style_dict: Optional[Dict[str, object]] = None, |
|
|
palette_colors: Optional[List[str]] = None, |
|
|
) -> matplotlib.figure.Figure: |
|
|
"""Original series (light) with rolling-mean overlay (bold) and +/-1 std band.""" |
|
|
with _StyleContext(style_dict): |
|
|
fig, ax = plt.subplots(figsize=_DEFAULT_FIG_SIZE) |
|
|
|
|
|
raw_color = _default_color(palette_colors, 0) |
|
|
mean_color = _default_color(palette_colors, 1) if palette_colors and len(palette_colors) > 1 else "#333333" |
|
|
|
|
|
dates = df[date_col] |
|
|
vals = df[y_col] |
|
|
rolling_mean = vals.rolling(window=window, center=True).mean() |
|
|
rolling_std = vals.rolling(window=window, center=True).std() |
|
|
|
|
|
|
|
|
ax.plot(dates, vals, linewidth=0.8, alpha=0.4, color=raw_color, label="Original") |
|
|
|
|
|
|
|
|
ax.plot(dates, rolling_mean, linewidth=2.2, color=mean_color, |
|
|
label=f"{window}-pt Rolling Mean") |
|
|
|
|
|
|
|
|
ax.fill_between( |
|
|
dates, |
|
|
rolling_mean - rolling_std, |
|
|
rolling_mean + rolling_std, |
|
|
alpha=0.15, color=mean_color, label="\u00b11 Std Dev", |
|
|
) |
|
|
|
|
|
ax.set_xlabel(date_col) |
|
|
ax.set_ylabel(y_col) |
|
|
if title: |
|
|
ax.set_title(title) |
|
|
_auto_date_axis(ax) |
|
|
ax.legend(loc="best") |
|
|
return _finish_figure(fig) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def plot_yoy_change( |
|
|
df: pd.DataFrame, |
|
|
date_col: str, |
|
|
y_col: str, |
|
|
yoy_df: pd.DataFrame, |
|
|
title: Optional[str] = None, |
|
|
style_dict: Optional[Dict[str, object]] = None, |
|
|
) -> matplotlib.figure.Figure: |
|
|
"""Two-subplot bar chart: absolute YoY change (top) and percentage YoY change (bottom). |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
yoy_df: |
|
|
DataFrame with columns ``"date"``, ``"abs_change"``, ``"pct_change"``. |
|
|
""" |
|
|
with _StyleContext(style_dict): |
|
|
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 8), sharex=True) |
|
|
|
|
|
dates = yoy_df["date"] |
|
|
abs_change = yoy_df["abs_change"] |
|
|
pct_change = yoy_df["pct_change"] |
|
|
|
|
|
|
|
|
abs_colors = ["#2ca02c" if v >= 0 else "#d62728" for v in abs_change] |
|
|
pct_colors = ["#2ca02c" if v >= 0 else "#d62728" for v in pct_change] |
|
|
|
|
|
ax1.bar(dates, abs_change, color=abs_colors, width=20, edgecolor="white", linewidth=0.3) |
|
|
ax1.axhline(0, color="black", linewidth=0.6) |
|
|
ax1.set_ylabel("Absolute Change") |
|
|
ax1.set_title("Year-over-Year Absolute Change") |
|
|
|
|
|
ax2.bar(dates, pct_change, color=pct_colors, width=20, edgecolor="white", linewidth=0.3) |
|
|
ax2.axhline(0, color="black", linewidth=0.6) |
|
|
ax2.set_ylabel("% Change") |
|
|
ax2.set_title("Year-over-Year Percentage Change") |
|
|
|
|
|
_auto_date_axis(ax2) |
|
|
|
|
|
if title: |
|
|
fig.suptitle(title, fontsize=14, fontweight="bold", y=1.02) |
|
|
return _finish_figure(fig) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def plot_lag( |
|
|
series: pd.Series, |
|
|
lag: int = 1, |
|
|
title: Optional[str] = None, |
|
|
style_dict: Optional[Dict[str, object]] = None, |
|
|
) -> matplotlib.figure.Figure: |
|
|
"""Scatter plot of y(t) vs y(t-lag) with correlation-coefficient annotation.""" |
|
|
with _StyleContext(style_dict): |
|
|
y = series.dropna().values |
|
|
y_t = y[lag:] |
|
|
y_lag = y[:-lag] |
|
|
|
|
|
corr = np.corrcoef(y_t, y_lag)[0, 1] |
|
|
|
|
|
fig, ax = plt.subplots(figsize=(7, 7)) |
|
|
ax.scatter(y_lag, y_t, alpha=0.5, s=20, color=MIAMI_RED, edgecolors="white", linewidths=0.3) |
|
|
|
|
|
|
|
|
ax.annotate( |
|
|
f"r = {corr:.3f}", |
|
|
xy=(0.05, 0.95), xycoords="axes fraction", |
|
|
fontsize=12, fontweight="bold", |
|
|
bbox=dict(boxstyle="round,pad=0.3", facecolor="white", edgecolor="#CCCCCC", alpha=0.9), |
|
|
verticalalignment="top", |
|
|
) |
|
|
|
|
|
ax.set_xlabel(f"y(t\u2212{lag})") |
|
|
ax.set_ylabel("y(t)") |
|
|
if title: |
|
|
ax.set_title(title) |
|
|
else: |
|
|
ax.set_title(f"Lag-{lag} Plot") |
|
|
return _finish_figure(fig) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def plot_panel( |
|
|
df: pd.DataFrame, |
|
|
date_col: str, |
|
|
y_cols: List[str], |
|
|
chart_type: str = "line", |
|
|
shared_y: bool = True, |
|
|
title: Optional[str] = None, |
|
|
style_dict: Optional[Dict[str, object]] = None, |
|
|
palette_colors: Optional[List[str]] = None, |
|
|
) -> matplotlib.figure.Figure: |
|
|
"""Small multiples: one subplot per *y_col* arranged in a grid. |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
chart_type: |
|
|
``"line"`` or ``"bar"``. |
|
|
shared_y: |
|
|
If ``True`` all panels share the same y-axis limits. |
|
|
""" |
|
|
with _StyleContext(style_dict): |
|
|
n = len(y_cols) |
|
|
nrows, ncols = _grid_dims(n) |
|
|
fig_h = max(4, nrows * 3.5) |
|
|
fig_w = max(8, ncols * 4.5) |
|
|
fig, axes = plt.subplots( |
|
|
nrows, ncols, figsize=(fig_w, fig_h), |
|
|
sharey=shared_y, squeeze=False, |
|
|
) |
|
|
flat_axes = axes.flatten() |
|
|
|
|
|
|
|
|
bar_width = None |
|
|
if chart_type == "bar": |
|
|
dates_sorted = df[date_col].dropna().sort_values() |
|
|
if len(dates_sorted) >= 2: |
|
|
median_gap = dates_sorted.diff().dropna().median() |
|
|
bar_width = median_gap * 0.9 |
|
|
else: |
|
|
bar_width = 1 |
|
|
|
|
|
for i, col in enumerate(y_cols): |
|
|
ax = flat_axes[i] |
|
|
color = _default_color(palette_colors, i) |
|
|
|
|
|
if chart_type == "bar": |
|
|
ax.bar(df[date_col], df[col], width=bar_width, color=color, alpha=0.85, edgecolor="white", linewidth=0.3) |
|
|
else: |
|
|
ax.plot(df[date_col], df[col], linewidth=1.3, color=color) |
|
|
|
|
|
ax.set_title(col, fontsize=10) |
|
|
_auto_date_axis(ax) |
|
|
|
|
|
|
|
|
for j in range(n, len(flat_axes)): |
|
|
flat_axes[j].set_visible(False) |
|
|
|
|
|
if title: |
|
|
fig.suptitle(title, fontsize=14, fontweight="bold", y=1.02) |
|
|
return _finish_figure(fig) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def plot_spaghetti( |
|
|
df: pd.DataFrame, |
|
|
date_col: str, |
|
|
y_cols: List[str], |
|
|
alpha: float = 0.15, |
|
|
highlight_col: Optional[str] = None, |
|
|
top_n: Optional[int] = None, |
|
|
show_median_band: bool = False, |
|
|
title: Optional[str] = None, |
|
|
style_dict: Optional[Dict[str, object]] = None, |
|
|
palette_colors: Optional[List[str]] = None, |
|
|
) -> matplotlib.figure.Figure: |
|
|
"""All series on one plot at low opacity, with optional highlighting. |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
alpha: |
|
|
Opacity for the background spaghetti lines. |
|
|
highlight_col: |
|
|
Column name to draw with full opacity and thicker line. |
|
|
top_n: |
|
|
If set, highlight the *top_n* series by maximum value. |
|
|
show_median_band: |
|
|
If ``True``, overlay the median line and shade the IQR. |
|
|
""" |
|
|
with _StyleContext(style_dict): |
|
|
fig, ax = plt.subplots(figsize=_DEFAULT_FIG_SIZE) |
|
|
|
|
|
dates = df[date_col] |
|
|
|
|
|
|
|
|
highlight_set: set[str] = set() |
|
|
if highlight_col and highlight_col in y_cols: |
|
|
highlight_set.add(highlight_col) |
|
|
if top_n: |
|
|
max_vals = {col: df[col].max() for col in y_cols} |
|
|
sorted_cols = sorted(max_vals, key=max_vals.get, reverse=True) |
|
|
highlight_set.update(sorted_cols[:top_n]) |
|
|
|
|
|
|
|
|
for i, col in enumerate(y_cols): |
|
|
color = _default_color(palette_colors, i) |
|
|
if col in highlight_set: |
|
|
ax.plot(dates, df[col], linewidth=2.0, alpha=0.9, |
|
|
color=color, label=col, zorder=3) |
|
|
else: |
|
|
ax.plot(dates, df[col], linewidth=0.8, alpha=alpha, |
|
|
color=color, zorder=1) |
|
|
|
|
|
|
|
|
if show_median_band: |
|
|
numeric_data = df[y_cols] |
|
|
median_line = numeric_data.median(axis=1) |
|
|
q1 = numeric_data.quantile(0.25, axis=1) |
|
|
q3 = numeric_data.quantile(0.75, axis=1) |
|
|
|
|
|
ax.plot(dates, median_line, linewidth=2.2, color="#333333", |
|
|
label="Median", zorder=4) |
|
|
ax.fill_between(dates, q1, q3, alpha=0.2, color="#333333", |
|
|
label="IQR", zorder=2) |
|
|
|
|
|
ax.set_xlabel(date_col) |
|
|
ax.set_ylabel("Value") |
|
|
if title: |
|
|
ax.set_title(title) |
|
|
_auto_date_axis(ax) |
|
|
|
|
|
|
|
|
handles, labels = ax.get_legend_handles_labels() |
|
|
if labels: |
|
|
ax.legend(loc="best", fontsize=8) |
|
|
return _finish_figure(fig) |
|
|
|