diff --git a/src/art/local/backend.py b/src/art/local/backend.py index daa490204..4e5bb2546 100644 --- a/src/art/local/backend.py +++ b/src/art/local/backend.py @@ -1467,6 +1467,31 @@ async def _experimental_fork_checkpoint( shutil.copytree(source_checkpoint_dir, dest_checkpoint_dir) + # Make the fork effective for already-created local services. The + # checkpoint copy alone updates disk, but Unsloth may already have a + # cached trainer and a running vLLM server pointed at the fresh step-0 + # adapter. + service = await self._get_service(cast(TrainableModel, model)) + if hasattr(service, "_state") and "_state" in service.__dict__: + del service.__dict__["_state"] + if verbose: + print("Invalidated service _state cache for forked checkpoint") + service._forked_checkpoint_dir = dest_checkpoint_dir # type: ignore[attr-defined] + + server_started = bool(getattr(service, "_vllm_process", None)) or bool( + getattr(service, "_server_task", None) + ) + register_lora = getattr(service, "register_lora_for_step", None) + if server_started and callable(register_lora): + await register_lora(selected_step, dest_checkpoint_dir) + if verbose: + print( + f"Registered forked checkpoint {model.name}@{selected_step} " + "with running inference service" + ) + elif hasattr(service, "_latest_step"): + service._latest_step = selected_step # type: ignore[attr-defined] + if verbose: print( f"Successfully forked checkpoint from {from_model} (step {selected_step}) to {model.name}" diff --git a/src/art/pipeline_trainer/trainer.py b/src/art/pipeline_trainer/trainer.py index 99ec77d76..56cb6792c 100644 --- a/src/art/pipeline_trainer/trainer.py +++ b/src/art/pipeline_trainer/trainer.py @@ -80,6 +80,13 @@ def __init__( adam_params: object | None = None, packed_sequence_length: int | None = None, max_steps: int | None = None, + # KL-penalized advantage adjustment + kl_penalty_coef: float = 0.0, + kl_penalty_reference_step: int | None = None, + kl_ref_adapter_path: str | None = None, + kl_window_size: int | None = None, + kl_window_base_step: int = 0, + kl_window_base_adapter_path: str | None = None, # Discard handling discard_queue_multiplier: int = 100, # Status output @@ -90,6 +97,7 @@ def __init__( eval_every_n_steps: int = 20, eval_at_start: bool = True, save_checkpoint: bool = True, + save_checkpoint_artifact: bool = False, # Resumption resume: bool = True, ) -> None: @@ -109,10 +117,20 @@ def __init__( raise ValueError("eval_every_n_steps must be >= 0") if max_steps is not None and max_steps < 0: raise ValueError("max_steps must be >= 0") + if kl_penalty_coef < 0: + raise ValueError("kl_penalty_coef must be >= 0") + if kl_penalty_reference_step is not None and kl_penalty_reference_step < 0: + raise ValueError("kl_penalty_reference_step must be >= 0") + if kl_window_size is not None and kl_window_size < 0: + raise ValueError("kl_window_size must be >= 0") + if kl_window_base_step < 0: + raise ValueError("kl_window_base_step must be >= 0") if log_interval_seconds <= 0: raise ValueError("log_interval_seconds must be > 0") if discard_queue_multiplier <= 0: raise ValueError("discard_queue_multiplier must be > 0") + if save_checkpoint_artifact and not save_checkpoint: + raise ValueError("save_checkpoint_artifact=True requires save_checkpoint=True") self.model = model self.backend = backend self.rollout_fn = rollout_fn @@ -132,10 +150,17 @@ def __init__( self.adam_params = adam_params self.packed_sequence_length = packed_sequence_length self.max_steps = max_steps + self.kl_penalty_coef = kl_penalty_coef + self.kl_penalty_reference_step = kl_penalty_reference_step + self.kl_ref_adapter_path = kl_ref_adapter_path + self.kl_window_size = kl_window_size + self.kl_window_base_step = kl_window_base_step + self.kl_window_base_adapter_path = kl_window_base_adapter_path self._status_log_interval_seconds = log_interval_seconds self.eval_every_n_steps = eval_every_n_steps self.eval_at_start = eval_at_start self.save_checkpoint = save_checkpoint + self.save_checkpoint_artifact = save_checkpoint_artifact self.resume = resume self.discard_queue_multiplier = discard_queue_multiplier self._discard_queue: list[TrajectoryGroup] = [] @@ -374,24 +399,32 @@ async def _rollout_worker(self, worker_id: int) -> None: token = self.model.activate_metrics_context("train") rollout_started = time.monotonic() try: - group = await self.rollout_fn(self.model, scenario, self.config) + result = await self.rollout_fn(self.model, scenario, self.config) finally: token.var.reset(token) rollout_wall_s = time.monotonic() - rollout_started - if not isinstance(group, TrajectoryGroup): + groups = result if isinstance(result, list) else [result] + if not groups or not all( + isinstance(group, TrajectoryGroup) for group in groups + ): errored = True continue - self._apply_scenario_metadata(group, scenario) - self._apply_policy_versions( - group, - initial_version=initial_version, - final_version=self.state.policy_version, - ) - if self.state.done: - break - queue_wait_s = await self._put_output_group(group) - group.metadata[_ROLLOUT_WALL_TIME_KEY] = rollout_wall_s - group.metadata[_ACTOR_IDLE_TIME_KEY] = actor_idle_s + queue_wait_s + rollout_wall_per_group = rollout_wall_s / len(groups) + actor_idle_per_group = actor_idle_s / len(groups) + for group in groups: + self._apply_scenario_metadata(group, scenario) + self._apply_policy_versions( + group, + initial_version=initial_version, + final_version=self.state.policy_version, + ) + if self.state.done: + break + queue_wait_s = await self._put_output_group(group) + group.metadata[_ROLLOUT_WALL_TIME_KEY] = rollout_wall_per_group + group.metadata[_ACTOR_IDLE_TIME_KEY] = ( + actor_idle_per_group + queue_wait_s + ) except asyncio.CancelledError: raise except Exception as exc: @@ -464,11 +497,24 @@ async def _training_stage(self) -> None: } if self.packed_sequence_length is not None: train_kwargs["packed_sequence_length"] = self.packed_sequence_length + train_kwargs.update( + self._backend_kl_train_kwargs(current_step=current_step) + ) result = await self.backend.train( self.model, batch, **train_kwargs, ) + checkpoint_path = getattr(result, "checkpoint_path", None) + if ( + should_checkpoint + and self.save_checkpoint_artifact + and checkpoint_path is not None + ): + self._save_checkpoint_artifact( + checkpoint_path=checkpoint_path, + step=result.step, + ) except Exception: self._status.note_training_end() raise @@ -810,6 +856,53 @@ def _should_eval_step(self, step: int) -> bool: return False return (step - self.state.last_eval_step) >= self.eval_every_n_steps + def _backend_kl_train_kwargs(self, *, current_step: int) -> dict[str, Any]: + if self.kl_penalty_coef <= 0: + return {} + + kwargs: dict[str, Any] = {"kl_penalty_coef": self.kl_penalty_coef} + if self.kl_ref_adapter_path is not None: + kwargs["kl_ref_adapter_path"] = self.kl_ref_adapter_path + return kwargs + + if self.kl_penalty_reference_step is not None: + kwargs["kl_penalty_reference_step"] = self.kl_penalty_reference_step + return kwargs + + if self.kl_window_size is None: + return kwargs + + if self.kl_window_size == 0: + if self.kl_window_base_adapter_path is not None: + kwargs["kl_ref_adapter_path"] = self.kl_window_base_adapter_path + return kwargs + + target_step = current_step - self.kl_window_size + if target_step <= self.kl_window_base_step: + reference_step = self.kl_window_base_step + elif self.eval_every_n_steps <= 0: + reference_step = target_step + else: + window_steps = (target_step - self.kl_window_base_step) // ( + self.eval_every_n_steps + ) + reference_step = ( + self.kl_window_base_step + window_steps * self.eval_every_n_steps + ) + kwargs["kl_penalty_reference_step"] = reference_step + return kwargs + + def _save_checkpoint_artifact(self, *, checkpoint_path: str, step: int) -> None: + from art.utils.deployment import WandbDeploymentConfig, deploy_wandb + + deploy_wandb( + model=self.model, + checkpoint_path=checkpoint_path, + step=step, + config=WandbDeploymentConfig(provenance=["local-rl"]), + verbose=True, + ) + def _read_pipeline_state(self) -> dict[str, Any]: state = self.model.read_state() or {} return state.get(PIPELINE_STATE_KEY, {}) @@ -829,6 +922,9 @@ def _is_scalar_metadata(value: object) -> bool: async def _put_output_group(self, group: TrajectoryGroup) -> float: assert self._output_queue is not None + if group.metadata and group.metadata.get("skip_training"): + self._status.note_zero_variance_discarded(1) + return 0.0 queue_wait_started = time.monotonic() while not self.state.done: try: diff --git a/src/art/pipeline_trainer/types.py b/src/art/pipeline_trainer/types.py index 4b04891e2..4dd8d2263 100644 --- a/src/art/pipeline_trainer/types.py +++ b/src/art/pipeline_trainer/types.py @@ -12,7 +12,8 @@ RolloutFn = Callable[ - [art.TrainableModel, ScenarioT, ConfigT], Awaitable[TrajectoryGroup] + [art.TrainableModel, ScenarioT, ConfigT], + Awaitable[TrajectoryGroup | list[TrajectoryGroup]], ] SingleRolloutFn = Callable[ diff --git a/src/art/unsloth/service.py b/src/art/unsloth/service.py index 2a5a60abf..22c0f6c94 100644 --- a/src/art/unsloth/service.py +++ b/src/art/unsloth/service.py @@ -108,6 +108,7 @@ class UnslothService: output_dir: str _is_sleeping: bool = False _latest_step: int = 0 + _forked_checkpoint_dir: str | None = None _lora_id_counter: int = 1 # Start from 1 since 0 is reserved # Dedicated mode subprocess state _vllm_process: subprocess.Popen | None = field(default=None, repr=False) # type: ignore[type-arg] @@ -571,6 +572,14 @@ async def register_lora_for_step(self, step: int, checkpoint_dir: str) -> None: self._latest_step = step await llm.resume_generation() + async def _load_forked_checkpoint_if_needed(self) -> None: + forked_dir = self._forked_checkpoint_dir + if forked_dir is None: + return + + self._forked_checkpoint_dir = None + await self._state.load_lora_adapter(forked_dir) + async def train( self, disk_packed_tensors: DiskPackedTensors, @@ -598,6 +607,8 @@ async def _train_dedicated( verbose: bool = False, ) -> AsyncIterator[dict[str, float]]: """Train in dedicated mode — no sleep/wake, vLLM keeps running on separate GPU.""" + await self._load_forked_checkpoint_if_needed() + async for result in run_unsloth_rl_training( self._state, disk_packed_tensors=disk_packed_tensors, @@ -663,6 +674,8 @@ async def _train_shared( # Reload training model to GPU (after vLLM is asleep) self._state.reload_to_gpu() + await self._load_forked_checkpoint_if_needed() + async for result in run_unsloth_rl_training( self._state, disk_packed_tensors=disk_packed_tensors, diff --git a/src/art/utils/deployment/wandb.py b/src/art/utils/deployment/wandb.py index 9ddf778e8..49202a41a 100644 --- a/src/art/utils/deployment/wandb.py +++ b/src/art/utils/deployment/wandb.py @@ -32,6 +32,18 @@ class WandbDeploymentConfig(DeploymentConfig): "Qwen/Qwen2.5-14B-Instruct", ] +WANDB_BASE_MODEL_ALIASES = { + "unsloth/Meta-Llama-3.1-8B-Instruct": "meta-llama/Llama-3.1-8B-Instruct", + "meta-llama/Meta-Llama-3.1-8B-Instruct": "meta-llama/Llama-3.1-8B-Instruct", + "unsloth/Meta-Llama-3.1-70B-Instruct": "meta-llama/Llama-3.1-70B-Instruct", + "meta-llama/Meta-Llama-3.1-70B-Instruct": "meta-llama/Llama-3.1-70B-Instruct", +} + + +def get_wandb_base_model(base_model: str) -> str: + """Return the W&B inference base model id for compatible aliases.""" + return WANDB_BASE_MODEL_ALIASES.get(base_model, base_model) + def deploy_wandb( model: "TrainableModel", @@ -54,7 +66,8 @@ def deploy_wandb( """ import wandb - if model.base_model not in WANDB_SUPPORTED_BASE_MODELS: + wandb_base_model = get_wandb_base_model(model.base_model) + if wandb_base_model not in WANDB_SUPPORTED_BASE_MODELS: raise UnsupportedBaseModelDeploymentError( message=f"Base model {model.base_model} is not supported for serverless LoRA deployment by W&B. Supported models: {WANDB_SUPPORTED_BASE_MODELS}" ) @@ -77,7 +90,9 @@ def deploy_wandb( settings=wandb.Settings(api_key=os.environ["WANDB_API_KEY"]), ) try: - metadata: dict[str, object] = {"wandb.base_model": model.base_model} + metadata: dict[str, object] = {"wandb.base_model": wandb_base_model} + if wandb_base_model != model.base_model: + metadata["source_base_model"] = model.base_model if config is not None: metadata["wandb.provenance"] = config.provenance artifact = wandb.Artifact( diff --git a/tests/unit/test_pipeline_trainer_batching.py b/tests/unit/test_pipeline_trainer_batching.py index 0ab412e8f..733ed86da 100644 --- a/tests/unit/test_pipeline_trainer_batching.py +++ b/tests/unit/test_pipeline_trainer_batching.py @@ -1,6 +1,7 @@ import asyncio from pathlib import Path -from unittest.mock import MagicMock +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock import pytest @@ -8,6 +9,11 @@ from art.pipeline_trainer.trainer import PipelineTrainer +@pytest.fixture(autouse=True) +def _skip_backend_validation(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(PipelineTrainer, "_validate_backend_support", lambda _self: None) + + def _make_group() -> TrajectoryGroup: return TrajectoryGroup( [ @@ -24,6 +30,22 @@ def _make_group() -> TrajectoryGroup: ) +def _make_group_with_rewards(rewards: list[float]) -> TrajectoryGroup: + return TrajectoryGroup( + [ + Trajectory( + reward=reward, + initial_policy_version=0, + messages_and_choices=[ + {"role": "user", "content": f"prompt-{idx}"}, + {"role": "assistant", "content": f"answer-{idx}"}, + ], + ) + for idx, reward in enumerate(rewards) + ] + ) + + @pytest.mark.asyncio async def test_collect_batch_respects_max_batch_size(tmp_path: Path) -> None: model = TrainableModel( @@ -65,3 +87,148 @@ async def test_collect_batch_respects_max_batch_size(tmp_path: Path) -> None: assert batch == [third] assert discarded == 0 assert saw_sentinel + + +@pytest.mark.asyncio +async def test_pipeline_trainer_forwards_kl_window_reference_step( + tmp_path: Path, +) -> None: + model = TrainableModel( + name="pipeline-kl-window", + project="pipeline-tests", + base_model="test-model", + base_path=str(tmp_path), + ) + backend = MagicMock() + backend.train = AsyncMock(return_value=SimpleNamespace(step=801, metrics={})) + + trainer = PipelineTrainer( + model=model, + backend=backend, # type: ignore[arg-type] + rollout_fn=lambda *_args, **_kwargs: asyncio.sleep(0), + scenarios=[], + config={}, + num_rollout_workers=1, + min_batch_size=1, + max_batch_size=1, + max_steps_off_policy=1000, + max_steps=1, + eval_every_n_steps=10, + kl_penalty_coef=1.0, + kl_window_size=50, + kl_window_base_step=686, + ) + trainer.state.next_training_step = 800 + trainer._output_queue = asyncio.Queue() + await trainer._output_queue.put(_make_group()) + await trainer._output_queue.put(None) + + await trainer._training_stage() + + assert backend.train.await_args.kwargs["kl_penalty_coef"] == 1.0 + assert backend.train.await_args.kwargs["kl_penalty_reference_step"] == 746 + + +@pytest.mark.asyncio +async def test_pipeline_trainer_kl_window_zero_uses_base_adapter_path( + tmp_path: Path, +) -> None: + model = TrainableModel( + name="pipeline-kl-window-zero", + project="pipeline-tests", + base_model="test-model", + base_path=str(tmp_path), + ) + backend = MagicMock() + backend.train = AsyncMock(return_value=SimpleNamespace(step=1, metrics={})) + adapter_path = str(tmp_path / "initial-checkpoint") + + trainer = PipelineTrainer( + model=model, + backend=backend, # type: ignore[arg-type] + rollout_fn=lambda *_args, **_kwargs: asyncio.sleep(0), + scenarios=[], + config={}, + num_rollout_workers=1, + min_batch_size=1, + max_batch_size=1, + max_steps=1, + kl_penalty_coef=0.5, + kl_window_size=0, + kl_window_base_adapter_path=adapter_path, + ) + trainer._output_queue = asyncio.Queue() + await trainer._output_queue.put(_make_group()) + await trainer._output_queue.put(None) + + await trainer._training_stage() + + assert backend.train.await_args.kwargs["kl_penalty_coef"] == 0.5 + assert backend.train.await_args.kwargs["kl_ref_adapter_path"] == adapter_path + + +@pytest.mark.asyncio +async def test_pipeline_trainer_rollout_worker_accepts_multiple_groups( + tmp_path: Path, +) -> None: + model = TrainableModel( + name="pipeline-multi-group-rollout", + project="pipeline-tests", + base_model="test-model", + base_path=str(tmp_path), + ) + + group_a = _make_group_with_rewards([0.0, 1.0]) + group_b = _make_group_with_rewards([0.25, 0.75]) + + async def rollout_fn(*_args: object) -> list[TrajectoryGroup]: + return [group_a, group_b] + + trainer = PipelineTrainer( + model=model, + backend=MagicMock(), # type: ignore[arg-type] + rollout_fn=rollout_fn, + scenarios=[{"id": "scenario-1"}], + config={}, + num_rollout_workers=1, + min_batch_size=1, + max_steps=1, + ) + trainer._output_queue = asyncio.Queue() + + await trainer._rollout_worker(worker_id=0) + + assert await trainer._output_queue.get() is group_a + assert await trainer._output_queue.get() is group_b + assert group_a.metadata["_art_rollout_wall_s"] >= 0 + assert group_b.metadata["_art_actor_idle_s"] >= 0 + + +@pytest.mark.asyncio +async def test_pipeline_trainer_skips_groups_marked_skip_training( + tmp_path: Path, +) -> None: + model = TrainableModel( + name="pipeline-skip-training-group", + project="pipeline-tests", + base_model="test-model", + base_path=str(tmp_path), + ) + trainer = PipelineTrainer( + model=model, + backend=MagicMock(), # type: ignore[arg-type] + rollout_fn=lambda *_args, **_kwargs: asyncio.sleep(0), + scenarios=[], + config={}, + num_rollout_workers=1, + min_batch_size=1, + max_steps=1, + ) + trainer._output_queue = asyncio.Queue() + group = _make_group_with_rewards([0.0, 0.0]) + group.metadata["skip_training"] = True + + queue_wait_s = await trainer._put_output_group(group) + + assert queue_wait_s == 0.0 + assert trainer._output_queue.empty() diff --git a/tests/unit/test_pipeline_trainer_local_backend.py b/tests/unit/test_pipeline_trainer_local_backend.py index 90e2c59d7..f3696d84e 100644 --- a/tests/unit/test_pipeline_trainer_local_backend.py +++ b/tests/unit/test_pipeline_trainer_local_backend.py @@ -15,6 +15,7 @@ from art.megatron.train import load_adapter_into_model from art.pipeline_trainer.trainer import PipelineTrainer from art.preprocessing.tokenize import TokenizedResult +from art.utils.deployment.wandb import get_wandb_base_model from art.utils.output_dirs import get_model_dir @@ -159,6 +160,80 @@ async def test_pipeline_trainer_uses_same_train_kwargs_for_local_backend( } +@pytest.mark.asyncio +async def test_pipeline_trainer_saves_checkpoint_artifact_on_eval_step( + tmp_path: Path, +) -> None: + model = TrainableModel( + name="pipeline-save-checkpoint-artifact", + project="pipeline-tests", + base_model="test-model", + base_path=str(tmp_path), + ) + checkpoint_path = str(tmp_path / "checkpoint-1") + backend = MagicMock() + backend.train = AsyncMock( + return_value=SimpleNamespace( + step=1, + metrics={}, + checkpoint_path=checkpoint_path, + ) + ) + + trainer = _make_trainer( + model=model, + backend=backend, + eval_fn=AsyncMock(return_value=[]), + eval_every_n_steps=1, + save_checkpoint_artifact=True, + ) + trainer._save_checkpoint_artifact = MagicMock() # type: ignore[method-assign] + trainer._output_queue = asyncio.Queue() + await trainer._output_queue.put(_make_group([0.0, 1.0])) + await trainer._output_queue.put(None) + + await trainer._training_stage() + + assert backend.train.await_args.kwargs["save_checkpoint"] is True + trainer._save_checkpoint_artifact.assert_called_once_with( # type: ignore[attr-defined] + checkpoint_path=checkpoint_path, + step=1, + ) + + +def test_pipeline_trainer_checkpoint_artifact_requires_checkpoint( + tmp_path: Path, +) -> None: + model = TrainableModel( + name="pipeline-save-checkpoint-artifact-validation", + project="pipeline-tests", + base_model="test-model", + base_path=str(tmp_path), + ) + backend = MagicMock() + + with pytest.raises( + ValueError, match="save_checkpoint_artifact=True requires save_checkpoint=True" + ): + _make_trainer( + model=model, + backend=backend, + save_checkpoint=False, + save_checkpoint_artifact=True, + ) + + +def test_wandb_base_model_aliases_for_unsloth_llama() -> None: + assert ( + get_wandb_base_model("unsloth/Meta-Llama-3.1-8B-Instruct") + == "meta-llama/Llama-3.1-8B-Instruct" + ) + assert ( + get_wandb_base_model("unsloth/Meta-Llama-3.1-70B-Instruct") + == "meta-llama/Llama-3.1-70B-Instruct" + ) + + @pytest.mark.asyncio async def test_local_backend_train_translates_loss_fn(tmp_path: Path) -> None: model = TrainableModel(