fmegahed's picture
Cleaning up the app
789e257
"""
plotting.py
-----------
Chart-generation functions for time-series visualisation.
Every public function returns a :class:`matplotlib.figure.Figure` object.
Callers (e.g. Streamlit pages) can pass the figure to ``st.pyplot(fig)``
or convert it to PNG bytes via :func:`fig_to_png_bytes`.
All functions accept an optional *style_dict* (typically from
:func:`ui_theme.get_miami_mpl_style`) and an optional *palette_colors*
list so that colours stay consistent across the application.
"""
from __future__ import annotations
import io
import math
from typing import Dict, List, Optional, Sequence
# CRITICAL: set the non-interactive backend before any other mpl import.
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt # noqa: E402
import matplotlib.dates as mdates # noqa: E402
import numpy as np # noqa: E402
import pandas as pd # noqa: E402
# ---------------------------------------------------------------------------
# Brand defaults (mirrors ui_theme.py)
# ---------------------------------------------------------------------------
MIAMI_RED: str = "#C41230"
_DEFAULT_FIG_SIZE = (10, 6)
# ---------------------------------------------------------------------------
# Utility
# ---------------------------------------------------------------------------
def fig_to_png_bytes(fig: matplotlib.figure.Figure, dpi: int = 150) -> bytes:
"""Render *fig* to an in-memory PNG and return the raw bytes."""
buf = io.BytesIO()
fig.savefig(buf, format="png", dpi=dpi, bbox_inches="tight")
buf.seek(0)
return buf.read()
# ---------------------------------------------------------------------------
# Internal helpers
# ---------------------------------------------------------------------------
class _StyleContext:
"""Context manager that temporarily applies *style_dict* to rcParams.
On exit the previous values are restored so that other figures are not
affected.
"""
def __init__(self, style_dict: Optional[Dict[str, object]]):
self._style = style_dict
self._saved: Dict[str, object] = {}
def __enter__(self) -> "_StyleContext":
if self._style:
for key, value in self._style.items():
self._saved[key] = plt.rcParams.get(key)
try:
plt.rcParams[key] = value
except (KeyError, ValueError):
pass
return self
def __exit__(self, *exc_info: object) -> None:
for key, value in self._saved.items():
try:
plt.rcParams[key] = value
except (KeyError, ValueError):
pass
def _default_color(palette_colors: Optional[List[str]], idx: int = 0) -> str:
"""Pick a colour from *palette_colors* or fall back to MIAMI_RED."""
if palette_colors and len(palette_colors) > idx:
return palette_colors[idx % len(palette_colors)]
return MIAMI_RED
def _finish_figure(fig: matplotlib.figure.Figure) -> matplotlib.figure.Figure:
"""Apply common finishing touches and return the figure."""
fig.tight_layout()
return fig
def _auto_date_axis(ax: plt.Axes) -> None:
"""Auto-format and rotate date tick labels."""
ax.xaxis.set_major_formatter(mdates.AutoDateFormatter(mdates.AutoDateLocator()))
for label in ax.get_xticklabels():
label.set_rotation(30)
label.set_ha("right")
def _grid_dims(n: int) -> tuple[int, int]:
"""Return (nrows, ncols) for a compact grid of *n* panels."""
ncols = min(n, 3)
nrows = math.ceil(n / ncols)
return nrows, ncols
# ===================================================================
# 1. Line with markers
# ===================================================================
def plot_line_with_markers(
df: pd.DataFrame,
date_col: str,
y_col: str,
title: Optional[str] = None,
style_dict: Optional[Dict[str, object]] = None,
palette_colors: Optional[List[str]] = None,
) -> matplotlib.figure.Figure:
"""Simple line plot with small circle markers.
Uses the first palette colour or *MIAMI_RED* as the default.
"""
with _StyleContext(style_dict):
fig, ax = plt.subplots(figsize=_DEFAULT_FIG_SIZE)
color = _default_color(palette_colors, 0)
ax.plot(
df[date_col], df[y_col],
marker="o", markersize=4, linewidth=1.5,
color=color, label=y_col,
)
ax.set_xlabel(date_col)
ax.set_ylabel(y_col)
if title:
ax.set_title(title)
_auto_date_axis(ax)
ax.legend(loc="best")
return _finish_figure(fig)
# ===================================================================
# 2. Line with coloured markers
# ===================================================================
def plot_line_colored_markers(
df: pd.DataFrame,
date_col: str,
y_col: str,
color_by: str,
palette_colors: List[str],
title: Optional[str] = None,
style_dict: Optional[Dict[str, object]] = None,
) -> matplotlib.figure.Figure:
"""Line plot where marker colour varies by a categorical column.
A legend is added mapping each unique value of *color_by* to its
colour.
"""
with _StyleContext(style_dict):
fig, ax = plt.subplots(figsize=_DEFAULT_FIG_SIZE)
# Draw the connecting line in a neutral grey
ax.plot(
df[date_col], df[y_col],
linewidth=1.0, color="#AAAAAA", zorder=1,
)
# Map categories to colours
categories = df[color_by].unique()
n_cats = len(categories)
if len(palette_colors) < n_cats:
# cycle palette to cover all categories
import itertools
palette_colors = list(itertools.islice(
itertools.cycle(palette_colors), n_cats
))
color_map = {cat: palette_colors[i] for i, cat in enumerate(categories)}
for cat in categories:
mask = df[color_by] == cat
ax.scatter(
df.loc[mask, date_col], df.loc[mask, y_col],
c=color_map[cat], label=str(cat),
s=30, zorder=2, edgecolors="white", linewidths=0.3,
)
ax.set_xlabel(date_col)
ax.set_ylabel(y_col)
if title:
ax.set_title(title)
_auto_date_axis(ax)
ax.legend(title=color_by, loc="best", fontsize=8, ncol=max(1, n_cats // 8))
return _finish_figure(fig)
# ===================================================================
# 3. Seasonal plot
# ===================================================================
def plot_seasonal(
df: pd.DataFrame,
date_col: str,
y_col: str,
period: str,
palette_name_colors: List[str],
title: Optional[str] = None,
style_dict: Optional[Dict[str, object]] = None,
) -> matplotlib.figure.Figure:
"""Seasonal plot: one line per year/cycle, x-axis is within-period position.
Parameters
----------
period:
``"month"`` (x = month 1-12) or ``"quarter"`` (x = quarter 1-4).
palette_name_colors:
List of hex colours; one per cycle/year.
"""
with _StyleContext(style_dict):
tmp = df[[date_col, y_col]].copy()
tmp["_year"] = tmp[date_col].dt.year
if period.lower().startswith("q"):
tmp["_period_pos"] = tmp[date_col].dt.quarter
x_label = "Quarter"
else:
tmp["_period_pos"] = tmp[date_col].dt.month
x_label = "Month"
years = sorted(tmp["_year"].unique())
n_years = len(years)
if len(palette_name_colors) < n_years:
import itertools
palette_name_colors = list(itertools.islice(
itertools.cycle(palette_name_colors), n_years
))
fig, ax = plt.subplots(figsize=_DEFAULT_FIG_SIZE)
for i, year in enumerate(years):
sub = tmp[tmp["_year"] == year].sort_values("_period_pos")
ax.plot(
sub["_period_pos"], sub[y_col],
marker="o", markersize=4, linewidth=1.4,
color=palette_name_colors[i], label=str(year),
)
ax.set_xlabel(x_label)
ax.set_ylabel(y_col)
if title:
ax.set_title(title)
ax.legend(title="Year", loc="best", fontsize=8, ncol=max(1, n_years // 6))
return _finish_figure(fig)
# ===================================================================
# 4. Seasonal sub-series
# ===================================================================
def plot_seasonal_subseries(
df: pd.DataFrame,
date_col: str,
y_col: str,
period: str,
title: Optional[str] = None,
style_dict: Optional[Dict[str, object]] = None,
palette_colors: Optional[List[str]] = None,
) -> matplotlib.figure.Figure:
"""Subseries plot with vertical panels for each season and horizontal mean lines.
Parameters
----------
period:
``"month"`` or ``"quarter"``.
"""
with _StyleContext(style_dict):
tmp = df[[date_col, y_col]].copy()
if period.lower().startswith("q"):
tmp["_season"] = tmp[date_col].dt.quarter
labels = {1: "Q1", 2: "Q2", 3: "Q3", 4: "Q4"}
else:
tmp["_season"] = tmp[date_col].dt.month
labels = {
1: "Jan", 2: "Feb", 3: "Mar", 4: "Apr",
5: "May", 6: "Jun", 7: "Jul", 8: "Aug",
9: "Sep", 10: "Oct", 11: "Nov", 12: "Dec",
}
seasons = sorted(tmp["_season"].unique())
n = len(seasons)
fig_w = max(10, n * 1.3)
fig, axes = plt.subplots(1, n, figsize=(fig_w, 5), sharey=True)
if n == 1:
axes = [axes]
color = _default_color(palette_colors, 0)
for idx, season in enumerate(seasons):
ax = axes[idx]
sub = tmp[tmp["_season"] == season].sort_values(date_col)
x_positions = range(len(sub))
ax.plot(x_positions, sub[y_col].values, marker="o", markersize=3,
linewidth=1.2, color=color)
mean_val = sub[y_col].mean()
ax.axhline(mean_val, color=MIAMI_RED, linewidth=1.8, linestyle="--", alpha=0.8)
ax.set_title(labels.get(season, str(season)), fontsize=10)
ax.set_xticks([])
ax.tick_params(axis="y", labelsize=8)
if idx == 0:
ax.set_ylabel(y_col)
if title:
fig.suptitle(title, fontsize=14, fontweight="bold", y=1.02)
return _finish_figure(fig)
# ===================================================================
# 5. ACF / PACF
# ===================================================================
def plot_acf_pacf(
acf_vals: np.ndarray,
acf_ci: np.ndarray,
pacf_vals: np.ndarray,
pacf_ci: np.ndarray,
title: Optional[str] = None,
style_dict: Optional[Dict[str, object]] = None,
) -> matplotlib.figure.Figure:
"""Side-by-side ACF and PACF bar plots with confidence-interval bands.
Parameters
----------
acf_vals, pacf_vals:
1-D arrays of autocorrelation values (lag 0, 1, ...).
acf_ci, pacf_ci:
Arrays of shape ``(n_lags, 2)`` giving the lower and upper CI bounds.
"""
with _StyleContext(style_dict):
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
for ax, vals, ci, sub_title in [
(ax1, acf_vals, acf_ci, "ACF"),
(ax2, pacf_vals, pacf_ci, "PACF"),
]:
lags = np.arange(len(vals))
ax.bar(lags, vals, width=0.3, color=MIAMI_RED, alpha=0.85, zorder=2)
# Confidence band
lower = ci[:, 0]
upper = ci[:, 1]
ax.fill_between(lags, lower, upper, color="#C41230", alpha=0.12, zorder=1)
ax.axhline(0, color="black", linewidth=0.8)
ax.set_xlabel("Lag")
ax.set_ylabel("Correlation")
ax.set_title(sub_title)
if title:
fig.suptitle(title, fontsize=14, fontweight="bold", y=1.02)
return _finish_figure(fig)
# ===================================================================
# 6. Decomposition
# ===================================================================
def plot_decomposition(
decomposition_result,
title: Optional[str] = None,
style_dict: Optional[Dict[str, object]] = None,
) -> matplotlib.figure.Figure:
"""4-panel plot: observed, trend, seasonal, residual.
Parameters
----------
decomposition_result:
An object with ``.observed``, ``.trend``, ``.seasonal``, and
``.resid`` attributes (e.g. from ``statsmodels.tsa.seasonal_decompose``).
"""
with _StyleContext(style_dict):
components = [
("Observed", decomposition_result.observed),
("Trend", decomposition_result.trend),
("Seasonal", decomposition_result.seasonal),
("Residual", decomposition_result.resid),
]
fig, axes = plt.subplots(4, 1, figsize=(10, 10), sharex=True)
for ax, (label, series) in zip(axes, components):
ax.plot(series.index, series.values, linewidth=1.2, color=MIAMI_RED)
ax.set_ylabel(label, fontsize=10)
ax.tick_params(axis="both", labelsize=9)
# Date formatting on the shared x-axis (bottom panel)
_auto_date_axis(axes[-1])
if title:
fig.suptitle(title, fontsize=14, fontweight="bold", y=1.01)
return _finish_figure(fig)
# ===================================================================
# 7. Rolling overlay
# ===================================================================
def plot_rolling_overlay(
df: pd.DataFrame,
date_col: str,
y_col: str,
window: int,
title: Optional[str] = None,
style_dict: Optional[Dict[str, object]] = None,
palette_colors: Optional[List[str]] = None,
) -> matplotlib.figure.Figure:
"""Original series (light) with rolling-mean overlay (bold) and +/-1 std band."""
with _StyleContext(style_dict):
fig, ax = plt.subplots(figsize=_DEFAULT_FIG_SIZE)
raw_color = _default_color(palette_colors, 0)
mean_color = _default_color(palette_colors, 1) if palette_colors and len(palette_colors) > 1 else "#333333"
dates = df[date_col]
vals = df[y_col]
rolling_mean = vals.rolling(window=window, center=True).mean()
rolling_std = vals.rolling(window=window, center=True).std()
# Original series (light)
ax.plot(dates, vals, linewidth=0.8, alpha=0.4, color=raw_color, label="Original")
# Rolling mean (bold)
ax.plot(dates, rolling_mean, linewidth=2.2, color=mean_color,
label=f"{window}-pt Rolling Mean")
# +/- 1 std band
ax.fill_between(
dates,
rolling_mean - rolling_std,
rolling_mean + rolling_std,
alpha=0.15, color=mean_color, label="\u00b11 Std Dev",
)
ax.set_xlabel(date_col)
ax.set_ylabel(y_col)
if title:
ax.set_title(title)
_auto_date_axis(ax)
ax.legend(loc="best")
return _finish_figure(fig)
# ===================================================================
# 8. Year-over-Year change
# ===================================================================
def plot_yoy_change(
df: pd.DataFrame,
date_col: str,
y_col: str,
yoy_df: pd.DataFrame,
title: Optional[str] = None,
style_dict: Optional[Dict[str, object]] = None,
) -> matplotlib.figure.Figure:
"""Two-subplot bar chart: absolute YoY change (top) and percentage YoY change (bottom).
Parameters
----------
yoy_df:
DataFrame with columns ``"date"``, ``"abs_change"``, ``"pct_change"``.
"""
with _StyleContext(style_dict):
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 8), sharex=True)
dates = yoy_df["date"]
abs_change = yoy_df["abs_change"]
pct_change = yoy_df["pct_change"]
# Colours: green for positive, red for negative
abs_colors = ["#2ca02c" if v >= 0 else "#d62728" for v in abs_change]
pct_colors = ["#2ca02c" if v >= 0 else "#d62728" for v in pct_change]
ax1.bar(dates, abs_change, color=abs_colors, width=20, edgecolor="white", linewidth=0.3)
ax1.axhline(0, color="black", linewidth=0.6)
ax1.set_ylabel("Absolute Change")
ax1.set_title("Year-over-Year Absolute Change")
ax2.bar(dates, pct_change, color=pct_colors, width=20, edgecolor="white", linewidth=0.3)
ax2.axhline(0, color="black", linewidth=0.6)
ax2.set_ylabel("% Change")
ax2.set_title("Year-over-Year Percentage Change")
_auto_date_axis(ax2)
if title:
fig.suptitle(title, fontsize=14, fontweight="bold", y=1.02)
return _finish_figure(fig)
# ===================================================================
# 9. Lag plot
# ===================================================================
def plot_lag(
series: pd.Series,
lag: int = 1,
title: Optional[str] = None,
style_dict: Optional[Dict[str, object]] = None,
) -> matplotlib.figure.Figure:
"""Scatter plot of y(t) vs y(t-lag) with correlation-coefficient annotation."""
with _StyleContext(style_dict):
y = series.dropna().values
y_t = y[lag:]
y_lag = y[:-lag]
corr = np.corrcoef(y_t, y_lag)[0, 1]
fig, ax = plt.subplots(figsize=(7, 7))
ax.scatter(y_lag, y_t, alpha=0.5, s=20, color=MIAMI_RED, edgecolors="white", linewidths=0.3)
# Annotation
ax.annotate(
f"r = {corr:.3f}",
xy=(0.05, 0.95), xycoords="axes fraction",
fontsize=12, fontweight="bold",
bbox=dict(boxstyle="round,pad=0.3", facecolor="white", edgecolor="#CCCCCC", alpha=0.9),
verticalalignment="top",
)
ax.set_xlabel(f"y(t\u2212{lag})")
ax.set_ylabel("y(t)")
if title:
ax.set_title(title)
else:
ax.set_title(f"Lag-{lag} Plot")
return _finish_figure(fig)
# ===================================================================
# 10. Panel (small multiples)
# ===================================================================
def plot_panel(
df: pd.DataFrame,
date_col: str,
y_cols: List[str],
chart_type: str = "line",
shared_y: bool = True,
title: Optional[str] = None,
style_dict: Optional[Dict[str, object]] = None,
palette_colors: Optional[List[str]] = None,
) -> matplotlib.figure.Figure:
"""Small multiples: one subplot per *y_col* arranged in a grid.
Parameters
----------
chart_type:
``"line"`` or ``"bar"``.
shared_y:
If ``True`` all panels share the same y-axis limits.
"""
with _StyleContext(style_dict):
n = len(y_cols)
nrows, ncols = _grid_dims(n)
fig_h = max(4, nrows * 3.5)
fig_w = max(8, ncols * 4.5)
fig, axes = plt.subplots(
nrows, ncols, figsize=(fig_w, fig_h),
sharey=shared_y, squeeze=False,
)
flat_axes = axes.flatten()
# Compute a sensible bar width from the median date spacing
bar_width = None
if chart_type == "bar":
dates_sorted = df[date_col].dropna().sort_values()
if len(dates_sorted) >= 2:
median_gap = dates_sorted.diff().dropna().median()
bar_width = median_gap * 0.9
else:
bar_width = 1
for i, col in enumerate(y_cols):
ax = flat_axes[i]
color = _default_color(palette_colors, i)
if chart_type == "bar":
ax.bar(df[date_col], df[col], width=bar_width, color=color, alpha=0.85, edgecolor="white", linewidth=0.3)
else:
ax.plot(df[date_col], df[col], linewidth=1.3, color=color)
ax.set_title(col, fontsize=10)
_auto_date_axis(ax)
# Hide unused subplots
for j in range(n, len(flat_axes)):
flat_axes[j].set_visible(False)
if title:
fig.suptitle(title, fontsize=14, fontweight="bold", y=1.02)
return _finish_figure(fig)
# ===================================================================
# 11. Spaghetti plot
# ===================================================================
def plot_spaghetti(
df: pd.DataFrame,
date_col: str,
y_cols: List[str],
alpha: float = 0.15,
highlight_col: Optional[str] = None,
top_n: Optional[int] = None,
show_median_band: bool = False,
title: Optional[str] = None,
style_dict: Optional[Dict[str, object]] = None,
palette_colors: Optional[List[str]] = None,
) -> matplotlib.figure.Figure:
"""All series on one plot at low opacity, with optional highlighting.
Parameters
----------
alpha:
Opacity for the background spaghetti lines.
highlight_col:
Column name to draw with full opacity and thicker line.
top_n:
If set, highlight the *top_n* series by maximum value.
show_median_band:
If ``True``, overlay the median line and shade the IQR.
"""
with _StyleContext(style_dict):
fig, ax = plt.subplots(figsize=_DEFAULT_FIG_SIZE)
dates = df[date_col]
# Determine which columns to highlight
highlight_set: set[str] = set()
if highlight_col and highlight_col in y_cols:
highlight_set.add(highlight_col)
if top_n:
max_vals = {col: df[col].max() for col in y_cols}
sorted_cols = sorted(max_vals, key=max_vals.get, reverse=True) # type: ignore[arg-type]
highlight_set.update(sorted_cols[:top_n])
# Draw all series
for i, col in enumerate(y_cols):
color = _default_color(palette_colors, i)
if col in highlight_set:
ax.plot(dates, df[col], linewidth=2.0, alpha=0.9,
color=color, label=col, zorder=3)
else:
ax.plot(dates, df[col], linewidth=0.8, alpha=alpha,
color=color, zorder=1)
# Median + IQR band
if show_median_band:
numeric_data = df[y_cols]
median_line = numeric_data.median(axis=1)
q1 = numeric_data.quantile(0.25, axis=1)
q3 = numeric_data.quantile(0.75, axis=1)
ax.plot(dates, median_line, linewidth=2.2, color="#333333",
label="Median", zorder=4)
ax.fill_between(dates, q1, q3, alpha=0.2, color="#333333",
label="IQR", zorder=2)
ax.set_xlabel(date_col)
ax.set_ylabel("Value")
if title:
ax.set_title(title)
_auto_date_axis(ax)
# Only add legend if there are labelled items
handles, labels = ax.get_legend_handles_labels()
if labels:
ax.legend(loc="best", fontsize=8)
return _finish_figure(fig)