Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions src/spotforecast2/model_selection/bayesian_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,46 @@ def bayesian_search_forecaster(
TypeError: If cv is not an instance of TimeSeriesFold or OneStepAheadFold.
ValueError: If metric list contains duplicate metric names.

Examples:
```{python}
import numpy as np
import pandas as pd
from sklearn.linear_model import Ridge
from spotforecast2_safe.forecaster.recursive import ForecasterRecursive
from spotforecast2_safe.splitter import TimeSeriesFold
from spotforecast2.model_selection.bayesian_search import bayesian_search_forecaster

rng = np.random.default_rng(0)
y = pd.Series(rng.standard_normal(40), name="y")

forecaster = ForecasterRecursive(estimator=Ridge(), lags=2)
cv = TimeSeriesFold(steps=2, initial_train_size=25, refit=False)

def search_space(trial):
return {
"estimator__alpha": trial.suggest_float("estimator__alpha", 0.01, 10.0),
}

results, best_trial = bayesian_search_forecaster(
forecaster=forecaster,
y=y,
cv=cv,
search_space=search_space,
metric="mean_squared_error",
n_trials=3,
random_state=0,
return_best=False,
verbose=False,
show_progress=False,
suppress_warnings=True,
)

print(results.shape)
print(results.columns.tolist())
assert results.shape[0] == 3
assert "mean_squared_error" in results.columns
assert "estimator__alpha" in results.columns
```
"""

if return_best and exog is not None and (len(exog) != len(y)):
Expand Down
39 changes: 39 additions & 0 deletions src/spotforecast2/model_selection/grid_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,45 @@ def grid_search_forecaster(
) -> pd.DataFrame:
"""
Exhaustive grid search over parameter values for a Forecaster.

Examples:
```{python}
import warnings
import numpy as np
import pandas as pd
from sklearn.linear_model import Ridge
from spotforecast2_safe.forecaster.recursive import ForecasterRecursive
from spotforecast2_safe.splitter import TimeSeriesFold
from spotforecast2.model_selection.grid_search import grid_search_forecaster

rng = np.random.default_rng(0)
idx = pd.date_range("2020-01-01", periods=120, freq="h")
y = pd.Series(rng.normal(0, 1, 120), index=idx)

forecaster = ForecasterRecursive(estimator=Ridge(), lags=3)
cv = TimeSeriesFold(steps=3, initial_train_size=90, refit=False)
param_grid = {"alpha": [0.1, 1.0]}

with warnings.catch_warnings():
warnings.simplefilter("ignore")
results = grid_search_forecaster(
forecaster=forecaster,
y=y,
cv=cv,
param_grid=param_grid,
metric="mean_absolute_error",
lags_grid=[3, 5],
return_best=True,
n_jobs=1,
verbose=False,
show_progress=False,
suppress_warnings=True,
)

print(results[["lags_label", "params", "mean_absolute_error"]].head())
assert results.shape == (4, 5)
assert "mean_absolute_error" in results.columns
```
"""

param_grid = list(ParameterGrid(param_grid))
Expand Down
133 changes: 131 additions & 2 deletions src/spotforecast2/multitask/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,75 @@ class PlottingMixin:
plot_with_outliers: Display original vs. cleaned data with outlier markers.
_show_prediction_figure: Show an interactive per-target prediction figure.
_show_prediction_figure_agg: Show an interactive aggregated prediction figure.

Examples:
```{python}
import tempfile
import numpy as np
import pandas as pd
from spotforecast2_safe.configurator.config_multi import ConfigMulti
from spotforecast2.multitask import LazyTask
from spotforecast2.multitask.base import PlottingMixin

rng = np.random.default_rng(0)
idx = pd.date_range("2023-01-01", periods=24 * 14, freq="h", tz="UTC")
df = pd.DataFrame({"load": rng.normal(100, 10, len(idx))}, index=idx)
df.index.name = "DateTime"

with tempfile.TemporaryDirectory() as tmp:
cfg = ConfigMulti(
predict_size=6,
use_exogenous_features=False,
use_outlier_detection=False,
auto_save_models=False,
cache_home=tmp,
)
task = LazyTask(cfg, dataframe=df)

# LazyTask inherits from BaseTask(PlottingMixin, SafeBaseTask), so
# PlottingMixin.plot_with_outliers overrides the safe-base no-op stub.
print("PlottingMixin in MRO:", PlottingMixin in type(task).__mro__)
print(
"plot_with_outliers wired to PlottingMixin:",
type(task).plot_with_outliers is PlottingMixin.plot_with_outliers,
)
assert PlottingMixin in type(task).__mro__
assert type(task).plot_with_outliers is PlottingMixin.plot_with_outliers
```
"""

def plot_with_outliers(self) -> None:
"""Visualise original vs. cleaned data with outlier markers.

Raises:
RuntimeError: If ``detect_outliers`` has not been called.

Examples:
```{python}
import tempfile
import numpy as np
import pandas as pd
from spotforecast2_safe.configurator.config_multi import ConfigMulti
from spotforecast2.multitask import LazyTask

rng = np.random.default_rng(0)
idx = pd.date_range("2023-01-01", periods=24 * 14, freq="h", tz="UTC")
df = pd.DataFrame({"load": rng.normal(100, 10, len(idx))}, index=idx)
df.index.name = "DateTime"

with tempfile.TemporaryDirectory() as tmp:
cfg = ConfigMulti(
predict_size=6,
use_exogenous_features=False,
use_outlier_detection=False,
bounds=[(50, 150)],
auto_save_models=False,
cache_home=tmp,
)
task = LazyTask(cfg, dataframe=df)
task.prepare_data().detect_outliers()
task.plot_with_outliers()
```
"""
if self.df_pipeline_original is None: # type: ignore[attr-defined]
raise RuntimeError("Call detect_outliers() before plot_with_outliers().")
Expand Down Expand Up @@ -125,10 +187,43 @@ class BaseTask(PlottingMixin, SafeBaseTask):

Visualisation additions over the safe base:
plot_with_outliers: Renders original vs. cleaned data with outlier
markers via ``spotforecast2.plots.plotter.plot_with_outliers``.
_show_prediction_figure: Calls ``make_plot`` and shows the figure
markers via `spotforecast2.plots.plotter.plot_with_outliers`.
_show_prediction_figure: Calls `make_plot` and shows the figure
interactively.
_show_prediction_figure_agg: Same for the aggregated prediction.

Examples:
```{python}
import tempfile
import numpy as np
import pandas as pd
from spotforecast2_safe.configurator.config_multi import ConfigMulti
from spotforecast2.multitask.base import BaseTask, PlottingMixin

rng = np.random.default_rng(0)
idx = pd.date_range("2023-01-01", periods=24 * 14, freq="h", tz="UTC")
df = pd.DataFrame({"load": rng.normal(100, 10, len(idx))}, index=idx)
df.index.name = "DateTime"

with tempfile.TemporaryDirectory() as tmp:
cfg = ConfigMulti(
predict_size=6,
use_exogenous_features=False,
use_outlier_detection=False,
auto_save_models=False,
cache_home=tmp,
)
task = BaseTask(cfg, dataframe=df)
# Data-preparation pipeline (steps 1-3)
task.prepare_data().detect_outliers().impute()

print("Pipeline shape:", task.df_pipeline.shape)
print("Targets:", task.config.targets)
# PlottingMixin is in the MRO — visualisation hooks are live Plotly calls.
print("PlottingMixin in MRO:", PlottingMixin in type(task).__mro__)
assert task.df_pipeline.shape[1] == 1
assert PlottingMixin in type(task).__mro__
```
"""

# ``_show_prediction_figure`` and ``_show_prediction_figure_agg`` are
Expand Down Expand Up @@ -223,6 +318,40 @@ def run( # noqa: PLR0913

Raises:
NotImplementedError: Always, unless overridden by a subclass.

Examples:
`BaseTask.run` is abstract — it raises `NotImplementedError` to
enforce that every concrete task subclass provides its own
implementation. Use `LazyTask`, `OptunaTask`, `SpotOptimTask`,
`PredictTask`, or `CleanTask` for live pipelines.

```{python}
import tempfile
import numpy as np
import pandas as pd
from spotforecast2_safe.configurator.config_multi import ConfigMulti
from spotforecast2.multitask.base import BaseTask
from spotforecast2.multitask import LazyTask

rng = np.random.default_rng(0)
idx = pd.date_range("2023-01-01", periods=24 * 14, freq="h", tz="UTC")
df = pd.DataFrame({"load": rng.normal(100, 10, len(idx))}, index=idx)
df.index.name = "DateTime"

# BaseTask.run raises NotImplementedError — use a concrete subclass.
with tempfile.TemporaryDirectory() as tmp:
cfg = ConfigMulti(cache_home=tmp)
base = BaseTask(cfg)
try:
base.run()
except NotImplementedError as exc:
print("BaseTask.run() raised NotImplementedError (expected).")
print(str(exc)[:60])

# LazyTask overrides run() with lazy fitting logic.
print("LazyTask.run is overridden:", LazyTask.run is not BaseTask.run)
assert LazyTask.run is not BaseTask.run
```
"""
raise NotImplementedError(
f"{self.__class__.__name__} must implement run(). "
Expand Down
27 changes: 27 additions & 0 deletions src/spotforecast2/multitask/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,33 @@ def run(
Returns:
Aggregated prediction package. Per-target packages are stored
on ``self.results["defaults"]``.

Examples:
```{python}
from unittest.mock import MagicMock, patch
from spotforecast2.multitask import DefaultsTask

task = DefaultsTask(predict_size=24, auto_save_models=False)
task.config.targets = ["t1"]

sentinel = {"future_pred": MagicMock(name="predictions")}
with (
patch.object(task, "_ensure_pipeline_ready"),
patch.object(
task,
"_get_target_data",
return_value=(MagicMock(), MagicMock(), MagicMock()),
),
patch.object(task, "create_forecaster"),
patch.object(task, "_train_and_predict_target", return_value=sentinel),
patch.object(task, "_aggregate_and_show", return_value=sentinel),
):
result = task.run(show=False)

assert "future_pred" in result
print(f"task.TASK: {task.TASK!r}")
print(f"result keys: {list(result.keys())}")
```
"""
del kwargs # DefaultsTask has no tuning- or cache-related parameters
return execute_defaults(self, show=show)
48 changes: 48 additions & 0 deletions src/spotforecast2/multitask/lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,54 @@ def run(
Returns:
Aggregated prediction package. Per-target packages are stored
on ``self.results["lazy"]``.

Examples:
```{python}
import tempfile
import numpy as np
import pandas as pd
from lightgbm import LGBMRegressor
from spotforecast2.multitask import LazyTask
from spotforecast2_safe.configurator.config_multi import ConfigMulti
from spotforecast2_safe.forecaster.recursive import ForecasterRecursive
from spotforecast2_safe.preprocessing import RollingFeatures

rng = np.random.default_rng(0)
n = 24 * 14 # two weeks of hourly data
idx = pd.date_range("2023-01-01", periods=n, freq="h", tz="UTC")
idx.name = "DateTime"
df = pd.DataFrame({"load": rng.normal(100, 10, n)}, index=idx)

def _fast_factory(config, *, weight_func=None, target=None):
return ForecasterRecursive(
estimator=LGBMRegressor(
n_estimators=10,
random_state=config.random_state,
verbose=-1,
),
lags=6,
window_features=RollingFeatures(stats=["mean"], window_sizes=6),
weight_func=weight_func,
)

with tempfile.TemporaryDirectory() as tmp:
cfg = ConfigMulti(
predict_size=6,
use_exogenous_features=False,
use_outlier_detection=False,
auto_save_models=False,
number_folds=2,
random_state=42,
forecaster_factory=_fast_factory,
cache_home=tmp,
)
task = LazyTask(cfg, dataframe=df)
task.prepare_data().detect_outliers().impute().build_exogenous_features()
result = task.run(show=False, use_tuned_params=False)

print(f"Future predictions: {len(result['future_pred'])} steps")
assert len(result["future_pred"]) == 6
```
"""
return execute_lazy(
self,
Expand Down
Loading
Loading