From 6bff756ed3ebf063cd2f8b45904ce79668f2e64f Mon Sep 17 00:00:00 2001 From: bartzbeielstein <32470350+bartzbeielstein@users.noreply.github.com> Date: Wed, 10 Jun 2026 18:17:48 +0200 Subject: [PATCH 1/2] feat(multitask): forward max_time_spotoptim to SpotOptim ConfigMulti/ConfigEntsoe (sf2-safe >= 21.1.0) carry a new max_time_spotoptim field: a wall-clock budget for the surrogate search in minutes. SpotOptimStrategy now forwards it as SpotOptim's max_time, so the search stops at n_trials_spotoptim evaluations or the time limit, whichever comes first. Unset/None forwards nothing (read via getattr, so older sf2-safe configs keep working). Co-Authored-By: Claude Fable 5 --- src/spotforecast2/multitask/strategies.py | 17 +++++++++ tests/test_multitask_strategies.py | 46 +++++++++++++++++++++++ 2 files changed, 63 insertions(+) diff --git a/src/spotforecast2/multitask/strategies.py b/src/spotforecast2/multitask/strategies.py index 3b042565..118f2606 100644 --- a/src/spotforecast2/multitask/strategies.py +++ b/src/spotforecast2/multitask/strategies.py @@ -230,6 +230,13 @@ def prepare_forecaster( task its own ``tensorboard_path``, when several tasks share one log directory. + Time budget: when ``config.max_time_spotoptim`` is set (wall-clock + minutes, a first-class `ConfigMulti`/`ConfigEntsoe` field in + sf2-safe >= 21.2.0), it is forwarded as SpotOptim's ``max_time``; + the search stops after ``n_trials_spotoptim`` evaluations or that + time limit, whichever comes first. ``None`` (the default) means no + time limit. + Args: task: A `BaseTask` (or compatible) instance that supplies ``cv_ts``, ``config``, ``logger``, ``save_tuning_results``, and @@ -351,6 +358,16 @@ def prepare_forecaster( kwargs_spotoptim.update(tb_kwargs) task.logger.info(" SpotOptim TensorBoard: %s", tb_kwargs) + # Wall-clock budget: forward ``config.max_time_spotoptim`` (minutes, + # sf2-safe >= 21.2.0) as SpotOptim's ``max_time``. The search then + # stops at ``n_trials_spotoptim`` evaluations or the time limit, + # whichever comes first. ``None`` (the default) forwards nothing and + # leaves SpotOptim's own default (no limit) in charge. + max_time = getattr(task.config, "max_time_spotoptim", None) + if max_time is not None: + kwargs_spotoptim["max_time"] = float(max_time) + task.logger.info(" SpotOptim max_time: %.2f min", float(max_time)) + tuning_results, _ = spotoptim_search_forecaster( forecaster=forecaster, y=y_train, diff --git a/tests/test_multitask_strategies.py b/tests/test_multitask_strategies.py index c4e4e15e..ce6354a3 100644 --- a/tests/test_multitask_strategies.py +++ b/tests/test_multitask_strategies.py @@ -204,3 +204,49 @@ def fake_search_forecaster(*args, **kwargs): # No tensorboard attrs -> empty kwargs dict -> None passed through. assert captured["kwargs_spotoptim"] is None + + +def test_spotoptim_strategy_forwards_max_time(monkeypatch): + """config.max_time_spotoptim (minutes) must reach kwargs_spotoptim as max_time.""" + import pandas as pd + + import spotforecast2.model_selection as ms + + captured = {} + + def fake_search_forecaster(*args, **kwargs): + captured["kwargs_spotoptim"] = kwargs.get("kwargs_spotoptim") + results = pd.DataFrame({"params": [{"alpha": 1.0}], "lags": [[1, 2]]}) + return results, object() + + monkeypatch.setattr(ms, "spotoptim_search_forecaster", fake_search_forecaster) + + task = _make_fake_task(max_time_spotoptim=2.5) + SpotOptimStrategy(search_space={"alpha": (0.1, 1.0)}).prepare_forecaster( + task, "A", _FakeForecaster(), y_train=None + ) + + assert captured["kwargs_spotoptim"]["max_time"] == 2.5 + + +def test_spotoptim_strategy_none_max_time_not_forwarded(monkeypatch): + """max_time_spotoptim=None (the config default) must forward nothing.""" + import pandas as pd + + import spotforecast2.model_selection as ms + + captured = {} + + def fake_search_forecaster(*args, **kwargs): + captured["kwargs_spotoptim"] = kwargs.get("kwargs_spotoptim") + results = pd.DataFrame({"params": [{"alpha": 1.0}], "lags": [[1, 2]]}) + return results, object() + + monkeypatch.setattr(ms, "spotoptim_search_forecaster", fake_search_forecaster) + + task = _make_fake_task(max_time_spotoptim=None) + SpotOptimStrategy(search_space={"alpha": (0.1, 1.0)}).prepare_forecaster( + task, "A", _FakeForecaster(), y_train=None + ) + + assert captured["kwargs_spotoptim"] is None From 1f52d04b1e0e09fe1977b3f49a875454a93a08c2 Mon Sep 17 00:00:00 2001 From: bartzbeielstein <32470350+bartzbeielstein@users.noreply.github.com> Date: Wed, 10 Jun 2026 18:31:44 +0200 Subject: [PATCH 2/2] feat(deps): require spotforecast2-safe >=21.2.0 21.2.0 carries the max_time_spotoptim config field consumed by SpotOptimStrategy. Co-Authored-By: Claude Fable 5 --- pyproject.toml | 4 +++- uv.lock | 10 +++++----- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 7762074d..b03ec566 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,7 +31,9 @@ dependencies = [ "ruff>=0.15.6", "scikit-learn>=1.8.0", "shap>=0.49.1", - "spotforecast2-safe>=21.0.0,<22", + # 21.2.0 added the max_time_spotoptim config field that SpotOptimStrategy + # forwards as SpotOptim's max_time. + "spotforecast2-safe>=21.2.0,<22", # spotoptim 1.0 is sequential-only and lean: torch/tensorboard moved to its # ``[torch]`` extra. sf2 forwards tensorboard_* kwargs into SpotOptim, so we # pin the extra to keep the TensorBoard tuning dashboards working (they were diff --git a/uv.lock b/uv.lock index dae7bada..fd654820 100644 --- a/uv.lock +++ b/uv.lock @@ -3491,7 +3491,7 @@ wheels = [ [[package]] name = "spotforecast2" -version = "6.1.0" +version = "7.0.0" source = { editable = "." } dependencies = [ { name = "astral" }, @@ -3571,7 +3571,7 @@ requires-dist = [ { name = "safety", marker = "extra == 'dev'", specifier = ">=3.0.0" }, { name = "scikit-learn", specifier = ">=1.8.0" }, { name = "shap", specifier = ">=0.49.1" }, - { name = "spotforecast2-safe", specifier = ">=21.0.0,<22" }, + { name = "spotforecast2-safe", specifier = ">=21.2.0,<22" }, { name = "spotoptim", extras = ["torch"], specifier = ">=1.0.0,<2" }, { name = "tqdm", specifier = ">=4.67.2" }, { name = "ty", marker = "extra == 'dev'", specifier = ">=0.0.29" }, @@ -3590,7 +3590,7 @@ dev = [ [[package]] name = "spotforecast2-safe" -version = "21.0.0" +version = "21.2.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "astral" }, @@ -3607,9 +3607,9 @@ dependencies = [ { name = "statsmodels" }, { name = "tqdm" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/ec/55/54aed0c9cdbfa8126f7241114bbb26e35c1b712683379573243978bf69bd/spotforecast2_safe-21.0.0.tar.gz", hash = "sha256:642b61b3f08b52e12cd5a24f84878fa761201455b0e8fc2907dabb9ed9022afa", size = 20624222, upload-time = "2026-06-09T19:36:08.747Z" } +sdist = { url = "https://files.pythonhosted.org/packages/a5/48/e12f37462beae893f901751648f8b830475cd58d19d8e8e9703236cd25c8/spotforecast2_safe-21.2.0.tar.gz", hash = "sha256:1b181bc157a0765b15a5348831902465d356b21cba3b6ad36b57d54f6a51b07c", size = 20630614, upload-time = "2026-06-10T16:28:37.526Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/71/ca/8ad0ce36cb44e539e77ea22701e4e8c4a259e45bd851cf41cd406e6cbfa8/spotforecast2_safe-21.0.0-py3-none-any.whl", hash = "sha256:ade1a635333d84097ed59d993dd7cfde5551ff350b125b65fb48699ac3a64cd3", size = 20688561, upload-time = "2026-06-09T19:36:06.457Z" }, + { url = "https://files.pythonhosted.org/packages/f8/73/0537ab7ce84308bfd49b04b365fde4a58ef2f8b3d40b415b8963777b71c7/spotforecast2_safe-21.2.0-py3-none-any.whl", hash = "sha256:1a80b4629062a0600b108416e93a07f75e1d31cf7bec0bcd02cb77089826e11c", size = 20696409, upload-time = "2026-06-10T16:28:34.959Z" }, ] [[package]]