Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ dependencies = [
"ruff>=0.15.6",
"scikit-learn>=1.8.0",
"shap>=0.49.1",
"spotforecast2-safe>=21.0.0,<22",
# 21.2.0 added the max_time_spotoptim config field that SpotOptimStrategy
# forwards as SpotOptim's max_time.
"spotforecast2-safe>=21.2.0,<22",
# spotoptim 1.0 is sequential-only and lean: torch/tensorboard moved to its
# ``[torch]`` extra. sf2 forwards tensorboard_* kwargs into SpotOptim, so we
# pin the extra to keep the TensorBoard tuning dashboards working (they were
Expand Down
17 changes: 17 additions & 0 deletions src/spotforecast2/multitask/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,13 @@ def prepare_forecaster(
task its own ``tensorboard_path``, when several tasks share one
log directory.

Time budget: when ``config.max_time_spotoptim`` is set (wall-clock
minutes, a first-class `ConfigMulti`/`ConfigEntsoe` field in
sf2-safe >= 21.2.0), it is forwarded as SpotOptim's ``max_time``;
the search stops after ``n_trials_spotoptim`` evaluations or that
time limit, whichever comes first. ``None`` (the default) means no
time limit.

Args:
task: A `BaseTask` (or compatible) instance that supplies ``cv_ts``,
``config``, ``logger``, ``save_tuning_results``, and
Expand Down Expand Up @@ -351,6 +358,16 @@ def prepare_forecaster(
kwargs_spotoptim.update(tb_kwargs)
task.logger.info(" SpotOptim TensorBoard: %s", tb_kwargs)

# Wall-clock budget: forward ``config.max_time_spotoptim`` (minutes,
# sf2-safe >= 21.2.0) as SpotOptim's ``max_time``. The search then
# stops at ``n_trials_spotoptim`` evaluations or the time limit,
# whichever comes first. ``None`` (the default) forwards nothing and
# leaves SpotOptim's own default (no limit) in charge.
max_time = getattr(task.config, "max_time_spotoptim", None)
if max_time is not None:
kwargs_spotoptim["max_time"] = float(max_time)
task.logger.info(" SpotOptim max_time: %.2f min", float(max_time))

tuning_results, _ = spotoptim_search_forecaster(
forecaster=forecaster,
y=y_train,
Expand Down
46 changes: 46 additions & 0 deletions tests/test_multitask_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,3 +204,49 @@ def fake_search_forecaster(*args, **kwargs):

# No tensorboard attrs -> empty kwargs dict -> None passed through.
assert captured["kwargs_spotoptim"] is None


def test_spotoptim_strategy_forwards_max_time(monkeypatch):
"""config.max_time_spotoptim (minutes) must reach kwargs_spotoptim as max_time."""
import pandas as pd

import spotforecast2.model_selection as ms

captured = {}

def fake_search_forecaster(*args, **kwargs):
captured["kwargs_spotoptim"] = kwargs.get("kwargs_spotoptim")
results = pd.DataFrame({"params": [{"alpha": 1.0}], "lags": [[1, 2]]})
return results, object()

monkeypatch.setattr(ms, "spotoptim_search_forecaster", fake_search_forecaster)

task = _make_fake_task(max_time_spotoptim=2.5)
SpotOptimStrategy(search_space={"alpha": (0.1, 1.0)}).prepare_forecaster(
task, "A", _FakeForecaster(), y_train=None
)

assert captured["kwargs_spotoptim"]["max_time"] == 2.5


def test_spotoptim_strategy_none_max_time_not_forwarded(monkeypatch):
"""max_time_spotoptim=None (the config default) must forward nothing."""
import pandas as pd

import spotforecast2.model_selection as ms

captured = {}

def fake_search_forecaster(*args, **kwargs):
captured["kwargs_spotoptim"] = kwargs.get("kwargs_spotoptim")
results = pd.DataFrame({"params": [{"alpha": 1.0}], "lags": [[1, 2]]})
return results, object()

monkeypatch.setattr(ms, "spotoptim_search_forecaster", fake_search_forecaster)

task = _make_fake_task(max_time_spotoptim=None)
SpotOptimStrategy(search_space={"alpha": (0.1, 1.0)}).prepare_forecaster(
task, "A", _FakeForecaster(), y_train=None
)

assert captured["kwargs_spotoptim"] is None
10 changes: 5 additions & 5 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading