From 8565b8fec2273873e006496c4d0fefd690a1bea2 Mon Sep 17 00:00:00 2001 From: bartzbeielstein <32470350+bartzbeielstein@users.noreply.github.com> Date: Thu, 11 Jun 2026 01:07:34 +0200 Subject: [PATCH] feat(multitask)!: seed SpotOptim warm start from warm_start_lags list SpotOptimStrategy reads the seed lag set directly from config.warm_start_lags (sf2-safe >= 22.0.0, list-valued; default DEFAULT_WARM_START_LAGS) instead of combining the old boolean flag with lags_consider. None or an empty list disables the warm start; lags_consider keeps only its passive roles. BREAKING CHANGE: configs with warm_start_lags=True no longer warm-start the search with lags_consider; set warm_start_lags to the seed lag list itself (or None to disable). Co-Authored-By: Claude Fable 5 --- .../model_selection/spotoptim_search.py | 2 +- src/spotforecast2/multitask/strategies.py | 23 ++++++++++--------- tests/test_multitask_strategies.py | 2 +- tests/test_spotoptim_warm_start.py | 11 ++++----- uv.lock | 2 +- 5 files changed, 20 insertions(+), 20 deletions(-) diff --git a/src/spotforecast2/model_selection/spotoptim_search.py b/src/spotforecast2/model_selection/spotoptim_search.py index 9f462386..8f905295 100644 --- a/src/spotforecast2/model_selection/spotoptim_search.py +++ b/src/spotforecast2/model_selection/spotoptim_search.py @@ -877,7 +877,7 @@ def build_warm_start_x0( ``SpotOptimStrategy.prepare_forecaster``). forecaster: The pre-tuning forecaster; its ``estimator`` supplies the starting values for the numeric hyperparameter dimensions. - lags_seed: The lag configuration to seed (e.g. ``config.lags_consider``). + lags_seed: The lag configuration to seed (e.g. ``config.warm_start_lags``). Returns: A 1-D float array of length ``len(var_name)``, or ``None`` when the diff --git a/src/spotforecast2/multitask/strategies.py b/src/spotforecast2/multitask/strategies.py index 118f2606..f0352792 100644 --- a/src/spotforecast2/multitask/strategies.py +++ b/src/spotforecast2/multitask/strategies.py @@ -241,9 +241,10 @@ def prepare_forecaster( task: A `BaseTask` (or compatible) instance that supplies ``cv_ts``, ``config``, ``logger``, ``save_tuning_results``, and ``create_forecaster``. The config must expose - ``n_trials_spotoptim``, ``n_initial_spotoptim``, - ``random_state``, ``warm_start_lags``, and optionally - ``lags_consider`` and the TensorBoard knobs described + ``n_trials_spotoptim``, ``n_initial_spotoptim``, and + ``random_state``; optionally ``warm_start_lags`` (the seed + lag set, sf2-safe >= 22.0.0; ``None``/empty = cold start), + ``max_time_spotoptim``, and the TensorBoard knobs described above. target: Target column name; forwarded to ``task.create_forecaster`` and ``task.save_tuning_results``. @@ -288,7 +289,7 @@ def prepare_forecaster( n_trials_spotoptim=5, n_initial_spotoptim=3, random_state=0, - warm_start_lags=False, + warm_start_lags=None, ) task = types.SimpleNamespace( config=cfg, @@ -318,15 +319,15 @@ def prepare_forecaster( search_space = self.search_space or _default_spotoptim_search_space() cv = task.cv_ts(y_train) - # Warm start: inject ``lags_consider`` as a candidate lag set and seed - # the optimizer's first evaluation with it. Only dict search spaces - # with a ``"lags"`` list are eligible; anything else falls through to a - # normal cold-start run. + # Warm start: ``config.warm_start_lags`` (sf2-safe >= 22.0.0) is the + # seed lag set itself — it is injected as a search-space candidate and + # seeds the optimizer's first evaluation. ``None``/empty disables the + # warm start. Only dict search spaces with a ``"lags"`` list are + # eligible; anything else falls through to a normal cold-start run. kwargs_spotoptim: Dict[str, Any] = {} - lags_seed = getattr(task.config, "lags_consider", None) + lags_seed = getattr(task.config, "warm_start_lags", None) if ( - getattr(task.config, "warm_start_lags", False) - and lags_seed + lags_seed and isinstance(search_space, dict) and isinstance(search_space.get("lags"), list) ): diff --git a/tests/test_multitask_strategies.py b/tests/test_multitask_strategies.py index ce6354a3..1753d82a 100644 --- a/tests/test_multitask_strategies.py +++ b/tests/test_multitask_strategies.py @@ -75,7 +75,7 @@ def _make_fake_task(**config_extra): n_trials_spotoptim=2, n_initial_spotoptim=1, random_state=0, - warm_start_lags=False, + warm_start_lags=None, **config_extra, ) return types.SimpleNamespace( diff --git a/tests/test_spotoptim_warm_start.py b/tests/test_spotoptim_warm_start.py index b8f8b919..951309bf 100644 --- a/tests/test_spotoptim_warm_start.py +++ b/tests/test_spotoptim_warm_start.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: 2026 bartzbeielstein # SPDX-License-Identifier: AGPL-3.0-or-later -"""Tests for warm-starting the SpotOptim search with ``lags_consider``. +"""Tests for warm-starting the SpotOptim search with ``warm_start_lags``. Covers the ``build_warm_start_x0`` helper (full-dim seed point, clipping, graceful ``None`` cases) and an end-to-end check that the seeded lag @@ -112,8 +112,7 @@ def test_warm_start_x0_round_trips_to_seed(): class _MockConfig: - warm_start_lags = True - lags_consider = [1, 2, 24] + warm_start_lags = [1, 2, 24] random_state = 1 n_trials_spotoptim = 4 n_initial_spotoptim = 3 @@ -141,7 +140,7 @@ def create_forecaster(self, target): def test_strategy_injects_seed_and_forwards_x0(monkeypatch): - """The strategy injects lags_consider as a candidate and forwards x0.""" + """The strategy injects warm_start_lags as a candidate and forwards x0.""" from spotforecast2.multitask.strategies import SpotOptimStrategy captured = {} @@ -186,7 +185,7 @@ def fake_search(**kwargs): def test_strategy_no_x0_when_flag_disabled(monkeypatch): - """With warm_start_lags False, no x0 is passed and lags is untouched.""" + """With warm_start_lags None, no x0 is passed and lags is untouched.""" from spotforecast2.multitask.strategies import SpotOptimStrategy captured = {} @@ -207,7 +206,7 @@ def fake_search(**kwargs): ) task = _MockTask() - task.config.warm_start_lags = False + task.config.warm_start_lags = None y = pd.Series( np.sin(np.arange(400) * 2 * np.pi / 24), index=pd.date_range("2022-01-01", periods=400, freq="h"), diff --git a/uv.lock b/uv.lock index fd654820..f6c72523 100644 --- a/uv.lock +++ b/uv.lock @@ -3491,7 +3491,7 @@ wheels = [ [[package]] name = "spotforecast2" -version = "7.0.0" +version = "7.1.0" source = { editable = "." } dependencies = [ { name = "astral" },