Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/spotforecast2/model_selection/spotoptim_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -877,7 +877,7 @@ def build_warm_start_x0(
``SpotOptimStrategy.prepare_forecaster``).
forecaster: The pre-tuning forecaster; its ``estimator`` supplies the
starting values for the numeric hyperparameter dimensions.
lags_seed: The lag configuration to seed (e.g. ``config.lags_consider``).
lags_seed: The lag configuration to seed (e.g. ``config.warm_start_lags``).

Returns:
A 1-D float array of length ``len(var_name)``, or ``None`` when the
Expand Down
23 changes: 12 additions & 11 deletions src/spotforecast2/multitask/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,9 +241,10 @@ def prepare_forecaster(
task: A `BaseTask` (or compatible) instance that supplies ``cv_ts``,
``config``, ``logger``, ``save_tuning_results``, and
``create_forecaster``. The config must expose
``n_trials_spotoptim``, ``n_initial_spotoptim``,
``random_state``, ``warm_start_lags``, and optionally
``lags_consider`` and the TensorBoard knobs described
``n_trials_spotoptim``, ``n_initial_spotoptim``, and
``random_state``; optionally ``warm_start_lags`` (the seed
lag set, sf2-safe >= 22.0.0; ``None``/empty = cold start),
``max_time_spotoptim``, and the TensorBoard knobs described
above.
target: Target column name; forwarded to ``task.create_forecaster``
and ``task.save_tuning_results``.
Expand Down Expand Up @@ -288,7 +289,7 @@ def prepare_forecaster(
n_trials_spotoptim=5,
n_initial_spotoptim=3,
random_state=0,
warm_start_lags=False,
warm_start_lags=None,
)
task = types.SimpleNamespace(
config=cfg,
Expand Down Expand Up @@ -318,15 +319,15 @@ def prepare_forecaster(
search_space = self.search_space or _default_spotoptim_search_space()
cv = task.cv_ts(y_train)

# Warm start: inject ``lags_consider`` as a candidate lag set and seed
# the optimizer's first evaluation with it. Only dict search spaces
# with a ``"lags"`` list are eligible; anything else falls through to a
# normal cold-start run.
# Warm start: ``config.warm_start_lags`` (sf2-safe >= 22.0.0) is the
# seed lag set itself — it is injected as a search-space candidate and
# seeds the optimizer's first evaluation. ``None``/empty disables the
# warm start. Only dict search spaces with a ``"lags"`` list are
# eligible; anything else falls through to a normal cold-start run.
kwargs_spotoptim: Dict[str, Any] = {}
lags_seed = getattr(task.config, "lags_consider", None)
lags_seed = getattr(task.config, "warm_start_lags", None)
if (
getattr(task.config, "warm_start_lags", False)
and lags_seed
lags_seed
and isinstance(search_space, dict)
and isinstance(search_space.get("lags"), list)
):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_multitask_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def _make_fake_task(**config_extra):
n_trials_spotoptim=2,
n_initial_spotoptim=1,
random_state=0,
warm_start_lags=False,
warm_start_lags=None,
**config_extra,
)
return types.SimpleNamespace(
Expand Down
11 changes: 5 additions & 6 deletions tests/test_spotoptim_warm_start.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# SPDX-FileCopyrightText: 2026 bartzbeielstein
# SPDX-License-Identifier: AGPL-3.0-or-later

"""Tests for warm-starting the SpotOptim search with ``lags_consider``.
"""Tests for warm-starting the SpotOptim search with ``warm_start_lags``.

Covers the ``build_warm_start_x0`` helper (full-dim seed point, clipping,
graceful ``None`` cases) and an end-to-end check that the seeded lag
Expand Down Expand Up @@ -112,8 +112,7 @@ def test_warm_start_x0_round_trips_to_seed():


class _MockConfig:
warm_start_lags = True
lags_consider = [1, 2, 24]
warm_start_lags = [1, 2, 24]
random_state = 1
n_trials_spotoptim = 4
n_initial_spotoptim = 3
Expand Down Expand Up @@ -141,7 +140,7 @@ def create_forecaster(self, target):


def test_strategy_injects_seed_and_forwards_x0(monkeypatch):
"""The strategy injects lags_consider as a candidate and forwards x0."""
"""The strategy injects warm_start_lags as a candidate and forwards x0."""
from spotforecast2.multitask.strategies import SpotOptimStrategy

captured = {}
Expand Down Expand Up @@ -186,7 +185,7 @@ def fake_search(**kwargs):


def test_strategy_no_x0_when_flag_disabled(monkeypatch):
"""With warm_start_lags False, no x0 is passed and lags is untouched."""
"""With warm_start_lags None, no x0 is passed and lags is untouched."""
from spotforecast2.multitask.strategies import SpotOptimStrategy

captured = {}
Expand All @@ -207,7 +206,7 @@ def fake_search(**kwargs):
)

task = _MockTask()
task.config.warm_start_lags = False
task.config.warm_start_lags = None
y = pd.Series(
np.sin(np.arange(400) * 2 * np.pi / 24),
index=pd.date_range("2022-01-01", periods=400, freq="h"),
Expand Down
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading