diff --git a/_quarto.yml b/_quarto.yml index 63015e8e..ec490fa3 100644 --- a/_quarto.yml +++ b/_quarto.yml @@ -116,6 +116,8 @@ website: file: docs/reference/model_selection.utils_metrics.qmd - section: "Plots" contents: + - text: "diagnostics" + file: docs/reference/plots.diagnostics.qmd - text: "outlier_plots" file: docs/reference/plots.outlier_plots.qmd - text: "plotter" @@ -215,8 +217,9 @@ quartodoc: sections: - title: "Plots" - desc: "Outlier, prediction, and time-series visualization tools." + desc: "Outlier, prediction, time-series, and operational diagnostic visualization tools." contents: + - plots.diagnostics - plots.outlier_plots - plots.plotter - plots.time_series_visualization diff --git a/src/spotforecast2/plots/__init__.py b/src/spotforecast2/plots/__init__.py index ba0923de..7f14046a 100644 --- a/src/spotforecast2/plots/__init__.py +++ b/src/spotforecast2/plots/__init__.py @@ -1,6 +1,12 @@ # SPDX-FileCopyrightText: 2026 bartzbeielstein # SPDX-License-Identifier: AGPL-3.0-or-later +from .diagnostics import ( + plot_acf_with_lags, + plot_feature_importance_by_family, + plot_forecast_vs_reference, + plot_shap_summary, +) from .distribution import plot_distribution, plot_distributions from .outlier_plots import ( visualize_outliers_hist, @@ -13,9 +19,13 @@ ) __all__ = [ + "plot_acf_with_lags", "plot_distribution", "plot_distributions", + "plot_feature_importance_by_family", + "plot_forecast_vs_reference", "plot_periodogram", + "plot_shap_summary", "visualize_outliers_hist", "visualize_outliers_plotly_scatter", "visualize_ts_plotly", diff --git a/src/spotforecast2/plots/diagnostics.py b/src/spotforecast2/plots/diagnostics.py new file mode 100644 index 00000000..f5d68b34 --- /dev/null +++ b/src/spotforecast2/plots/diagnostics.py @@ -0,0 +1,375 @@ +# SPDX-FileCopyrightText: 2026 bartzbeielstein +# SPDX-License-Identifier: AGPL-3.0-or-later + +"""Operational diagnostic plots for energy-load forecasting pipelines. + +Ports five matplotlib helpers from the chapter-14 team4 operational script +(`bart26k-lecture/scripts/team4_4zones_submit.py`) into a reusable, stateless +API. All functions return a `matplotlib.figure.Figure`; the caller is +responsible for saving and closing it. No `plt.show()` is called and the +backend is never changed here (set `matplotlib.use("Agg")` before importing +`pyplot` in headless environments). +""" + +from __future__ import annotations + +import logging +from typing import Sequence + +import matplotlib.pyplot as plt +import pandas as pd +from matplotlib.figure import Figure + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Family colour map — identical to the chapter-14 script so PNG diagnostics +# look the same whether run from the script or via this module. +# --------------------------------------------------------------------------- +_FAMILY_COLOR: dict[str, str] = { + "lag": "#B22222", + "weather/other": "#E07B00", + "weather_window": "#F0A33C", + "cyclical/RBF": "#1F4E79", + "holiday": "#2E8B57", + "polynomial": "#7B5EA7", +} + + +# --------------------------------------------------------------------------- +# Public helpers +# --------------------------------------------------------------------------- + + +def feature_family(name: str) -> str: + """Map a feature name to its diagnostic family label. + + This is the public, importable version of the `family_of` helper used + inside the chapter-14 team4 operational script. The mapping is + intentionally coarse — it covers the feature names that `ConfigEntsoe` + and `ForecasterRecursive` generate and is used exclusively for colour + grouping in `plot_feature_importance_by_family`. + + Args: + name: Feature name as returned by `LGBMRegressor.feature_name_`. + + Returns: + One of: `"holiday"`, `"polynomial"`, `"weather_window"`, + `"cyclical/RBF"`, `"lag"`, `"weather/other"`. + + Examples: + ```{python} + from spotforecast2.plots.diagnostics import feature_family + + print(feature_family("holiday_DE")) + print(feature_family("brueckentag_NW")) + print(feature_family("poly_hour_2")) + print(feature_family("window_mean_72")) + print(feature_family("sin_hour")) + print(feature_family("lag_1")) + print(feature_family("wind_speed_10m")) + ``` + """ + c = name.lower() + if "holiday" in c or "brueckentag" in c: + return "holiday" + if "poly" in c: + return "polynomial" + if "window" in c: + return "weather_window" + if any(k in c for k in ("sin", "cos", "rbf")): + return "cyclical/RBF" + if c.startswith("lag"): + return "lag" + return "weather/other" + + +def plot_acf_with_lags( + acf: pd.DataFrame, + key_lags: Sequence[int], + conf: float, +) -> Figure: + """Bar chart of autocorrelation values with annotated key lags. + + Ports `_plot_acf` from the chapter-14 team4 script. The `acf` DataFrame + is the output of + `spotforecast2.stats.autocorrelation.calculate_lag_autocorrelation` and + must contain at minimum the columns `"lag"` and `"autocorrelation"`. + + Confidence-band lines at `+conf` / `-conf` are drawn as dashed grey + horizontal lines. Each lag in `key_lags` that is present in `acf["lag"]` + gets an orange arrow annotation. Lags not found in the frame are silently + skipped. + + Args: + acf: DataFrame with columns `"lag"` (int) and `"autocorrelation"` + (float), as returned by `calculate_lag_autocorrelation`. + key_lags: Sequence of lag values to annotate (e.g. the PACF-selected + lags from the pipeline). + conf: Half-width of the confidence band, typically + `1.96 / sqrt(n_obs)`. + + Returns: + A `matplotlib.figure.Figure`. + + Examples: + ```{python} + import numpy as np + import pandas as pd + from spotforecast2.plots.diagnostics import plot_acf_with_lags + + rng = np.random.default_rng(0) + n = 100 + acf = pd.DataFrame({"lag": range(n), "autocorrelation": rng.uniform(-0.3, 0.3, n)}) + fig = plot_acf_with_lags(acf, key_lags=[1, 24, 48], conf=0.1) + print(type(fig).__name__) + ``` + """ + fig, ax = plt.subplots(figsize=(8.5, 3.0)) + ax.bar(acf["lag"], acf["autocorrelation"], width=0.9, color="#1F4E79") + ax.axhline(conf, color="#999999", lw=0.8, ls="--") + ax.axhline(-conf, color="#999999", lw=0.8, ls="--") + for lag in key_lags: + row = acf[acf["lag"] == lag] + if row.empty: + continue + val = float(row["autocorrelation"].iloc[0]) + ax.annotate( + f"lag {lag}", + xy=(lag, val), + xytext=(lag, val + 0.08), + ha="center", + fontsize=8, + color="#E07B00", + arrowprops=dict(arrowstyle="->", color="#E07B00", lw=0.8), + ) + ax.set_xlabel("lag (hours)") + ax.set_ylabel("autocorrelation") + ax.set_title("ACF of Actual Load (annotated: data-selected key_lags)") + ax.grid(True, color="#E5E5E5", linewidth=0.5) + for spine in ("top", "right"): + ax.spines[spine].set_visible(False) + fig.tight_layout() + return fig + + +def plot_feature_importance_by_family( + names: Sequence[str], + importances: Sequence[float], + *, + top_n: int = 20, +) -> Figure: + """Horizontal bar chart of the top-N feature importances, coloured by family. + + Ports `_plot_importance` from the chapter-14 team4 script. Feature + families are determined by `feature_family`; the colour palette is the + same as in the script so diagnostics look identical. + + Args: + names: Feature names (e.g. `fc.estimator.feature_name_`). + importances: Corresponding importance scores (e.g. + `fc.estimator.feature_importances_`). + top_n: Number of top features to display. Defaults to 20. + + Returns: + A `matplotlib.figure.Figure`. + + Examples: + ```{python} + from spotforecast2.plots.diagnostics import plot_feature_importance_by_family + + names = ["lag_1", "lag_24", "holiday_DE", "wind_speed_10m", + "poly_hour_2", "window_mean_72", "sin_hour"] + scores = [100, 80, 60, 55, 40, 35, 20] + fig = plot_feature_importance_by_family(names, scores, top_n=5) + print(type(fig).__name__) + ``` + """ + ranking = sorted( + zip(names, importances), key=lambda kv: kv[1], reverse=True + )[:top_n] + labels = [n for n, _ in ranking][::-1] + values = [v for _, v in ranking][::-1] + colors = [_FAMILY_COLOR.get(feature_family(n), "#888888") for n in labels] + + fig, ax = plt.subplots(figsize=(8.5, 5.5)) + ax.barh(labels, values, color=colors) + ax.set_xlabel("split count (feature importance)") + ax.set_title( + f"Top-{top_n} feature importances (coloured by family; lags in red)" + ) + ax.grid(True, axis="x", color="#E5E5E5", linewidth=0.5) + handles = [ + plt.Rectangle((0, 0), 1, 1, color=c) for c in _FAMILY_COLOR.values() + ] + ax.legend(handles, _FAMILY_COLOR.keys(), fontsize=7, loc="lower right") + for spine in ("top", "right"): + ax.spines[spine].set_visible(False) + fig.tight_layout() + return fig + + +def plot_shap_summary( + estimator: object, + X: pd.DataFrame, + *, + max_samples: int = 2000, +) -> Figure: + """SHAP bar-summary plot for a tree-based estimator. + + Ports `_plot_shap` from the chapter-14 team4 script. `X` is subsampled + to approximately `max_samples` rows (uniform stride `len(X) // max_samples`; + lengths just above `max_samples` are passed in full) before computing SHAP + values so the call stays fast even for large training matrices. + + The function uses `shap.TreeExplainer` and + `shap.summary_plot(plot_type="bar", show=False)`, then captures the + current matplotlib figure via `plt.gcf()`. Because the figure is harvested + from the global pyplot state this function is **not thread-safe**. Callers + must close the returned figure (e.g. `plt.close(fig)`) before performing + other pyplot work. + + Args: + estimator: A fitted tree-based estimator supported by + `shap.TreeExplainer` (e.g. `LGBMRegressor`). + X: Feature matrix; typically the design matrix returned by + `ForecasterRecursive.create_train_X_y`. + max_samples: Row budget passed to the SHAP explainer. Defaults to 2000. + + Returns: + A `matplotlib.figure.Figure`. + + Examples: + ```{python} + #| eval: false + import numpy as np + import pandas as pd + from lightgbm import LGBMRegressor + from spotforecast2.plots.diagnostics import plot_shap_summary + + rng = np.random.default_rng(0) + X = pd.DataFrame(rng.standard_normal((200, 5)), + columns=[f"f{i}" for i in range(5)]) + y = X["f0"] + rng.standard_normal(200) * 0.1 + est = LGBMRegressor(n_estimators=20, verbose=-1) + est.fit(X, y) + fig = plot_shap_summary(est, X, max_samples=100) + print(type(fig).__name__) + ``` + """ + import shap + + step = max(1, len(X) // max_samples) + X_sample = X.iloc[::step] + explainer = shap.TreeExplainer(estimator) + sv = explainer.shap_values(X_sample) + shap.summary_plot(sv, X_sample, plot_type="bar", show=False) + fig = plt.gcf() + fig.tight_layout() + return fig + + +def plot_forecast_vs_reference( + forecast: pd.Series, + reference: pd.Series, + *, + forecast_label: str = "forecast", + reference_label: str = "reference", + unit_scale: float = 1e-3, + unit: str = "GW", +) -> Figure: + """Line plot comparing a forecast against an optional reference series. + + Ports `_plot_vs_entsoe` from the chapter-14 team4 script into a general, + label-parametrised form. The reference is reindexed to `forecast.index`; + only the overlapping (non-NaN) timestamps are plotted. If there is no + overlap the reference line is omitted and an INFO message is logged — the + function still returns a valid single-line figure rather than raising. + + Both series are scaled by `unit_scale` before plotting (default `1e-3` + converts MW to GW). + + The overlap MAD (mean absolute deviation between `forecast` and + `reference` over shared timestamps) is logged at INFO level when an + overlap exists. This mirrors the original script's behaviour and is + useful for post-hoc sanity checks in operator logs. + + Args: + forecast: Point forecast series with a `DatetimeIndex`. + reference: Reference series (e.g. ENTSO-E day-ahead forecast). + Will be reindexed to `forecast.index`; NaN rows after reindexing + are treated as "no overlap" for that timestamp. + forecast_label: Legend label for the forecast line. + reference_label: Legend label for the reference line. + unit_scale: Multiplicative scale applied to both series before + plotting. Defaults to `1e-3` (MW → GW). + unit: Unit string used in the y-axis label. + + Returns: + A `matplotlib.figure.Figure`. + + Examples: + ```{python} + import numpy as np + import pandas as pd + from spotforecast2.plots.diagnostics import plot_forecast_vs_reference + + idx = pd.date_range("2024-01-15", periods=24, freq="h", tz="UTC") + rng = np.random.default_rng(42) + forecast = pd.Series(40_000 + rng.standard_normal(24) * 1000, index=idx) + reference = pd.Series(41_000 + rng.standard_normal(24) * 800, index=idx) + + fig = plot_forecast_vs_reference( + forecast, reference, + forecast_label="team_4 forecast", + reference_label="ENTSO-E day-ahead", + ) + print(type(fig).__name__) + ``` + """ + ref_aligned = reference.reindex(forecast.index) + overlap = ref_aligned.dropna().index + + fig, ax = plt.subplots(figsize=(8.5, 3.5)) + ax.plot( + forecast.index, + forecast.values * unit_scale, + marker="o", + ms=3, + color="#1F4E79", + label=forecast_label, + ) + if len(overlap) > 0: + ax.plot( + ref_aligned.index, + ref_aligned.values * unit_scale, + marker="x", + ms=4, + color="#E07B00", + label=reference_label, + ) + mad = float((forecast.loc[overlap] - ref_aligned.loc[overlap]).abs().mean()) + logger.info( + "mean |%s - %s| over %d overlap hours: %.1f %s", + forecast_label, + reference_label, + len(overlap), + mad * unit_scale, + unit, + ) + else: + logger.info( + "%s not available for forecast period; plotting %s only.", + reference_label, + forecast_label, + ) + ax.set_xlabel("Time (UTC)") + ax.set_ylabel(f"Load [{unit}]") + ax.set_title(f"{forecast_label} vs. {reference_label} — target day") + ax.grid(True, color="#E5E5E5", linewidth=0.5) + ax.legend(fontsize=8) + for spine in ("top", "right"): + ax.spines[spine].set_visible(False) + fig.autofmt_xdate() + fig.tight_layout() + return fig diff --git a/tests/test_plots_diagnostics.py b/tests/test_plots_diagnostics.py new file mode 100644 index 00000000..d8b4fdb5 --- /dev/null +++ b/tests/test_plots_diagnostics.py @@ -0,0 +1,344 @@ +# SPDX-FileCopyrightText: 2026 bartzbeielstein +# SPDX-License-Identifier: AGPL-3.0-or-later + +"""Tests for spotforecast2.plots.diagnostics. + +Headless matplotlib backend must be set BEFORE pyplot is imported. The +`matplotlib.use()` call at module level (before any pyplot import) satisfies +that requirement. +""" + +import matplotlib + +matplotlib.use("Agg", force=True) + +import matplotlib.pyplot as plt # noqa: E402 — must be after matplotlib.use +import numpy as np # noqa: E402 +import pandas as pd # noqa: E402 +import pytest # noqa: E402 +from matplotlib.figure import Figure # noqa: E402 + +from spotforecast2.plots.diagnostics import ( # noqa: E402 + feature_family, + plot_acf_with_lags, + plot_feature_importance_by_family, + plot_forecast_vs_reference, + plot_shap_summary, +) + + +# --------------------------------------------------------------------------- +# Teardown +# --------------------------------------------------------------------------- + + +@pytest.fixture(autouse=True) +def close_all_figures(): + """Close every matplotlib figure after each test to avoid resource leaks.""" + yield + plt.close("all") + + +# --------------------------------------------------------------------------- +# feature_family — exhaustive family-to-name mapping table +# --------------------------------------------------------------------------- + +_FAMILY_CASES = [ + # holiday family + ("holiday_DE", "holiday"), + ("is_holiday", "holiday"), + ("HOLIDAY_NW", "holiday"), + ("brueckentag_NW", "holiday"), + ("BRUECKENTAG_DE", "holiday"), + ("day_after_holiday", "holiday"), + ("day_before_holiday", "holiday"), + # polynomial family + ("poly_hour_2", "polynomial"), + ("poly_dayofweek_3", "polynomial"), + ("POLY_month", "polynomial"), + # precedence: poly beats window + ("poly_window_24", "polynomial"), + # weather_window family + ("window_mean_72", "weather_window"), + ("roll_window_168", "weather_window"), + ("WINDOW_std_24", "weather_window"), + # precedence: window beats cyclical + ("sin_window", "weather_window"), + # cyclical/RBF family + ("sin_hour", "cyclical/RBF"), + ("cos_hour", "cyclical/RBF"), + ("rbf_hour_0", "cyclical/RBF"), + ("SIN_dayofyear", "cyclical/RBF"), + ("COS_month", "cyclical/RBF"), + # lag family + ("lag_1", "lag"), + ("lag_24", "lag"), + ("lag_168", "lag"), + # weather/other (catch-all) + ("wind_speed_10m", "weather/other"), + ("temperature_2m", "weather/other"), + ("entsoe_forecasted_load", "weather/other"), + ("is_workday", "weather/other"), + ("day_type", "weather/other"), + ("solar_zenith", "weather/other"), +] + + +@pytest.mark.parametrize("name,expected", _FAMILY_CASES) +def test_feature_family_mapping(name, expected): + """Every name-to-family pair in the exhaustive table resolves correctly.""" + assert feature_family(name) == expected + + +def test_feature_family_case_insensitive_for_holiday(): + """Holiday detection is case-insensitive (lowercased internally).""" + assert feature_family("Holiday_NW") == "holiday" + + +def test_feature_family_lag_prefix_only(): + """Only names that START with 'lag' (after lowercasing) map to lag family.""" + assert feature_family("lag_1") == "lag" + # 'flag' contains 'lag' but does NOT start with it + assert feature_family("flag_col") == "weather/other" + + +# --------------------------------------------------------------------------- +# plot_acf_with_lags +# --------------------------------------------------------------------------- + + +def _make_acf(n: int = 60, seed: int = 0) -> pd.DataFrame: + rng = np.random.default_rng(seed) + return pd.DataFrame( + {"lag": np.arange(n), "autocorrelation": rng.uniform(-0.4, 0.4, n)} + ) + + +def test_plot_acf_returns_figure(): + acf = _make_acf(60) + fig = plot_acf_with_lags(acf, key_lags=[1, 24, 48], conf=0.1) + assert isinstance(fig, Figure) + + +def test_plot_acf_annotation_count_equals_present_key_lags(): + """Annotations are drawn only for key_lags that actually appear in acf.""" + acf = _make_acf(50) # lags 0..49 + # 3 of 4 lags are present in the frame (lag 100 is not) + key_lags = [1, 24, 48, 100] + expected_annotations = sum(1 for lag in key_lags if lag in acf["lag"].values) + fig = plot_acf_with_lags(acf, key_lags=key_lags, conf=0.1) + ax = fig.axes[0] + n_annotations = len(ax.texts) + assert n_annotations == expected_annotations + + +def test_plot_acf_empty_key_lags_no_annotations(): + acf = _make_acf(30) + fig = plot_acf_with_lags(acf, key_lags=[], conf=0.1) + ax = fig.axes[0] + assert len(ax.texts) == 0 + + +def test_plot_acf_all_key_lags_missing_no_annotations(): + """key_lags entirely outside the acf range produces zero annotations.""" + acf = _make_acf(10) # lags 0..9 + fig = plot_acf_with_lags(acf, key_lags=[50, 100, 168], conf=0.1) + ax = fig.axes[0] + assert len(ax.texts) == 0 + + +# --------------------------------------------------------------------------- +# plot_feature_importance_by_family +# --------------------------------------------------------------------------- + + +def test_importance_plot_returns_figure(): + names = ["lag_1", "lag_24", "holiday_DE", "wind_speed"] + scores = [100, 80, 60, 40] + fig = plot_feature_importance_by_family(names, scores) + assert isinstance(fig, Figure) + + +def test_importance_plot_top_n_respected(): + """The bar chart contains at most top_n bars.""" + rng = np.random.default_rng(1) + n_features = 30 + names = [f"feature_{i}" for i in range(n_features)] + scores = rng.integers(1, 100, n_features).tolist() + top_n = 10 + fig = plot_feature_importance_by_family(names, scores, top_n=top_n) + ax = fig.axes[0] + # barh produces one patch per bar + assert len(ax.patches) == top_n + + +def test_importance_plot_family_colors_distinct(): + """Known feature names from different families get different bar colors.""" + names = ["lag_1", "holiday_DE", "poly_hour_2", "window_mean_72", + "sin_hour", "wind_speed"] + scores = [100, 90, 80, 70, 60, 50] + fig = plot_feature_importance_by_family(names, scores) + ax = fig.axes[0] + # Each bar's facecolor should be non-grey (grey = unknown family) + face_colors = [p.get_facecolor() for p in ax.patches] + # At minimum the lag bar (red #B22222) and holiday (green #2E8B57) differ + assert len(set(str(c) for c in face_colors)) > 1 + + +def test_importance_plot_top_n_larger_than_features(): + """top_n larger than available features still works (no error).""" + names = ["lag_1", "lag_24"] + scores = [50, 30] + fig = plot_feature_importance_by_family(names, scores, top_n=20) + assert isinstance(fig, Figure) + + +# --------------------------------------------------------------------------- +# plot_shap_summary +# --------------------------------------------------------------------------- + + +def test_shap_summary_returns_figure(): + """Train a tiny LGBMRegressor and verify plot_shap_summary returns Figure.""" + pytest.importorskip("lightgbm") + pytest.importorskip("shap") + from lightgbm import LGBMRegressor + + rng = np.random.default_rng(7) + X = pd.DataFrame( + rng.standard_normal((120, 4)), columns=["f0", "f1", "f2", "f3"] + ) + y = X["f0"] * 2 + rng.standard_normal(120) * 0.1 + est = LGBMRegressor(n_estimators=10, verbose=-1, random_state=0) + est.fit(X, y) + + fig = plot_shap_summary(est, X, max_samples=50) + assert isinstance(fig, Figure) + + +def test_shap_summary_subsamples_max_samples(): + """When X has more rows than max_samples, SHAP only sees a subset. + + We cannot directly inspect the SHAP call, but we verify the function + completes without error and returns a Figure. The subsampling logic + (stride = max(1, len(X) // max_samples)) is verified by the fast runtime. + """ + pytest.importorskip("lightgbm") + pytest.importorskip("shap") + from lightgbm import LGBMRegressor + + rng = np.random.default_rng(9) + X = pd.DataFrame( + rng.standard_normal((500, 3)), columns=["a", "b", "c"] + ) + y = X["a"] + rng.standard_normal(500) * 0.2 + est = LGBMRegressor(n_estimators=10, verbose=-1, random_state=0) + est.fit(X, y) + + fig = plot_shap_summary(est, X, max_samples=50) + assert isinstance(fig, Figure) + + +# --------------------------------------------------------------------------- +# plot_forecast_vs_reference +# --------------------------------------------------------------------------- + + +def _make_forecast(n: int = 24, seed: int = 0) -> pd.Series: + rng = np.random.default_rng(seed) + idx = pd.date_range("2024-01-15", periods=n, freq="h", tz="UTC") + return pd.Series(40_000 + rng.standard_normal(n) * 1000, index=idx) + + +def test_forecast_vs_reference_returns_figure_with_overlap(): + forecast = _make_forecast(24) + reference = _make_forecast(24, seed=1) # same index -> full overlap + fig = plot_forecast_vs_reference( + forecast, reference, + forecast_label="team_4", reference_label="ENTSO-E" + ) + assert isinstance(fig, Figure) + + +def test_forecast_vs_reference_two_lines_when_overlap(): + """Two lines (forecast + reference) should appear when overlap exists.""" + forecast = _make_forecast(24) + reference = _make_forecast(24, seed=2) + fig = plot_forecast_vs_reference(forecast, reference) + ax = fig.axes[0] + assert len(ax.lines) == 2 + + +def test_forecast_vs_reference_one_line_when_empty_overlap(): + """When reference has no overlap with forecast, only the forecast line is drawn.""" + forecast = _make_forecast(24) + # Reference on a completely different index (no overlap) + idx_other = pd.date_range("2025-06-01", periods=24, freq="h", tz="UTC") + rng = np.random.default_rng(3) + reference = pd.Series(40_000 + rng.standard_normal(24) * 500, index=idx_other) + fig = plot_forecast_vs_reference(forecast, reference) + ax = fig.axes[0] + assert len(ax.lines) == 1 + + +def test_forecast_vs_reference_no_exception_on_empty_overlap(): + """Empty overlap must not raise — the function returns a valid figure.""" + forecast = _make_forecast(24) + empty_ref = pd.Series(dtype=float) + # Should not raise + fig = plot_forecast_vs_reference(forecast, empty_ref) + assert isinstance(fig, Figure) + + +def test_forecast_vs_reference_unit_scale_applied(caplog): + """unit_scale=1 leaves values unscaled; verify axis range is ~MW not GW.""" + import logging + + forecast = _make_forecast(24) + reference = _make_forecast(24, seed=5) + with caplog.at_level(logging.INFO, logger="spotforecast2.plots.diagnostics"): + fig = plot_forecast_vs_reference( + forecast, reference, unit_scale=1.0, unit="MW" + ) + ax = fig.axes[0] + # y-axis upper limit should be in the MW range (>1000), not GW (<100) + ymin, ymax = ax.get_ylim() + assert ymax > 1000 + + +def test_forecast_vs_reference_mad_logged_in_display_unit(caplog): + """MAD is logged in the display unit (GW), not in raw MW. + + With a constant offset of 1000 MW and unit_scale=1e-3, the logged value + must be 1.0 GW. The old bug (mad / unit_scale) would have logged 1000000. + """ + import logging + import re + + idx = pd.date_range("2024-01-15", periods=24, freq="h", tz="UTC") + forecast = pd.Series(np.full(24, 40_000.0), index=idx) + # Exact constant offset of 1000 MW → MAD = 1000 MW = 1.0 GW + reference = pd.Series(np.full(24, 41_000.0), index=idx) + + with caplog.at_level(logging.INFO, logger="spotforecast2.plots.diagnostics"): + plot_forecast_vs_reference( + forecast, + reference, + forecast_label="fc", + reference_label="ref", + unit_scale=1e-3, + unit="GW", + ) + + # Find the MAD log message and extract the numeric value + mad_messages = [r for r in caplog.records if "overlap hours" in r.message] + assert mad_messages, "Expected MAD log message not found in caplog" + msg = mad_messages[0].message + # Extract the number immediately before " GW" + match = re.search(r"([\d.]+)\s+GW", msg) + assert match, f"Could not parse GW value from log message: {msg!r}" + logged_value = float(match.group(1)) + assert logged_value == pytest.approx(1.0, abs=0.05), ( + f"Logged MAD should be 1.0 GW, got {logged_value}. " + f"Full message: {msg!r}" + )