From 02fd39cb6404a1f065a6f4652f90b365003a36e0 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 26 May 2026 16:58:52 -0400 Subject: [PATCH 01/41] Add tool to evaluate layer-wise numerical-error propagation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit A new `tools/evaluate_precision.py` (`RunnableConfig`) drives a fp32 reference run plus one one-iteration trainer run per named variant from a Fast-LLM training YAML, then extracts per-layer forward activations and input gradients from the saved tensor logs and reports per-tensor RMS and max diffs (absolute and scaled). Variants are flat dicts of dotted-path overrides, the same syntax as Fast-LLM CLI key=value args, so they can sweep arbitrary configuration knobs (dtype, attention implementation, optimizer dtype, etc.) — not just compute_dtype. Also moves `compare_tensor_logs.py` into the `fast_llm` package so it is importable from `tools/` (the test tree isn't on sys.path for script entry points), and factors a `_compute_diff` helper out of `CompareConfig.compare_tensors` so the tool can extract numbers for every tensor rather than only those that breach a tolerance. Existing test callers are unaffected. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../config_utils}/compare_tensor_logs.py | 55 +++-- tests/models/test_checkpoint.py | 2 +- tests/models/test_match_megatron.py | 2 +- tests/utils/distributed_configs.py | 2 +- tools/evaluate_precision.py | 201 ++++++++++++++++++ 5 files changed, 243 insertions(+), 19 deletions(-) rename {tests/utils => fast_llm/engine/config_utils}/compare_tensor_logs.py (79%) create mode 100644 tools/evaluate_precision.py diff --git a/tests/utils/compare_tensor_logs.py b/fast_llm/engine/config_utils/compare_tensor_logs.py similarity index 79% rename from tests/utils/compare_tensor_logs.py rename to fast_llm/engine/config_utils/compare_tensor_logs.py index f02d62c79..080510036 100644 --- a/tests/utils/compare_tensor_logs.py +++ b/fast_llm/engine/config_utils/compare_tensor_logs.py @@ -87,6 +87,30 @@ def _compare_dict_keys(self, dict_ref, dict_test, errors, name): # Avoid set to preserve ordering. return [key for key in dict_test if key in dict_ref] + def _compute_diff(self, tensor_ref, tensor_test, step_name, tensor_name) -> dict | None: + # Returns per-tensor error metrics, or None on shape/sampling mismatch. + if tensor_ref["shape"] != tensor_test["shape"]: + return None + if tensor_ref["step"] != tensor_test["step"]: + return None + sub_config = self._get_sub_config(step_name, tensor_name) + samples_ref = tensor_ref["samples"].flatten().float() + samples_test = tensor_test["samples"].flatten().float() + if sub_config.scale != 1.0: + samples_test = samples_test / sub_config.scale + scale_unreg = (samples_ref**2).mean() ** 0.5 + rms_scale = (scale_unreg**2 + sub_config.rms_eps**2) ** 0.5 + rms = ((samples_ref - samples_test) ** 2).mean() ** 0.5 + max_diff = (samples_ref - samples_test).abs().max() + return { + "rms_abs": rms.item(), + "rms_rel": (rms / rms_scale).item(), + "max_abs": max_diff.item(), + "max_rel": (max_diff / rms_scale).item(), + "ref_scale": scale_unreg.item(), + "ref_scale_regularized": rms_scale.item(), + } + def compare_tensors(self, tensor_ref, tensor_test, errors, step_name, tensor_name): sub_config = self._get_sub_config(step_name, tensor_name) if tensor_ref["shape"] != tensor_test["shape"]: @@ -108,34 +132,33 @@ def compare_tensors(self, tensor_ref, tensor_test, errors, step_name, tensor_nam ) return - samples_ref = tensor_ref["samples"].flatten().float() - samples_test = tensor_test["samples"].flatten().float() - if sub_config.scale != 1.0: - samples_test = samples_test / sub_config.scale - scale_unreg = (samples_ref**2).mean() ** 0.5 - rms_scale = (scale_unreg**2 + sub_config.rms_eps**2) ** 0.5 - rms = ((samples_ref - samples_test) ** 2).mean() ** 0.5 - max_diff = (samples_ref - samples_test).abs().max() + metrics = self._compute_diff(tensor_ref, tensor_test, step_name, tensor_name) + rms_scale = metrics["ref_scale_regularized"] + scale_unreg = metrics["ref_scale"] tensor_errors = [] - if rms > sub_config.rms_abs_tolerance: - tensor_errors.append(f" * RMS diff absolute = {rms} > {sub_config.rms_abs_tolerance}") + if metrics["rms_abs"] > sub_config.rms_abs_tolerance: + tensor_errors.append(f" * RMS diff absolute = {metrics['rms_abs']} > {sub_config.rms_abs_tolerance}") - if rms / rms_scale > sub_config.rms_rel_tolerance: + if metrics["rms_rel"] > sub_config.rms_rel_tolerance: tensor_errors.append( - f" * RMS diff scaled = {rms / rms_scale} > {sub_config.rms_rel_tolerance} (scale={rms_scale}, unregularized={scale_unreg})" + f" * RMS diff scaled = {metrics['rms_rel']} > {sub_config.rms_rel_tolerance} (scale={rms_scale}, unregularized={scale_unreg})" ) - if max_diff > sub_config.max_abs_tolerance: - tensor_errors.append(f" * Max diff absolute = {max_diff} > {sub_config.max_abs_tolerance}") + if metrics["max_abs"] > sub_config.max_abs_tolerance: + tensor_errors.append(f" * Max diff absolute = {metrics['max_abs']} > {sub_config.max_abs_tolerance}") - if max_diff / rms_scale > sub_config.max_rel_tolerance: + if metrics["max_rel"] > sub_config.max_rel_tolerance: tensor_errors.append( - f" * Max diff scaled = {max_diff / rms_scale} > {sub_config.max_rel_tolerance} (scale={rms_scale}, unregularized={scale_unreg})" + f" * Max diff scaled = {metrics['max_rel']} > {sub_config.max_rel_tolerance} (scale={rms_scale}, unregularized={scale_unreg})" ) if tensor_errors: + samples_ref = tensor_ref["samples"].flatten().float() + samples_test = tensor_test["samples"].flatten().float() + if sub_config.scale != 1.0: + samples_test = samples_test / sub_config.scale tensor_errors.extend( [ f" Test samples: " + "".join(f"{x:12.4e}" for x in samples_test[: self.show_samples].tolist()), diff --git a/tests/models/test_checkpoint.py b/tests/models/test_checkpoint.py index 0b4dbafc1..f3febae4b 100644 --- a/tests/models/test_checkpoint.py +++ b/tests/models/test_checkpoint.py @@ -18,9 +18,9 @@ ModelConfigType, ) from fast_llm.engine.checkpoint.convert import ConvertConfig +from fast_llm.engine.config_utils.compare_tensor_logs import CompareConfig from fast_llm.engine.multi_stage.config import FastLLMModelConfig, ShardName, StageMode from fast_llm.utils import Assert, header -from tests.utils.compare_tensor_logs import CompareConfig from tests.utils.distributed_configs import DistributedTestingConfig from tests.utils.model_configs import ModelTestingConfig, ModelTestingGroup from tests.utils.save_load_configs import DISTRIBUTED_SAVE_LOAD_CONFIGS, DistributedSaveLoadConfig diff --git a/tests/models/test_match_megatron.py b/tests/models/test_match_megatron.py index 03ebac757..3c95d0dea 100644 --- a/tests/models/test_match_megatron.py +++ b/tests/models/test_match_megatron.py @@ -18,9 +18,9 @@ from fast_llm.data.dataset.sampled import logger from fast_llm.data.document.language_model import LanguageModelDocument from fast_llm.data.preparation.tokenizer import TokenizerConfig +from fast_llm.engine.config_utils.compare_tensor_logs import CompareConfig from fast_llm.engine.config_utils.data_type import DataType from fast_llm.utils import Assert -from tests.utils.compare_tensor_logs import CompareConfig from tests.utils.dataset import get_common_test_dataset from tests.utils.distributed_configs import DistributedTestingConfig from tests.utils.global_variables import DATASET_CACHE, MODEL_TEST_VOCAB_SIZE, TOKENIZER_NAME diff --git a/tests/utils/distributed_configs.py b/tests/utils/distributed_configs.py index f3bbbac8d..d08b023b9 100644 --- a/tests/utils/distributed_configs.py +++ b/tests/utils/distributed_configs.py @@ -4,7 +4,7 @@ import torch -from tests.utils.compare_tensor_logs import CompareConfig +from fast_llm.engine.config_utils.compare_tensor_logs import CompareConfig logger = logging.getLogger(__name__) diff --git a/tools/evaluate_precision.py b/tools/evaluate_precision.py new file mode 100644 index 000000000..782ff996d --- /dev/null +++ b/tools/evaluate_precision.py @@ -0,0 +1,201 @@ +import json +import logging +import pathlib +import typing + +import yaml + +from fast_llm.config import Field, FieldHint, config_class +from fast_llm.engine.config_utils.compare_tensor_logs import CompareConfig +from fast_llm.engine.config_utils.data_type import DataType +from fast_llm.engine.config_utils.runnable import RunnableConfig +from fast_llm.engine.training.config import TrainerConfig + +# Populate the trainer dynamic-type registry. +import fast_llm.data.auto # noqa: F401 # isort:skip +import fast_llm.engine.checkpoint.convert # noqa: F401 # isort:skip +import fast_llm.models.auto # noqa: F401 # isort:skip + +logger = logging.getLogger(__name__) + + +# Tensor-log verbosity level. 13 gives 2**(13-3)=1024 sampled values per tensor, +# matching the convention in the existing layer-comparison tests. +_LOG_LEVEL = 13 +_REFERENCE_NAME = "reference" + + +@config_class() +class EvaluatePrecisionConfig(RunnableConfig): + training_config: pathlib.Path = Field( + desc="Path to a Fast-LLM training YAML serving as the fp32 reference configuration.", + hint=FieldHint.core, + ) + model_type: str = Field( + desc="Trainer dynamic-type name (e.g. 'gpt') used to dispatch to the right TrainerConfig subclass.", + hint=FieldHint.core, + ) + variants: dict[str, typing.Any] = Field( + desc="Named override bundles to evaluate against the fp32 reference." + " Each value is a flat dict mapping dotted-path keys (same syntax as the Fast-LLM CLI) to values.", + hint=FieldHint.core, + ) + output_dir: pathlib.Path = Field( + desc="Directory for per-run tensor-log artifacts and the final JSON report.", + hint=FieldHint.core, + ) + num_samples: int = Field( + default=1024, + desc="Number of sampled values stored per logged tensor.", + hint=FieldHint.feature, + ) + + def _validate(self) -> None: + super()._validate() + assert self.training_config.is_file(), f"Training config not found: {self.training_config}" + assert _REFERENCE_NAME not in self.variants, f"'{_REFERENCE_NAME}' is reserved for the fp32 baseline." + for name, overrides in self.variants.items(): + assert isinstance(overrides, dict) and all( + isinstance(k, str) for k in overrides + ), f"Variant {name!r} must be a flat dict of dotted-path string keys." + + def run(self) -> None: + base_dict = yaml.safe_load(self.training_config.read_text()) + for field_name in ("compute_dtype", "optimization_dtype"): + current = _get_nested(base_dict, ("model", "distributed", field_name)) + if current is not None and DataType(current) is not DataType.float32: + logger.warning( + f"Base config sets model.distributed.{field_name}={current!r};" + f" overriding to float32 for the reference run." + ) + + runs: dict[str, dict[str, typing.Any]] = {_REFERENCE_NAME: {}} + runs.update(self.variants) + for name, variant_overrides in runs.items(): + self._run_one(name, variant_overrides) + + ref_artifacts = self._artifact_path(_REFERENCE_NAME) + results = {name: self._compare(ref_artifacts, self._artifact_path(name)) for name in self.variants} + + report_path = self.output_dir / "precision_report.json" + report_path.parent.mkdir(parents=True, exist_ok=True) + report_path.write_text(json.dumps(results, indent=2)) + logger.info(f"Wrote report to {report_path}") + + for name, rows in results.items(): + _print_table(name, rows) + + def _artifact_path(self, name: str) -> pathlib.Path: + return self.output_dir / name / "runs" / "0" / "artifacts" + + def _run_one(self, name: str, variant_overrides: dict[str, typing.Any]) -> None: + experiment_dir = (self.output_dir / name).resolve() + forced_fp32 = { + "model.distributed.compute_dtype": "float32", + "model.distributed.optimization_dtype": "float32", + } + tool_overrides = { + "training.train_iters": 1, + "training.checkpoint.interval": None, + "run.tensor_logs.save": True, + "run.tensor_logs.show": False, + "run.tensor_logs.max_elements": self.num_samples, + "run.experiment_dir": str(experiment_dir), + "model.multi_stage.debug_layer_outputs": _LOG_LEVEL, + "model.multi_stage.debug_layer_gradients": _LOG_LEVEL, + } + # Compose: forced fp32 first so a variant can override it (e.g. compute_dtype=bfloat16); + # tool overrides last so logging and single-iteration mode always win. + combined = {**forced_fp32, **variant_overrides, **tool_overrides} + cli_overrides = [f"{key}={yaml.safe_dump(value).strip()}" for key, value in combined.items()] + logger.info(f"=== Running {name!r} ===") + if variant_overrides: + logger.info(f"Variant overrides: {variant_overrides}") + TrainerConfig.parse_and_run([self.model_type, "-c", str(self.training_config), *cli_overrides]) + + def _compare(self, ref_path: pathlib.Path, test_path: pathlib.Path) -> list[dict[str, typing.Any]]: + compare_config = CompareConfig() + errors: list[str] = [] + ref_logs = compare_config._extract_tensor_logs(ref_path, errors) + test_logs = compare_config._extract_tensor_logs(test_path, errors) + for error in errors: + logger.warning(error) + rows: list[dict[str, typing.Any]] = [] + for step_name in sorted(ref_logs): + if step_name not in test_logs: + logger.warning(f"Step {step_name!r} missing from test logs") + continue + step_ref = ref_logs[step_name] + step_test = test_logs[step_name] + for tensor_name, ref in step_ref.items(): + if tensor_name not in step_test: + continue + metrics = compare_config._compute_diff(ref, step_test[tensor_name], step_name, tensor_name) + if metrics is None: + continue + rows.append( + { + "step": step_name, + "tensor_name": tensor_name, + "kind": _classify(tensor_name), + "shape": ref["shape"], + **metrics, + } + ) + return rows + + +def _classify(tensor_name: str) -> str: + # Stage._log_layer_forward / _log_layer_backward produce " fw[, mb=…]" + # and " bw[, mb=…]"; log_distributed_tensor may prefix the name + # with "Global " and append a ": " suffix when reconstructing a + # tensor-parallel-global tensor. + for kind in ("fw", "bw"): + if f" {kind}:" in tensor_name or f" {kind}," in tensor_name or tensor_name.endswith(f" {kind}"): + return kind + return "other" + + +def _get_nested(d: typing.Any, keys: tuple[str, ...]) -> typing.Any: + for k in keys: + if not isinstance(d, dict) or k not in d: + return None + d = d[k] + return d + + +def _print_table(name: str, rows: list[dict[str, typing.Any]]) -> None: + print(f"\n=== Variant: {name} ===") + if not rows: + print("(no matching tensors)") + return + columns = [ + ("step", "step", 6), + ("kind", "kind", 6), + ("tensor_name", "tensor", 48), + ("shape", "shape", 22), + ("ref_scale", "ref_scale", 12), + ("rms_abs", "rms_abs", 12), + ("rms_rel", "rms_rel", 12), + ("max_abs", "max_abs", 12), + ("max_rel", "max_rel", 12), + ] + header = " ".join(f"{title:<{width}}" for _, title, width in columns) + print(header) + print("-" * len(header)) + for row in rows: + parts = [] + for key, _, width in columns: + value = row[key] + if isinstance(value, float): + cell = f"{value:.4e}" + elif isinstance(value, list): + cell = "x".join(str(x) for x in value) + else: + cell = str(value) + parts.append(f"{cell:<{width}}") + print(" ".join(parts)) + + +if __name__ == "__main__": + EvaluatePrecisionConfig.parse_and_run() From 4dd6c1498b4a39d733a0702c0ab381e64cfafd0a Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 27 May 2026 14:57:14 -0400 Subject: [PATCH 02/41] Collapse to a single config; require a checkpoint MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The tool now takes a single YAML containing `pretrained:` (the checkpoint that defines the model architecture + weights), `variants:`, `output_dir:` and a few optional knobs (`model_type`, `num_samples`, `micro_batch_size`, `sequence_length`). The training/optimizer/data sections of the underlying training config are hardcoded — they have no bearing on the propagation measurement (1 iteration, no checkpoint save, random tokens, dummy learning rate, optimization dtype forced to float32 alongside compute dtype). A variant can still override any of the hardcoded fields via the dotted-path mechanism if needed. Co-Authored-By: Claude Opus 4.7 (1M context) --- tools/evaluate_precision.py | 115 ++++++++++++++++++++++-------------- 1 file changed, 71 insertions(+), 44 deletions(-) diff --git a/tools/evaluate_precision.py b/tools/evaluate_precision.py index 782ff996d..016744468 100644 --- a/tools/evaluate_precision.py +++ b/tools/evaluate_precision.py @@ -7,7 +7,6 @@ from fast_llm.config import Field, FieldHint, config_class from fast_llm.engine.config_utils.compare_tensor_logs import CompareConfig -from fast_llm.engine.config_utils.data_type import DataType from fast_llm.engine.config_utils.runnable import RunnableConfig from fast_llm.engine.training.config import TrainerConfig @@ -27,12 +26,9 @@ @config_class() class EvaluatePrecisionConfig(RunnableConfig): - training_config: pathlib.Path = Field( - desc="Path to a Fast-LLM training YAML serving as the fp32 reference configuration.", - hint=FieldHint.core, - ) - model_type: str = Field( - desc="Trainer dynamic-type name (e.g. 'gpt') used to dispatch to the right TrainerConfig subclass.", + pretrained: dict[str, typing.Any] = Field( + desc="Fast-LLM `CheckpointLoadConfig` dict (e.g. `{path: HuggingFaceTB/SmolLM2-135M, format: llama}`)." + " The model architecture and weights are loaded from this checkpoint.", hint=FieldHint.core, ) variants: dict[str, typing.Any] = Field( @@ -44,15 +40,29 @@ class EvaluatePrecisionConfig(RunnableConfig): desc="Directory for per-run tensor-log artifacts and the final JSON report.", hint=FieldHint.core, ) + model_type: str = Field( + default="gpt", + desc="Trainer dynamic-type name used to dispatch to the right `TrainerConfig` subclass.", + hint=FieldHint.optional, + ) num_samples: int = Field( default=1024, desc="Number of sampled values stored per logged tensor.", hint=FieldHint.feature, ) + micro_batch_size: int = Field( + default=1, + desc="Micro-batch size for the single forward+backward pass.", + hint=FieldHint.feature, + ) + sequence_length: int = Field( + default=2048, + desc="Sequence length (maximum document length) for the random input.", + hint=FieldHint.feature, + ) def _validate(self) -> None: super()._validate() - assert self.training_config.is_file(), f"Training config not found: {self.training_config}" assert _REFERENCE_NAME not in self.variants, f"'{_REFERENCE_NAME}' is reserved for the fp32 baseline." for name, overrides in self.variants.items(): assert isinstance(overrides, dict) and all( @@ -60,15 +70,7 @@ def _validate(self) -> None: ), f"Variant {name!r} must be a flat dict of dotted-path string keys." def run(self) -> None: - base_dict = yaml.safe_load(self.training_config.read_text()) - for field_name in ("compute_dtype", "optimization_dtype"): - current = _get_nested(base_dict, ("model", "distributed", field_name)) - if current is not None and DataType(current) is not DataType.float32: - logger.warning( - f"Base config sets model.distributed.{field_name}={current!r};" - f" overriding to float32 for the reference run." - ) - + self.output_dir.mkdir(parents=True, exist_ok=True) runs: dict[str, dict[str, typing.Any]] = {_REFERENCE_NAME: {}} runs.update(self.variants) for name, variant_overrides in runs.items(): @@ -78,7 +80,6 @@ def run(self) -> None: results = {name: self._compare(ref_artifacts, self._artifact_path(name)) for name in self.variants} report_path = self.output_dir / "precision_report.json" - report_path.parent.mkdir(parents=True, exist_ok=True) report_path.write_text(json.dumps(results, indent=2)) logger.info(f"Wrote report to {report_path}") @@ -89,29 +90,57 @@ def _artifact_path(self, name: str) -> pathlib.Path: return self.output_dir / name / "runs" / "0" / "artifacts" def _run_one(self, name: str, variant_overrides: dict[str, typing.Any]) -> None: - experiment_dir = (self.output_dir / name).resolve() - forced_fp32 = { - "model.distributed.compute_dtype": "float32", - "model.distributed.optimization_dtype": "float32", - } - tool_overrides = { - "training.train_iters": 1, - "training.checkpoint.interval": None, - "run.tensor_logs.save": True, - "run.tensor_logs.show": False, - "run.tensor_logs.max_elements": self.num_samples, - "run.experiment_dir": str(experiment_dir), - "model.multi_stage.debug_layer_outputs": _LOG_LEVEL, - "model.multi_stage.debug_layer_gradients": _LOG_LEVEL, - } - # Compose: forced fp32 first so a variant can override it (e.g. compute_dtype=bfloat16); - # tool overrides last so logging and single-iteration mode always win. - combined = {**forced_fp32, **variant_overrides, **tool_overrides} - cli_overrides = [f"{key}={yaml.safe_dump(value).strip()}" for key, value in combined.items()] + config_dict = self._build_config_dict(name) + # Apply variant overrides on top of the forced-fp32 baseline so a variant can set + # `model.distributed.compute_dtype: bfloat16` (etc.) and have it win. + for dotted_key, value in variant_overrides.items(): + _set_nested(config_dict, dotted_key.split("."), value) + config_yaml = self.output_dir / f"{name}_config.yaml" + config_yaml.write_text(yaml.safe_dump(config_dict)) logger.info(f"=== Running {name!r} ===") if variant_overrides: logger.info(f"Variant overrides: {variant_overrides}") - TrainerConfig.parse_and_run([self.model_type, "-c", str(self.training_config), *cli_overrides]) + TrainerConfig.parse_and_run([self.model_type, "-c", str(config_yaml)]) + + def _build_config_dict(self, name: str) -> dict[str, typing.Any]: + return { + "pretrained": self.pretrained, + "training": { + "train_iters": 1, + "num_workers": 0, + "logs": {"interval": 1}, + }, + "optimizer": { + "learning_rate": { + "base": 0.0, + "decay_style": "constant", + "warmup_iterations": 0, + }, + }, + "data": { + "datasets": {"training": {"type": "random"}}, + "micro_batch_size": self.micro_batch_size, + "maximum_document_length": self.sequence_length, + }, + "run": { + "experiment_dir": str((self.output_dir / name).resolve()), + "tensor_logs": { + "save": True, + "show": False, + "max_elements": self.num_samples, + }, + }, + "model": { + "distributed": { + "compute_dtype": "float32", + "optimization_dtype": "float32", + }, + "multi_stage": { + "debug_layer_outputs": _LOG_LEVEL, + "debug_layer_gradients": _LOG_LEVEL, + }, + }, + } def _compare(self, ref_path: pathlib.Path, test_path: pathlib.Path) -> list[dict[str, typing.Any]]: compare_config = CompareConfig() @@ -156,12 +185,10 @@ def _classify(tensor_name: str) -> str: return "other" -def _get_nested(d: typing.Any, keys: tuple[str, ...]) -> typing.Any: - for k in keys: - if not isinstance(d, dict) or k not in d: - return None - d = d[k] - return d +def _set_nested(d: dict[str, typing.Any], keys: list[str], value: typing.Any) -> None: + for key in keys[:-1]: + d = d.setdefault(key, {}) + d[keys[-1]] = value def _print_table(name: str, rows: list[dict[str, typing.Any]]) -> None: From 5ebea3374483330ccb8507b1fbebd5515a91c7c3 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 27 May 2026 15:07:47 -0400 Subject: [PATCH 03/41] Expose `model:` alongside `pretrained:` in the tool config MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The tool's input mirrors the trainer config's top-level shape: both `model:` (FastLLMModelConfig dict) and `pretrained:` are user-facing, and either or both may be set. Pretrained-from-HF is one config choice among many — a user can also specify the architecture inline, or load from HF and override individual fields. The forced fp32 dtypes and tool-required debug levels are now applied as overrides on top of whatever the user supplies, instead of being hardcoded into the model section. Co-Authored-By: Claude Opus 4.7 (1M context) --- tools/evaluate_precision.py | 33 ++++++++++++++++++--------------- 1 file changed, 18 insertions(+), 15 deletions(-) diff --git a/tools/evaluate_precision.py b/tools/evaluate_precision.py index 016744468..09925f263 100644 --- a/tools/evaluate_precision.py +++ b/tools/evaluate_precision.py @@ -26,11 +26,19 @@ @config_class() class EvaluatePrecisionConfig(RunnableConfig): - pretrained: dict[str, typing.Any] = Field( - desc="Fast-LLM `CheckpointLoadConfig` dict (e.g. `{path: HuggingFaceTB/SmolLM2-135M, format: llama}`)." - " The model architecture and weights are loaded from this checkpoint.", + model: dict[str, typing.Any] = Field( + default_factory=dict, + desc="`FastLLMModelConfig` dict (`base_model`, `distributed`, `multi_stage`)." + " Forwarded into the trainer config as-is alongside `pretrained`. Either or both" + " can be set: `pretrained` to load architecture/weights from a checkpoint," + " `model` to specify the architecture inline or override pretrained fields.", hint=FieldHint.core, ) + pretrained: dict[str, typing.Any] = Field( + default_factory=dict, + desc="`CheckpointLoadConfig` dict, e.g. `{path: HuggingFaceTB/SmolLM2-135M, format: llama}`.", + hint=FieldHint.optional, + ) variants: dict[str, typing.Any] = Field( desc="Named override bundles to evaluate against the fp32 reference." " Each value is a flat dict mapping dotted-path keys (same syntax as the Fast-LLM CLI) to values.", @@ -91,10 +99,14 @@ def _artifact_path(self, name: str) -> pathlib.Path: def _run_one(self, name: str, variant_overrides: dict[str, typing.Any]) -> None: config_dict = self._build_config_dict(name) - # Apply variant overrides on top of the forced-fp32 baseline so a variant can set - # `model.distributed.compute_dtype: bfloat16` (etc.) and have it win. + # Force fp32 on the reference baseline (variants apply on top and can re-override). + _set_nested(config_dict, ["model", "distributed", "compute_dtype"], "float32") + _set_nested(config_dict, ["model", "distributed", "optimization_dtype"], "float32") for dotted_key, value in variant_overrides.items(): _set_nested(config_dict, dotted_key.split("."), value) + # Tool-required overrides always win — variants must not silently disable tensor logging. + _set_nested(config_dict, ["model", "multi_stage", "debug_layer_outputs"], _LOG_LEVEL) + _set_nested(config_dict, ["model", "multi_stage", "debug_layer_gradients"], _LOG_LEVEL) config_yaml = self.output_dir / f"{name}_config.yaml" config_yaml.write_text(yaml.safe_dump(config_dict)) logger.info(f"=== Running {name!r} ===") @@ -105,6 +117,7 @@ def _run_one(self, name: str, variant_overrides: dict[str, typing.Any]) -> None: def _build_config_dict(self, name: str) -> dict[str, typing.Any]: return { "pretrained": self.pretrained, + "model": self.model, "training": { "train_iters": 1, "num_workers": 0, @@ -130,16 +143,6 @@ def _build_config_dict(self, name: str) -> dict[str, typing.Any]: "max_elements": self.num_samples, }, }, - "model": { - "distributed": { - "compute_dtype": "float32", - "optimization_dtype": "float32", - }, - "multi_stage": { - "debug_layer_outputs": _LOG_LEVEL, - "debug_layer_gradients": _LOG_LEVEL, - }, - }, } def _compare(self, ref_path: pathlib.Path, test_path: pathlib.Path) -> list[dict[str, typing.Any]]: From 4c444d81f52955dad0c07c2b91adbfc7e38ac6aa Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 27 May 2026 15:15:59 -0400 Subject: [PATCH 04/41] Inherit PretrainedGPTModelConfig; use Config update mechanism MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The tool now inherits from `PretrainedGPTModelConfig` so `model` and `pretrained` are typed `FastLLMModelConfig` / `CheckpointLoadConfig` fields rather than loose dicts — validated, autocompleted, and introspectable like any other Fast-LLM config block. Per-variant trainer configs are built with `TrainerConfig.get_subclass(...) .from_dict(base, *updates)` instead of mutating a dict and round-tripping through YAML. Updates use tuple-keyed dotted paths so forced-fp32, variant overrides, and tool-required debug-logging overrides compose cleanly in the right precedence. Co-Authored-By: Claude Opus 4.7 (1M context) --- tools/evaluate_precision.py | 93 ++++++++++++++----------------------- 1 file changed, 36 insertions(+), 57 deletions(-) diff --git a/tools/evaluate_precision.py b/tools/evaluate_precision.py index 09925f263..4c56848b6 100644 --- a/tools/evaluate_precision.py +++ b/tools/evaluate_precision.py @@ -3,12 +3,11 @@ import pathlib import typing -import yaml - from fast_llm.config import Field, FieldHint, config_class from fast_llm.engine.config_utils.compare_tensor_logs import CompareConfig from fast_llm.engine.config_utils.runnable import RunnableConfig from fast_llm.engine.training.config import TrainerConfig +from fast_llm.models.gpt.config import PretrainedGPTModelConfig # Populate the trainer dynamic-type registry. import fast_llm.data.auto # noqa: F401 # isort:skip @@ -22,23 +21,20 @@ # matching the convention in the existing layer-comparison tests. _LOG_LEVEL = 13 _REFERENCE_NAME = "reference" +_MODEL_TYPE = "gpt" @config_class() -class EvaluatePrecisionConfig(RunnableConfig): - model: dict[str, typing.Any] = Field( - default_factory=dict, - desc="`FastLLMModelConfig` dict (`base_model`, `distributed`, `multi_stage`)." - " Forwarded into the trainer config as-is alongside `pretrained`. Either or both" - " can be set: `pretrained` to load architecture/weights from a checkpoint," - " `model` to specify the architecture inline or override pretrained fields.", - hint=FieldHint.core, - ) - pretrained: dict[str, typing.Any] = Field( - default_factory=dict, - desc="`CheckpointLoadConfig` dict, e.g. `{path: HuggingFaceTB/SmolLM2-135M, format: llama}`.", - hint=FieldHint.optional, - ) +class EvaluatePrecisionConfig(PretrainedGPTModelConfig, RunnableConfig): + """Evaluate layer-wise numerical-error propagation against an fp32 reference. + + Inherits `model` and `pretrained` from `PretrainedGPTModelConfig`: either or both + can be set in the YAML. The tool runs one fp32 reference + one trainer invocation + per variant, captures per-layer forward activations and input gradients via the + standard tensor-logs pipeline, and reports per-tensor RMS / max diffs. + """ + + _abstract = False variants: dict[str, typing.Any] = Field( desc="Named override bundles to evaluate against the fp32 reference." " Each value is a flat dict mapping dotted-path keys (same syntax as the Fast-LLM CLI) to values.", @@ -48,11 +44,6 @@ class EvaluatePrecisionConfig(RunnableConfig): desc="Directory for per-run tensor-log artifacts and the final JSON report.", hint=FieldHint.core, ) - model_type: str = Field( - default="gpt", - desc="Trainer dynamic-type name used to dispatch to the right `TrainerConfig` subclass.", - hint=FieldHint.optional, - ) num_samples: int = Field( default=1024, desc="Number of sampled values stored per logged tensor.", @@ -98,37 +89,18 @@ def _artifact_path(self, name: str) -> pathlib.Path: return self.output_dir / name / "runs" / "0" / "artifacts" def _run_one(self, name: str, variant_overrides: dict[str, typing.Any]) -> None: - config_dict = self._build_config_dict(name) - # Force fp32 on the reference baseline (variants apply on top and can re-override). - _set_nested(config_dict, ["model", "distributed", "compute_dtype"], "float32") - _set_nested(config_dict, ["model", "distributed", "optimization_dtype"], "float32") - for dotted_key, value in variant_overrides.items(): - _set_nested(config_dict, dotted_key.split("."), value) - # Tool-required overrides always win — variants must not silently disable tensor logging. - _set_nested(config_dict, ["model", "multi_stage", "debug_layer_outputs"], _LOG_LEVEL) - _set_nested(config_dict, ["model", "multi_stage", "debug_layer_gradients"], _LOG_LEVEL) - config_yaml = self.output_dir / f"{name}_config.yaml" - config_yaml.write_text(yaml.safe_dump(config_dict)) - logger.info(f"=== Running {name!r} ===") - if variant_overrides: - logger.info(f"Variant overrides: {variant_overrides}") - TrainerConfig.parse_and_run([self.model_type, "-c", str(config_yaml)]) - - def _build_config_dict(self, name: str) -> dict[str, typing.Any]: - return { - "pretrained": self.pretrained, - "model": self.model, + # Base config: hardcoded training/optimizer/data/run skeleton plus the user's model/pretrained. + # Forced fp32 on the reference baseline lives in here too so a variant can override it. + base_dict: dict[str, typing.Any] = { + "pretrained": self.pretrained.to_dict(), + "model": self.model.to_dict(), "training": { "train_iters": 1, "num_workers": 0, "logs": {"interval": 1}, }, "optimizer": { - "learning_rate": { - "base": 0.0, - "decay_style": "constant", - "warmup_iterations": 0, - }, + "learning_rate": {"base": 0.0, "decay_style": "constant", "warmup_iterations": 0}, }, "data": { "datasets": {"training": {"type": "random"}}, @@ -137,13 +109,26 @@ def _build_config_dict(self, name: str) -> dict[str, typing.Any]: }, "run": { "experiment_dir": str((self.output_dir / name).resolve()), - "tensor_logs": { - "save": True, - "show": False, - "max_elements": self.num_samples, - }, + "tensor_logs": {"save": True, "show": False, "max_elements": self.num_samples}, }, } + fp32_dtypes = { + ("model", "distributed", "compute_dtype"): "float32", + ("model", "distributed", "optimization_dtype"): "float32", + } + variant_updates = {tuple(key.split(".")): value for key, value in variant_overrides.items()} + # Tool-required overrides win over variants — a variant must not silently disable tensor logging. + tool_overrides = { + ("model", "multi_stage", "debug_layer_outputs"): _LOG_LEVEL, + ("model", "multi_stage", "debug_layer_gradients"): _LOG_LEVEL, + } + logger.info(f"=== Running {name!r} ===") + if variant_overrides: + logger.info(f"Variant overrides: {variant_overrides}") + trainer_class = TrainerConfig.get_subclass(_MODEL_TYPE) + trainer_config = trainer_class.from_dict(base_dict, fp32_dtypes, variant_updates, tool_overrides) + trainer_config.configure_logging() + trainer_config._get_runnable()() def _compare(self, ref_path: pathlib.Path, test_path: pathlib.Path) -> list[dict[str, typing.Any]]: compare_config = CompareConfig() @@ -188,12 +173,6 @@ def _classify(tensor_name: str) -> str: return "other" -def _set_nested(d: dict[str, typing.Any], keys: list[str], value: typing.Any) -> None: - for key in keys[:-1]: - d = d.setdefault(key, {}) - d[keys[-1]] = value - - def _print_table(name: str, rows: list[dict[str, typing.Any]]) -> None: print(f"\n=== Variant: {name} ===") if not rows: From 35206a6c2a37e2ee0d669bcabbed6b5c0cd885cc Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 27 May 2026 15:46:47 -0400 Subject: [PATCH 05/41] Expand HF metadata allowlist for newer transformers configs `transformers.PretrainedConfig.to_dict()` serializes a growing set of generic defaults (generation knobs, family markers, encoder-decoder flags). The Fast-LLM allowlist covered only a subset, so loading any modern HF Llama checkpoint via `pretrained.format: llama` tripped the coverage walker on keys like `torchscript`, `is_decoder`, `is_llama_config`, `rope_interleaved`, and the full set of generation defaults. Fill in the missing entries, grouped by category. None of them are architecture knobs that Fast-LLM consumes. Co-Authored-By: Claude Opus 4.7 (1M context) --- fast_llm/engine/checkpoint/huggingface.py | 41 +++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/fast_llm/engine/checkpoint/huggingface.py b/fast_llm/engine/checkpoint/huggingface.py index c055a7f2c..a4810dc1a 100644 --- a/fast_llm/engine/checkpoint/huggingface.py +++ b/fast_llm/engine/checkpoint/huggingface.py @@ -128,20 +128,32 @@ def _export_config(cls, config: FastLLMModelConfig) -> dict[str, typing.Any]: { # transformers PretrainedConfig "_name_or_path", + "add_cross_attention", "architectures", "auto_map", "chunk_size_feed_forward", + "cross_attention_hidden_size", "dtype", + "finetuning_task", "id2label", + "is_decoder", "is_encoder_decoder", "label2id", "model_type", "output_attentions", "output_hidden_states", + "prefix", "problem_type", + "pruned_heads", "return_dict", + "task_specific_params", + "tf_legacy_loss", + "tie_encoder_decoder", + "tokenizer_class", "torch_dtype", + "torchscript", "transformers_version", + "use_bfloat16", "use_cache", # Token ids — generation/inference, not architecture. "bos_token_id", @@ -149,10 +161,39 @@ def _export_config(cls, config: FastLLMModelConfig) -> dict[str, typing.Any]: "eos_token_id", "pad_token_id", "sep_token_id", + # Generation defaults — never architecture. + "bad_words_ids", + "begin_suppress_tokens", + "diversity_penalty", + "do_sample", + "early_stopping", + "encoder_no_repeat_ngram_size", + "exponential_decay_length_penalty", + "forced_bos_token_id", + "forced_eos_token_id", + "length_penalty", + "max_length", + "min_length", + "no_repeat_ngram_size", + "num_beam_groups", + "num_beams", + "num_return_sequences", + "output_scores", + "remove_invalid_values", + "repetition_penalty", + "return_dict_in_generate", + "suppress_tokens", + "temperature", + "top_k", + "top_p", + "typical_p", # Initialization / pretraining metadata Fast-LLM does not consume. "initializer_range", "max_position_embeddings", "pretraining_tp", + # Family markers / default-valued knobs serialized by recent transformers versions. + "is_llama_config", + "rope_interleaved", } ) From bde1efa903288e06bb4d67386a11811aeabf033c Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 27 May 2026 16:14:51 -0400 Subject: [PATCH 06/41] Reshape console table for readability Drop step / shape / max_rel columns, shorten the tensor name to the description after the colon, reorder to Tensor / Kind / Relative / Absolute / Max / Scale, format Relative as percent and the rest with `.3g`. The JSON report keeps every field. Co-Authored-By: Claude Opus 4.7 (1M context) --- tools/evaluate_precision.py | 31 +++++++++---------------------- 1 file changed, 9 insertions(+), 22 deletions(-) diff --git a/tools/evaluate_precision.py b/tools/evaluate_precision.py index 4c56848b6..0e7f70707 100644 --- a/tools/evaluate_precision.py +++ b/tools/evaluate_precision.py @@ -178,32 +178,19 @@ def _print_table(name: str, rows: list[dict[str, typing.Any]]) -> None: if not rows: print("(no matching tensors)") return - columns = [ - ("step", "step", 6), - ("kind", "kind", 6), - ("tensor_name", "tensor", 48), - ("shape", "shape", 22), - ("ref_scale", "ref_scale", 12), - ("rms_abs", "rms_abs", 12), - ("rms_rel", "rms_rel", 12), - ("max_abs", "max_abs", 12), - ("max_rel", "max_rel", 12), + columns: list[tuple[str, str, int, typing.Callable[[typing.Any], str]]] = [ + ("tensor_name", "Tensor", 28, lambda v: v.split(":", 1)[-1].strip()), + ("kind", "Kind", 4, str), + ("rms_rel", "Relative", 9, lambda v: f"{v * 100:.3g}%"), + ("rms_abs", "Absolute", 10, lambda v: f"{v:.3g}"), + ("max_abs", "Max", 10, lambda v: f"{v:.3g}"), + ("ref_scale", "Scale", 10, lambda v: f"{v:.3g}"), ] - header = " ".join(f"{title:<{width}}" for _, title, width in columns) + header = " ".join(f"{title:<{width}}" for _, title, width, _ in columns) print(header) print("-" * len(header)) for row in rows: - parts = [] - for key, _, width in columns: - value = row[key] - if isinstance(value, float): - cell = f"{value:.4e}" - elif isinstance(value, list): - cell = "x".join(str(x) for x in value) - else: - cell = str(value) - parts.append(f"{cell:<{width}}") - print(" ".join(parts)) + print(" ".join(f"{format_fn(row[key]):<{width}}" for key, _, width, format_fn in columns)) if __name__ == "__main__": From 8099b51cae606914c3a6c3c08b443165b061519d Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 27 May 2026 16:22:48 -0400 Subject: [PATCH 07/41] Merge tensor+kind, fix decimal precision in console table Drop the separate Kind column and append `(fw)` / `(bw)` to the shortened tensor name. Switch numeric formatting to fixed precision: Relative shows `.2f` percent, Absolute / Max / Scale show `.2e` scientific. Every column now lines up on a consistent digit count. Co-Authored-By: Claude Opus 4.7 (1M context) --- tools/evaluate_precision.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/tools/evaluate_precision.py b/tools/evaluate_precision.py index 0e7f70707..f74a31649 100644 --- a/tools/evaluate_precision.py +++ b/tools/evaluate_precision.py @@ -178,19 +178,18 @@ def _print_table(name: str, rows: list[dict[str, typing.Any]]) -> None: if not rows: print("(no matching tensors)") return - columns: list[tuple[str, str, int, typing.Callable[[typing.Any], str]]] = [ - ("tensor_name", "Tensor", 28, lambda v: v.split(":", 1)[-1].strip()), - ("kind", "Kind", 4, str), - ("rms_rel", "Relative", 9, lambda v: f"{v * 100:.3g}%"), - ("rms_abs", "Absolute", 10, lambda v: f"{v:.3g}"), - ("max_abs", "Max", 10, lambda v: f"{v:.3g}"), - ("ref_scale", "Scale", 10, lambda v: f"{v:.3g}"), + columns: list[tuple[str, int, typing.Callable[[dict[str, typing.Any]], str]]] = [ + ("Tensor", 26, lambda r: f"{r['tensor_name'].split(':', 1)[-1].strip()} ({r['kind']})"), + ("Relative", 8, lambda r: f"{r['rms_rel'] * 100:.2f}%"), + ("Absolute", 10, lambda r: f"{r['rms_abs']:.2e}"), + ("Max", 10, lambda r: f"{r['max_abs']:.2e}"), + ("Scale", 10, lambda r: f"{r['ref_scale']:.2e}"), ] - header = " ".join(f"{title:<{width}}" for _, title, width, _ in columns) + header = " ".join(f"{title:<{width}}" for title, width, _ in columns) print(header) print("-" * len(header)) for row in rows: - print(" ".join(f"{format_fn(row[key]):<{width}}" for key, _, width, format_fn in columns)) + print(" ".join(f"{format_fn(row):<{width}}" for _, width, format_fn in columns)) if __name__ == "__main__": From dbd7702f7db29e85a9a31021475c48f1bce508e8 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 27 May 2026 16:26:36 -0400 Subject: [PATCH 08/41] Switch back to fixed-decimal formatting in the table Scientific notation was overkill for values that mostly land between 0.01 and a few hundred. `.3f` is more readable while keeping the per-column digit count consistent. Co-Authored-By: Claude Opus 4.7 (1M context) --- tools/evaluate_precision.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tools/evaluate_precision.py b/tools/evaluate_precision.py index f74a31649..6cd30a224 100644 --- a/tools/evaluate_precision.py +++ b/tools/evaluate_precision.py @@ -181,9 +181,9 @@ def _print_table(name: str, rows: list[dict[str, typing.Any]]) -> None: columns: list[tuple[str, int, typing.Callable[[dict[str, typing.Any]], str]]] = [ ("Tensor", 26, lambda r: f"{r['tensor_name'].split(':', 1)[-1].strip()} ({r['kind']})"), ("Relative", 8, lambda r: f"{r['rms_rel'] * 100:.2f}%"), - ("Absolute", 10, lambda r: f"{r['rms_abs']:.2e}"), - ("Max", 10, lambda r: f"{r['max_abs']:.2e}"), - ("Scale", 10, lambda r: f"{r['ref_scale']:.2e}"), + ("Absolute", 10, lambda r: f"{r['rms_abs']:.3f}"), + ("Max", 10, lambda r: f"{r['max_abs']:.3f}"), + ("Scale", 10, lambda r: f"{r['ref_scale']:.3f}"), ] header = " ".join(f"{title:<{width}}" for title, width, _ in columns) print(header) From 152ffc36df08da66f301052286f9e93ed3d7f056 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 27 May 2026 16:43:19 -0400 Subject: [PATCH 09/41] Wipe per-variant experiment dir before each run MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fast-LLM's `Run.__init__` picks the next free `runs/` subdirectory based on what already exists, but `_artifact_path` reads `runs/0` unconditionally. Without this wipe, re-running the tool against the same `output_dir` reads stale artifacts from the first invocation and silently reports old numbers — even though the trainer correctly ran with the new config. Co-Authored-By: Claude Opus 4.7 (1M context) --- tools/evaluate_precision.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tools/evaluate_precision.py b/tools/evaluate_precision.py index 6cd30a224..a9a17b2e4 100644 --- a/tools/evaluate_precision.py +++ b/tools/evaluate_precision.py @@ -1,6 +1,7 @@ import json import logging import pathlib +import shutil import typing from fast_llm.config import Field, FieldHint, config_class @@ -89,6 +90,12 @@ def _artifact_path(self, name: str) -> pathlib.Path: return self.output_dir / name / "runs" / "0" / "artifacts" def _run_one(self, name: str, variant_overrides: dict[str, typing.Any]) -> None: + # The trainer's Run picks the next `runs/` subdir based on what already exists; wipe + # any prior contents so each invocation lands in `runs/0` and stale artifacts can't be + # read by `_artifact_path` below. + experiment_dir = self.output_dir / name + if experiment_dir.exists(): + shutil.rmtree(experiment_dir) # Base config: hardcoded training/optimizer/data/run skeleton plus the user's model/pretrained. # Forced fp32 on the reference baseline lives in here too so a variant can override it. base_dict: dict[str, typing.Any] = { From 7e98500d85c3f20d44f410d7d8cd07ca07a4abed Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 28 May 2026 12:04:45 -0400 Subject: [PATCH 10/41] Support pre-generated memmap dataset; misc table-format polish MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add a `data_path` field to the tool. When set, the tool lazily generates a tokenized memmap dataset with random advantages and old_logprobs at the given path (via the test helper `tests/utils/dataset._get_test_dataset`) and uses it as the training input. Required for policy-gradient losses like GSPO/GRPO that consume those fields. Without it, the tool falls back to the random token generator as before. Console table now formats numeric columns with `.4g` so 1e-7-scale GSPO gradients aren't rounded to zero while normal CE-magnitude values still read as fixed-point numbers. Rename `download_santacoder_tokenizer` to `download_test_tokenizer` — it actually downloads the GPT-2 tokenizer. Co-Authored-By: Claude Opus 4.7 (1M context) --- tests/data/test_tokenizer.py | 4 ++-- tests/models/test_lm_eval.py | 4 ++-- tests/utils/dataset.py | 4 ++-- tools/evaluate_precision.py | 39 ++++++++++++++++++++++++++++++++---- 4 files changed, 41 insertions(+), 10 deletions(-) diff --git a/tests/data/test_tokenizer.py b/tests/data/test_tokenizer.py index 184294551..04a24e2ae 100644 --- a/tests/data/test_tokenizer.py +++ b/tests/data/test_tokenizer.py @@ -2,13 +2,13 @@ from fast_llm.data.preparation.tokenizer import Tokenizer, TokenizerConfig from fast_llm.utils import Assert -from tests.utils.dataset import download_santacoder_tokenizer +from tests.utils.dataset import download_test_tokenizer from tests.utils.global_variables import TOKENIZER_PATH @pytest.fixture(scope="session") def common_tokenizer() -> Tokenizer: - download_santacoder_tokenizer() + download_test_tokenizer() return TokenizerConfig(path=TOKENIZER_PATH).get_tokenizer() diff --git a/tests/models/test_lm_eval.py b/tests/models/test_lm_eval.py index 7ae26c2d6..c8b5fd004 100644 --- a/tests/models/test_lm_eval.py +++ b/tests/models/test_lm_eval.py @@ -3,7 +3,7 @@ import pytest -from tests.utils.dataset import download_santacoder_tokenizer +from tests.utils.dataset import download_test_tokenizer from tests.utils.distributed_configs import DistributedTestingConfig from tests.utils.global_variables import TOKENIZER_PATH from tests.utils.model_configs import ModelTestingGroup @@ -15,7 +15,7 @@ @pytest.fixture(scope="module") def tokenizer_path(): - download_santacoder_tokenizer() + download_test_tokenizer() return TOKENIZER_PATH diff --git a/tests/utils/dataset.py b/tests/utils/dataset.py index a2ea2f46e..e7b206cf5 100644 --- a/tests/utils/dataset.py +++ b/tests/utils/dataset.py @@ -14,7 +14,7 @@ from tests.utils.global_variables import DATASET_CACHE, MODEL_TEST_VOCAB_SIZE, TOKENIZER_FILE, TOKENIZER_PATH -def download_santacoder_tokenizer(): +def download_test_tokenizer(): if not TOKENIZER_FILE.is_file(): import transformers @@ -218,7 +218,7 @@ def _get_test_dataset( if has_grpo_data: source_schema["advantages"] = "advantages" - download_santacoder_tokenizer() + download_test_tokenizer() preparator_config = GPTMemmapDatasetPreparatorConfig.from_dict( { "dataset": { diff --git a/tools/evaluate_precision.py b/tools/evaluate_precision.py index a9a17b2e4..55d01cc1f 100644 --- a/tools/evaluate_precision.py +++ b/tools/evaluate_precision.py @@ -60,6 +60,13 @@ class EvaluatePrecisionConfig(PretrainedGPTModelConfig, RunnableConfig): desc="Sequence length (maximum document length) for the random input.", hint=FieldHint.feature, ) + data_path: pathlib.Path | None = Field( + default=None, + desc="If set, prepare a tokenized memmap dataset with advantages and `old_log_probabilities`" + " at this path (using the test helper `_get_test_dataset`) and use it as the training" + " input — required for policy-gradient losses like GSPO/GRPO. If unset, uses random tokens.", + hint=FieldHint.feature, + ) def _validate(self) -> None: super()._validate() @@ -71,6 +78,7 @@ def _validate(self) -> None: def run(self) -> None: self.output_dir.mkdir(parents=True, exist_ok=True) + self._prepare_data() runs: dict[str, dict[str, typing.Any]] = {_REFERENCE_NAME: {}} runs.update(self.variants) for name, variant_overrides in runs.items(): @@ -86,6 +94,23 @@ def run(self) -> None: for name, rows in results.items(): _print_table(name, rows) + def _prepare_data(self) -> None: + if self.data_path is None: + return + if (self.data_path / "fast_llm_config.yaml").is_file(): + return + # Couples `tools/` to `tests/utils/` for now — extract later if it sticks. + from tests.utils.dataset import _get_test_dataset + + self.data_path.mkdir(parents=True, exist_ok=True) + logger.info(f"Preparing memmap dataset at {self.data_path}") + _get_test_dataset( + self.data_path, + seed=42, + has_grpo_data=True, + max_vocab_size=self.model.base_model.embeddings.vocab_size, + ) + def _artifact_path(self, name: str) -> pathlib.Path: return self.output_dir / name / "runs" / "0" / "artifacts" @@ -110,7 +135,13 @@ def _run_one(self, name: str, variant_overrides: dict[str, typing.Any]) -> None: "learning_rate": {"base": 0.0, "decay_style": "constant", "warmup_iterations": 0}, }, "data": { - "datasets": {"training": {"type": "random"}}, + "datasets": { + "training": ( + {"type": "file", "path": str(self.data_path / "fast_llm_config.yaml")} + if self.data_path is not None + else {"type": "random"} + ) + }, "micro_batch_size": self.micro_batch_size, "maximum_document_length": self.sequence_length, }, @@ -188,9 +219,9 @@ def _print_table(name: str, rows: list[dict[str, typing.Any]]) -> None: columns: list[tuple[str, int, typing.Callable[[dict[str, typing.Any]], str]]] = [ ("Tensor", 26, lambda r: f"{r['tensor_name'].split(':', 1)[-1].strip()} ({r['kind']})"), ("Relative", 8, lambda r: f"{r['rms_rel'] * 100:.2f}%"), - ("Absolute", 10, lambda r: f"{r['rms_abs']:.3f}"), - ("Max", 10, lambda r: f"{r['max_abs']:.3f}"), - ("Scale", 10, lambda r: f"{r['ref_scale']:.3f}"), + ("Absolute", 10, lambda r: f"{r['rms_abs']:.4g}"), + ("Max", 10, lambda r: f"{r['max_abs']:.4g}"), + ("Scale", 10, lambda r: f"{r['ref_scale']:.4g}"), ] header = " ".join(f"{title:<{width}}" for title, width, _ in columns) print(header) From 173ae0de6e3f6d9c5adf2ea6fe7fde8c6d5ca4f3 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 28 May 2026 12:19:20 -0400 Subject: [PATCH 11/41] Print per-variant summary at the end of the run After the per-tensor tables, emit a short summary block per variant showing first/last/max/median for forward and backward separately. Aggregates over the intermediate layers per metric column (max and median are computed per-column, so each row is a per-metric envelope of the intermediate band rather than the metrics of any single layer). Co-Authored-By: Claude Opus 4.7 (1M context) --- tools/evaluate_precision.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/tools/evaluate_precision.py b/tools/evaluate_precision.py index 55d01cc1f..beed05564 100644 --- a/tools/evaluate_precision.py +++ b/tools/evaluate_precision.py @@ -2,6 +2,7 @@ import logging import pathlib import shutil +import statistics import typing from fast_llm.config import Field, FieldHint, config_class @@ -93,6 +94,9 @@ def run(self) -> None: for name, rows in results.items(): _print_table(name, rows) + print("\n=== Summary ===") + for name, rows in results.items(): + _print_table(name, _summary_rows(rows)) def _prepare_data(self) -> None: if self.data_path is None: @@ -200,6 +204,26 @@ def _compare(self, ref_path: pathlib.Path, test_path: pathlib.Path) -> list[dict return rows +def _summary_rows(rows: list[dict[str, typing.Any]]) -> list[dict[str, typing.Any]]: + out: list[dict[str, typing.Any]] = [] + metric_keys = ("rms_rel", "rms_abs", "max_abs", "ref_scale") + for kind in ("fw", "bw"): + group = [r for r in rows if r["kind"] == kind] + if not group: + continue + first, last = group[0], group[-1] + intermediate = group[1:-1] + out.append({**first, "tensor_name": "first", "kind": kind}) + out.append({**last, "tensor_name": "last", "kind": kind}) + if intermediate: + for agg_name, agg in (("max", max), ("median", statistics.median)): + aggregated = {"tensor_name": agg_name, "kind": kind} + for key in metric_keys: + aggregated[key] = agg(r[key] for r in intermediate) + out.append(aggregated) + return out + + def _classify(tensor_name: str) -> str: # Stage._log_layer_forward / _log_layer_backward produce " fw[, mb=…]" # and " bw[, mb=…]"; log_distributed_tensor may prefix the name From 005fd6222b07e6be5751cea8e6134940eecbb9d9 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 28 May 2026 12:21:02 -0400 Subject: [PATCH 12/41] =?UTF-8?q?Reshape=20end-of-run=20summary:=20variant?= =?UTF-8?q?s=20=C3=97=20aggregations,=20relative=20only?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Single compact table with one row per variant and columns for fw/bw first/last/max/median Relative %. Max/median are over intermediate layers (excluding first/last) when there is at least one intermediate row. Co-Authored-By: Claude Opus 4.7 (1M context) --- tools/evaluate_precision.py | 46 ++++++++++++++++++++----------------- 1 file changed, 25 insertions(+), 21 deletions(-) diff --git a/tools/evaluate_precision.py b/tools/evaluate_precision.py index beed05564..3b40c7e17 100644 --- a/tools/evaluate_precision.py +++ b/tools/evaluate_precision.py @@ -94,9 +94,7 @@ def run(self) -> None: for name, rows in results.items(): _print_table(name, rows) - print("\n=== Summary ===") - for name, rows in results.items(): - _print_table(name, _summary_rows(rows)) + _print_summary(results) def _prepare_data(self) -> None: if self.data_path is None: @@ -204,24 +202,30 @@ def _compare(self, ref_path: pathlib.Path, test_path: pathlib.Path) -> list[dict return rows -def _summary_rows(rows: list[dict[str, typing.Any]]) -> list[dict[str, typing.Any]]: - out: list[dict[str, typing.Any]] = [] - metric_keys = ("rms_rel", "rms_abs", "max_abs", "ref_scale") - for kind in ("fw", "bw"): - group = [r for r in rows if r["kind"] == kind] - if not group: - continue - first, last = group[0], group[-1] - intermediate = group[1:-1] - out.append({**first, "tensor_name": "first", "kind": kind}) - out.append({**last, "tensor_name": "last", "kind": kind}) - if intermediate: - for agg_name, agg in (("max", max), ("median", statistics.median)): - aggregated = {"tensor_name": agg_name, "kind": kind} - for key in metric_keys: - aggregated[key] = agg(r[key] for r in intermediate) - out.append(aggregated) - return out +def _print_summary(results: dict[str, list[dict[str, typing.Any]]]) -> None: + columns = [(f"{kind} {agg}", kind, agg) for kind in ("fw", "bw") for agg in ("first", "last", "max", "median")] + name_width = max((len(name) for name in results), default=7) + 2 + cell_width = 10 + print("\n=== Summary (Relative %) ===") + header = f"{'Variant':<{name_width}}" + "".join(f"{h:<{cell_width}}" for h, _, _ in columns) + print(header) + print("-" * len(header)) + for name, rows in results.items(): + cells = [] + for _, kind, agg in columns: + group = [r["rms_rel"] for r in rows if r["kind"] == kind] + if not group: + cells.append("n/a") + continue + if agg == "first": + value = group[0] + elif agg == "last": + value = group[-1] + else: + intermediate = group[1:-1] or group + value = max(intermediate) if agg == "max" else statistics.median(intermediate) + cells.append(f"{value * 100:.2f}%") + print(f"{name:<{name_width}}" + "".join(f"{c:<{cell_width}}" for c in cells)) def _classify(tensor_name: str) -> str: From c59465889ca16e141c85035e270f33657b861920 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 28 May 2026 12:27:08 -0400 Subject: [PATCH 13/41] Clarify intermediate aggregation in summary header Rename `max`/`median` columns to `mid max`/`mid med` and add a header note (`mid = excluding first/last`) so it's clear the aggregation excludes the boundary layers. Also fix a column-collision bug where labels at exactly the cell width touched without separator. Co-Authored-By: Claude Opus 4.7 (1M context) --- tools/evaluate_precision.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/tools/evaluate_precision.py b/tools/evaluate_precision.py index 3b40c7e17..4feb10792 100644 --- a/tools/evaluate_precision.py +++ b/tools/evaluate_precision.py @@ -203,11 +203,12 @@ def _compare(self, ref_path: pathlib.Path, test_path: pathlib.Path) -> list[dict def _print_summary(results: dict[str, list[dict[str, typing.Any]]]) -> None: - columns = [(f"{kind} {agg}", kind, agg) for kind in ("fw", "bw") for agg in ("first", "last", "max", "median")] + agg_labels = {"first": "first", "last": "last", "max": "mid max", "median": "mid med"} + columns = [(f"{kind} {agg_labels[agg]}", kind, agg) for kind in ("fw", "bw") for agg in agg_labels] name_width = max((len(name) for name in results), default=7) + 2 - cell_width = 10 - print("\n=== Summary (Relative %) ===") - header = f"{'Variant':<{name_width}}" + "".join(f"{h:<{cell_width}}" for h, _, _ in columns) + cell_width = max(len(label) for label, _, _ in columns) + 1 + print("\n=== Summary (Relative %; mid = excluding first/last) ===") + header = f"{'Variant':<{name_width}}" + " ".join(f"{h:<{cell_width}}" for h, _, _ in columns) print(header) print("-" * len(header)) for name, rows in results.items(): @@ -225,7 +226,7 @@ def _print_summary(results: dict[str, list[dict[str, typing.Any]]]) -> None: intermediate = group[1:-1] or group value = max(intermediate) if agg == "max" else statistics.median(intermediate) cells.append(f"{value * 100:.2f}%") - print(f"{name:<{name_width}}" + "".join(f"{c:<{cell_width}}" for c in cells)) + print(f"{name:<{name_width}}" + " ".join(f"{c:<{cell_width}}" for c in cells)) def _classify(tensor_name: str) -> str: From 3159f73efeb99564d6c42df550c4ee8537e6df94 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 28 May 2026 12:31:25 -0400 Subject: [PATCH 14/41] Split summary across fw/bw rows; one extra precision digit Each variant now occupies two rows in the summary (fw on the first, bw on the second), with the metric columns shared. Reads more naturally and keeps the table half as wide. Percent precision goes from .2f to .3f so single-digit-percent differences between variants are visible. Co-Authored-By: Claude Opus 4.7 (1M context) --- tools/evaluate_precision.py | 35 ++++++++++++++++++----------------- 1 file changed, 18 insertions(+), 17 deletions(-) diff --git a/tools/evaluate_precision.py b/tools/evaluate_precision.py index 4feb10792..7b4d1fd05 100644 --- a/tools/evaluate_precision.py +++ b/tools/evaluate_precision.py @@ -204,29 +204,30 @@ def _compare(self, ref_path: pathlib.Path, test_path: pathlib.Path) -> list[dict def _print_summary(results: dict[str, list[dict[str, typing.Any]]]) -> None: agg_labels = {"first": "first", "last": "last", "max": "mid max", "median": "mid med"} - columns = [(f"{kind} {agg_labels[agg]}", kind, agg) for kind in ("fw", "bw") for agg in agg_labels] name_width = max((len(name) for name in results), default=7) + 2 - cell_width = max(len(label) for label, _, _ in columns) + 1 + cell_width = max(len(label) for label in agg_labels.values()) + 2 print("\n=== Summary (Relative %; mid = excluding first/last) ===") - header = f"{'Variant':<{name_width}}" + " ".join(f"{h:<{cell_width}}" for h, _, _ in columns) + header = f"{'Variant':<{name_width}}{'':<4}" + " ".join(f"{label:<{cell_width}}" for label in agg_labels.values()) print(header) print("-" * len(header)) for name, rows in results.items(): - cells = [] - for _, kind, agg in columns: + for index, kind in enumerate(("fw", "bw")): group = [r["rms_rel"] for r in rows if r["kind"] == kind] - if not group: - cells.append("n/a") - continue - if agg == "first": - value = group[0] - elif agg == "last": - value = group[-1] - else: - intermediate = group[1:-1] or group - value = max(intermediate) if agg == "max" else statistics.median(intermediate) - cells.append(f"{value * 100:.2f}%") - print(f"{name:<{name_width}}" + " ".join(f"{c:<{cell_width}}" for c in cells)) + cells = [] + for agg in agg_labels: + if not group: + cells.append("n/a") + continue + if agg == "first": + value = group[0] + elif agg == "last": + value = group[-1] + else: + intermediate = group[1:-1] or group + value = max(intermediate) if agg == "max" else statistics.median(intermediate) + cells.append(f"{value * 100:.3f}%") + name_cell = name if index == 0 else "" + print(f"{name_cell:<{name_width}}{kind:<4}" + " ".join(f"{c:<{cell_width}}" for c in cells)) def _classify(tensor_name: str) -> str: From 6ef153e154ed86cffb609307c7b3e32febf60bf9 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 28 May 2026 12:34:20 -0400 Subject: [PATCH 15/41] Two-row column header in summary; chronological column order MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Top header line groups columns under `fw` / `bw`; the second line lists the per-pass aggregations. Aggregations are ordered chronologically along the pass — first → mid med → mid max → last — so reading left to right traces the propagation. Co-Authored-By: Claude Opus 4.7 (1M context) --- tools/evaluate_precision.py | 42 ++++++++++++++++++++++++------------- 1 file changed, 27 insertions(+), 15 deletions(-) diff --git a/tools/evaluate_precision.py b/tools/evaluate_precision.py index 7b4d1fd05..f5d3d8e4e 100644 --- a/tools/evaluate_precision.py +++ b/tools/evaluate_precision.py @@ -203,31 +203,43 @@ def _compare(self, ref_path: pathlib.Path, test_path: pathlib.Path) -> list[dict def _print_summary(results: dict[str, list[dict[str, typing.Any]]]) -> None: - agg_labels = {"first": "first", "last": "last", "max": "mid max", "median": "mid med"} + # Chronological column order: first → intermediate (median, max) → last. + aggs = ("first", "median", "max", "last") + agg_labels = {"first": "first", "median": "mid med", "max": "mid max", "last": "last"} + kinds = ("fw", "bw") name_width = max((len(name) for name in results), default=7) + 2 - cell_width = max(len(label) for label in agg_labels.values()) + 2 + cell_width = max(len(label) for label in agg_labels.values()) + 1 + group_sep = " " + group_width = len(" ".join(f"{agg_labels[a]:<{cell_width}}" for a in aggs)) print("\n=== Summary (Relative %; mid = excluding first/last) ===") - header = f"{'Variant':<{name_width}}{'':<4}" + " ".join(f"{label:<{cell_width}}" for label in agg_labels.values()) - print(header) - print("-" * len(header)) + top = f"{'':<{name_width}}" + group_sep.join(f"{kind:^{group_width}}" for kind in kinds) + bottom = f"{'Variant':<{name_width}}" + group_sep.join( + " ".join(f"{agg_labels[a]:<{cell_width}}" for a in aggs) for _ in kinds + ) + print(top) + print(bottom) + print("-" * len(bottom)) for name, rows in results.items(): - for index, kind in enumerate(("fw", "bw")): - group = [r["rms_rel"] for r in rows if r["kind"] == kind] + groups = [] + for kind in kinds: + values = [r["rms_rel"] for r in rows if r["kind"] == kind] + intermediate = values[1:-1] or values cells = [] - for agg in agg_labels: - if not group: + for agg in aggs: + if not values: cells.append("n/a") continue if agg == "first": - value = group[0] + value = values[0] elif agg == "last": - value = group[-1] + value = values[-1] + elif agg == "max": + value = max(intermediate) else: - intermediate = group[1:-1] or group - value = max(intermediate) if agg == "max" else statistics.median(intermediate) + value = statistics.median(intermediate) cells.append(f"{value * 100:.3f}%") - name_cell = name if index == 0 else "" - print(f"{name_cell:<{name_width}}{kind:<4}" + " ".join(f"{c:<{cell_width}}" for c in cells)) + groups.append(" ".join(f"{c:<{cell_width}}" for c in cells)) + print(f"{name:<{name_width}}" + group_sep.join(groups)) def _classify(tensor_name: str) -> str: From 7327932e47f781405e7e02862fe52662780ea74b Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 28 May 2026 12:40:20 -0400 Subject: [PATCH 16/41] Add fp32_lm_head flag for vLLM precision parity MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds an `fp32_lm_head` field on `LanguageModelHeadConfig`. When `True`, the LM head linear's input and weight are upcast to FP32 before the matmul, matching vLLM's `bf16_last_layer_fp32` quantization. This lets the trainer compute log-probabilities at the same numerical precision as the actor's sampling, so the importance-sampling ratio starts near 1.0 instead of being artificially inflated by a trainer/actor precision mismatch. The detached FP32 weight has `requires_grad=False`, which makes `output_parallel_linear_backward` skip the weight-grad path. The FSDP gradient contract is restored by computing `grad_weight = grad.t() @ saved_input` explicitly and accumulating into the original BF16 param's `grad_buffer` via `accumulate_gradient`. Off by default — disabled path is byte-identical to before. Cherry-picked from #526 to unblock the precision-evaluation tool's GSPO smoke test, which compares fp32_lm_head=true vs false. Co-Authored-By: Claude Opus 4.7 (1M context) --- fast_llm/layers/language_model/config.py | 7 +++++ fast_llm/layers/language_model/head.py | 34 +++++++++++++++++++----- 2 files changed, 35 insertions(+), 6 deletions(-) diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index bde33f297..6a0bfcfd6 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -131,6 +131,13 @@ class LanguageModelHeadConfig(BlockConfig): hint=FieldHint.architecture, valid=skip_valid_if_none(check_field(Assert.gt, 0)), ) + fp32_lm_head: bool = Field( + default=False, + desc="Upcast input and weight to float32 before the lm_head linear. " + "Matches vLLM's bf16_last_layer_fp32 quantization so new_logprobs and old_logprobs " + "are computed at the same numerical precision, keeping the IS ratio near 1 at init.", + hint=FieldHint.feature, + ) prediction_heads: int = Field( default=1, desc="Prediction heads.", diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 22c750082..eb67cd553 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -22,7 +22,7 @@ ) from fast_llm.layers.language_model.loss.config import LanguageModelLabelEntropyLossConfig from fast_llm.layers.language_model.loss.loss import LanguageModelLoss -from fast_llm.tensor import TensorMeta +from fast_llm.tensor import TensorMeta, accumulate_gradient from fast_llm.utils import Assert, safe_merge_dicts logger = logging.getLogger(__name__) @@ -252,9 +252,17 @@ def _logits_loss_forward_backward_partial( split_index: int = 0, return_logits: bool = False, ) -> tuple[torch.Tensor | None, torch.Tensor | None]: + if self._config.fp32_lm_head: + input_dtype = input_.dtype + input_ = input_.to(torch.float32) + # detach → requires_grad=False → output_parallel_linear_backward skips weight grad + weight = self.output_weights.detach().to(torch.float32) + else: + weight = self.output_weights + logits, context = output_parallel_linear_forward( input_=input_, - weight=self.output_weights, + weight=weight, bias=None, group=self._parallel_dim.group if self._vocab_parallel else None, sequence_parallel=self._sequence_parallel and self._vocab_parallel, @@ -285,12 +293,26 @@ def _logits_loss_forward_backward_partial( if loss_value is not None: losses_.append(loss_value.detach()) - if grad is not None and self._config.final_logit_softcap is not None: + if not self.training or grad is None: + return sum(losses_) if losses_ else None, None + + if self._config.final_logit_softcap is not None: grad = _softcap_backward(grad, logits, self._config.final_logit_softcap) - return sum(losses_) if losses_ else None, ( - output_parallel_linear_backward(grad, context) if self.training else None - ) + input_grad = output_parallel_linear_backward(grad, context) + if self._config.fp32_lm_head: + # Weight grad was skipped because weight.requires_grad=False; accumulate manually. + # context: (input_, weight, bias, group, sequence_parallel, ...) + saved_input = context[0] + if context[4]: # sequence_parallel + from fast_llm.core.ops import gather_op + + saved_input = gather_op(saved_input, context[3], dim=0) + grad_weight = grad.flatten(0, -2).t().mm(saved_input.flatten(0, -2)) + accumulate_gradient(self.output_weights, grad_weight.to(self.output_weights.dtype)) + input_grad = input_grad.to(input_dtype) + + return sum(losses_) if losses_ else None, input_grad def get_loss_definitions(self) -> list[LossDef]: return [ From 76335dffd75b96c6baa7d008301bf87d2e7c5170 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 28 May 2026 13:06:53 -0400 Subject: [PATCH 17/41] Extract layer-name labels for summary first/last columns MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Instead of generic `first` / `last` headers in the summary, use the actual layer name pulled from the matching tensor's `Global :` prefix. For the SmolLM2 smoke run that surfaces as `embeddings` / `head` on fw and `head` / `decoder.0` on bw — directly showing which layer the boundary values come from rather than making the reader guess. Co-Authored-By: Claude Opus 4.7 (1M context) --- tools/evaluate_precision.py | 38 ++++++++++++++++++++++++++++++++----- 1 file changed, 33 insertions(+), 5 deletions(-) diff --git a/tools/evaluate_precision.py b/tools/evaluate_precision.py index f5d3d8e4e..e27909333 100644 --- a/tools/evaluate_precision.py +++ b/tools/evaluate_precision.py @@ -202,19 +202,47 @@ def _compare(self, ref_path: pathlib.Path, test_path: pathlib.Path) -> list[dict return rows +def _layer_name(tensor_name: str) -> str: + # Stage hooks name tensors `Global fw: ...` / `Global bw: ...`; + # extract the layer to use as a meaningful column label. + prefix = tensor_name.split(":", 1)[0].strip().split() + if prefix and prefix[0] == "Global": + prefix = prefix[1:] + if prefix and prefix[-1] in ("fw", "bw"): + prefix = prefix[:-1] + return " ".join(prefix) if prefix else "?" + + def _print_summary(results: dict[str, list[dict[str, typing.Any]]]) -> None: # Chronological column order: first → intermediate (median, max) → last. aggs = ("first", "median", "max", "last") - agg_labels = {"first": "first", "median": "mid med", "max": "mid max", "last": "last"} + # Per-pass labels for `first`/`last` come from the actual layer name on the matching row. + sample = next(iter(results.values())) + endpoint_labels: dict[tuple[str, str], str] = { + ("fw", "first"): "first", + ("fw", "last"): "last", + ("bw", "first"): "first", + ("bw", "last"): "last", + } + for kind in ("fw", "bw"): + group = [r for r in sample if r["kind"] == kind] + if group: + endpoint_labels[(kind, "first")] = _layer_name(group[0]["tensor_name"]) + endpoint_labels[(kind, "last")] = _layer_name(group[-1]["tensor_name"]) + mid_labels = {"median": "mid med", "max": "mid max"} kinds = ("fw", "bw") + + def _label(kind: str, agg: str) -> str: + return endpoint_labels[(kind, agg)] if agg in ("first", "last") else mid_labels[agg] + name_width = max((len(name) for name in results), default=7) + 2 - cell_width = max(len(label) for label in agg_labels.values()) + 1 + cell_width = max(len(_label(k, a)) for k in kinds for a in aggs) + 1 group_sep = " " - group_width = len(" ".join(f"{agg_labels[a]:<{cell_width}}" for a in aggs)) + group_widths = {kind: len(" ".join(f"{_label(kind, a):<{cell_width}}" for a in aggs)) for kind in kinds} print("\n=== Summary (Relative %; mid = excluding first/last) ===") - top = f"{'':<{name_width}}" + group_sep.join(f"{kind:^{group_width}}" for kind in kinds) + top = f"{'':<{name_width}}" + group_sep.join(f"{kind:^{group_widths[kind]}}" for kind in kinds) bottom = f"{'Variant':<{name_width}}" + group_sep.join( - " ".join(f"{agg_labels[a]:<{cell_width}}" for a in aggs) for _ in kinds + " ".join(f"{_label(kind, a):<{cell_width}}" for a in aggs) for kind in kinds ) print(top) print(bottom) From 8122946df08040dd3cc2ab2640b0cf18d5f1b619 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 28 May 2026 13:24:29 -0400 Subject: [PATCH 18/41] Add `debug_hidden_states_log` to capture named tensors via output_hidden_states MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Previously the only way to get a non-layer-output tensor (e.g. the LM head's logits) into `tensor_logs` was to crank `model_debug_level`, which logs every single `_debug`-emitted tensor (~700 per step for a 30-layer model). Add a `MultiStageConfig.debug_hidden_states_log: list[str]` field — regex patterns that get appended to each model input's `output_hidden_states` set. Matching tensors are still populated into `kwargs[hidden_states]` (existing contract for the HF inference wrapper); now they're also written to `tensor_logs` so the precision tool can compare them across variants. `_debug` already had the `output_hidden_state`-matched branch but only used it to populate `kwargs[hidden_states]`. Extending it to also call `log_distributed_tensor` at a fixed verbosity (13, matching the test convention so samples are recorded) is a small gating change. Plumbed through `GPTModel.get_preprocessing_config` → `LanguageModelBatchPreprocessingConfig.output_hidden_states` → `LanguageModelBatch.get_model_inputs`, which compiles the patterns and unions them into each `LanguageModelInput.output_hidden_states`. The precision tool now sets `[r"head\.logits"]` and surfaces logits as a dedicated `logits` column on the fw side of the summary table. Co-Authored-By: Claude Opus 4.7 (1M context) --- fast_llm/data/document/config.py | 6 +++++ fast_llm/data/document/language_model.py | 7 +++++ fast_llm/engine/multi_stage/config.py | 8 ++++++ fast_llm/layers/block/block.py | 15 ++++++++--- fast_llm/models/gpt/model.py | 1 + tools/evaluate_precision.py | 34 +++++++++++++++++++----- 6 files changed, 61 insertions(+), 10 deletions(-) diff --git a/fast_llm/data/document/config.py b/fast_llm/data/document/config.py index a90bcdebc..fbfe60ac3 100644 --- a/fast_llm/data/document/config.py +++ b/fast_llm/data/document/config.py @@ -80,6 +80,12 @@ class LanguageModelBatchPreprocessingConfig(TokenPreprocessingConfig): use_preference_spans: bool = Field(default=False) use_grpo_data: bool = Field(default=False) return_label_counts: bool = Field(default=False) + output_hidden_states: list[str] = Field( + default_factory=list, + desc="Regex patterns to add to each model input's `output_hidden_states` set." + " Matching `_debug`-named tensors get populated into `kwargs[hidden_states]`" + " and (when running under a `Run` context) emitted into `tensor_logs`.", + ) def _validate(self) -> None: super()._validate() diff --git a/fast_llm/data/document/language_model.py b/fast_llm/data/document/language_model.py index 16114cb80..000fcc01d 100644 --- a/fast_llm/data/document/language_model.py +++ b/fast_llm/data/document/language_model.py @@ -161,6 +161,13 @@ def get_model_inputs(self, config: LanguageModelBatchPreprocessingConfig) -> lis self._set_target_inputs(model_inputs, config) + if config.output_hidden_states: + import re + + patterns = {re.compile(pattern) for pattern in config.output_hidden_states} + for model_input in model_inputs: + model_input.output_hidden_states.update(patterns) + return model_inputs def _set_target_inputs( diff --git a/fast_llm/engine/multi_stage/config.py b/fast_llm/engine/multi_stage/config.py index 958a3d228..96cb52f09 100644 --- a/fast_llm/engine/multi_stage/config.py +++ b/fast_llm/engine/multi_stage/config.py @@ -139,6 +139,14 @@ class StageConfig(Config): desc="Check for tensor-parallel desyncs and log an error if a desync is found. High overhead", hint=FieldHint.logging, ) + debug_hidden_states_log: list[str] = Field( + default_factory=list, + desc="Regex patterns for `_debug`-named tensors (`.`, e.g. `head.logits`," + " `decoder.0.norm_1`) to log to `tensor_logs`. Patterns are appended to each model" + " input's `output_hidden_states` set, so matching tensors are both populated into" + " `kwargs[hidden_states]` for downstream consumers and emitted into `tensor_logs`.", + hint=FieldHint.logging, + ) @config_class() diff --git a/fast_llm/layers/block/block.py b/fast_llm/layers/block/block.py index 805eae1e5..0476a8107 100644 --- a/fast_llm/layers/block/block.py +++ b/fast_llm/layers/block/block.py @@ -18,6 +18,12 @@ logger = logging.getLogger(__name__) +# Verbosity used for `output_hidden_states`-driven tensor logging. `log_tensor` collects sampled +# tensor values only at level >= 3; 13 matches the convention in the layer-comparison tests +# (1024 sampled values per tensor). +_HIDDEN_STATE_LOG_LEVEL = 13 + + class DebugLayer: """ A debugging utility for blocks. @@ -55,11 +61,14 @@ def __call__( if level > 1: log_pipeline_parallel_main_rank(lambda: log_memory_usage(name, str)) - if level > 0 and tensor is not None: + # `output_hidden_state` requests full-fidelity capture even when `model_debug_level` is + # off — clamp the log level so samples are saved alongside summary stats. + log_level = max(level, _HIDDEN_STATE_LOG_LEVEL) if output_hidden_state else level + if log_level > 0 and tensor is not None: log_distributed_tensor( "", tensor, - level=level, + level=log_level, meta=meta, **logging_kwargs, ) @@ -67,7 +76,7 @@ def __call__( log_distributed_grad( "", tensor, - level=level, + level=log_level, meta=self._get_meta(tensor, f"{name}.grad", dims), **logging_kwargs, ) diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 2e9b4365b..f4d4b286a 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -112,6 +112,7 @@ def get_preprocessing_config( return LanguageModelBatchPreprocessingConfig( phase=phase, micro_batch_splits=micro_batch_splits, + output_hidden_states=list(self._config.multi_stage.debug_hidden_states_log), **self._base_model.get_preprocessing_config(), ) diff --git a/tools/evaluate_precision.py b/tools/evaluate_precision.py index e27909333..afcb818c0 100644 --- a/tools/evaluate_precision.py +++ b/tools/evaluate_precision.py @@ -161,6 +161,9 @@ def _run_one(self, name: str, variant_overrides: dict[str, typing.Any]) -> None: tool_overrides = { ("model", "multi_stage", "debug_layer_outputs"): _LOG_LEVEL, ("model", "multi_stage", "debug_layer_gradients"): _LOG_LEVEL, + # Capture the LM-head logits via the `output_hidden_states` mechanism: the head's + # `_debug(logits, ...)` call matches this pattern and emits to `tensor_logs`. + ("model", "multi_stage", "debug_hidden_states_log"): [r"head\.logits"], } logger.info(f"=== Running {name!r} ===") if variant_overrides: @@ -213,9 +216,14 @@ def _layer_name(tensor_name: str) -> str: return " ".join(prefix) if prefix else "?" +def _logits_row(rows: list[dict[str, typing.Any]]) -> dict[str, typing.Any] | None: + return next( + (r for r in rows if r["tensor_name"].split(":", 1)[-1].strip() == "head.logits"), + None, + ) + + def _print_summary(results: dict[str, list[dict[str, typing.Any]]]) -> None: - # Chronological column order: first → intermediate (median, max) → last. - aggs = ("first", "median", "max", "last") # Per-pass labels for `first`/`last` come from the actual layer name on the matching row. sample = next(iter(results.values())) endpoint_labels: dict[tuple[str, str], str] = { @@ -229,31 +237,41 @@ def _print_summary(results: dict[str, list[dict[str, typing.Any]]]) -> None: if group: endpoint_labels[(kind, "first")] = _layer_name(group[0]["tensor_name"]) endpoint_labels[(kind, "last")] = _layer_name(group[-1]["tensor_name"]) - mid_labels = {"median": "mid med", "max": "mid max"} + mid_labels = {"median": "mid med", "max": "mid max", "logits": "logits"} + # Logits show up on the fw side via `output_hidden_states` ("Global : head.logits"); + # add a dedicated column for it (chronologically just before the head output / loss). + has_logits = _logits_row(sample) is not None + aggs_per_kind = { + "fw": ("first", "median", "max", "logits", "last") if has_logits else ("first", "median", "max", "last"), + "bw": ("first", "median", "max", "last"), + } kinds = ("fw", "bw") def _label(kind: str, agg: str) -> str: return endpoint_labels[(kind, agg)] if agg in ("first", "last") else mid_labels[agg] name_width = max((len(name) for name in results), default=7) + 2 - cell_width = max(len(_label(k, a)) for k in kinds for a in aggs) + 1 + cell_width = max(len(_label(k, a)) for k in kinds for a in aggs_per_kind[k]) + 1 group_sep = " " - group_widths = {kind: len(" ".join(f"{_label(kind, a):<{cell_width}}" for a in aggs)) for kind in kinds} + group_widths = { + kind: len(" ".join(f"{_label(kind, a):<{cell_width}}" for a in aggs_per_kind[kind])) for kind in kinds + } print("\n=== Summary (Relative %; mid = excluding first/last) ===") top = f"{'':<{name_width}}" + group_sep.join(f"{kind:^{group_widths[kind]}}" for kind in kinds) bottom = f"{'Variant':<{name_width}}" + group_sep.join( - " ".join(f"{_label(kind, a):<{cell_width}}" for a in aggs) for kind in kinds + " ".join(f"{_label(kind, a):<{cell_width}}" for a in aggs_per_kind[kind]) for kind in kinds ) print(top) print(bottom) print("-" * len(bottom)) for name, rows in results.items(): + logits_value = _logits_row(rows)["rms_rel"] if _logits_row(rows) else float("nan") groups = [] for kind in kinds: values = [r["rms_rel"] for r in rows if r["kind"] == kind] intermediate = values[1:-1] or values cells = [] - for agg in aggs: + for agg in aggs_per_kind[kind]: if not values: cells.append("n/a") continue @@ -261,6 +279,8 @@ def _label(kind: str, agg: str) -> str: value = values[0] elif agg == "last": value = values[-1] + elif agg == "logits": + value = logits_value elif agg == "max": value = max(intermediate) else: From 4633bfde1d45ff78b8b3674d9678cac53d0dc0e4 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 28 May 2026 13:29:42 -0400 Subject: [PATCH 19/41] Capture logit gradients; expose them in the summary MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The head's `logits` tensor has `requires_grad=False` (output of a custom-autograd Function), so the existing `_debug(logits, ...)` could only capture the forward value. Add a second `_debug(grad, "logits.grad", ...)` call right after the loss returns the explicit `dL/d_logits` so the gradient is captured at the same fidelity. With the precision tool's `output_hidden_states` pattern `r"head\.logits"`, both `head.logits` and `head.logits.grad` end up in tensor_logs. Tool summary surfaces both via dedicated `logits` columns — placed at end-of-fw and start-of-bw chronologically. For GSPO the bw-logits column reveals that the dL/dlogits computation itself is extremely precise (~0.001% relative error), and the apparent backward noise actually enters through the head matmul further downstream. Co-Authored-By: Claude Opus 4.7 (1M context) --- fast_llm/layers/language_model/head.py | 12 +++++++++ tools/evaluate_precision.py | 35 +++++++++++++++----------- 2 files changed, 32 insertions(+), 15 deletions(-) diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index eb67cd553..8dd511480 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -293,6 +293,18 @@ def _logits_loss_forward_backward_partial( if loss_value is not None: losses_.append(loss_value.detach()) + if grad is not None: + # `logits` has `requires_grad=False` (custom-autograd), so the existing + # `_debug(logits, ...)` can't auto-capture the gradient. Log it explicitly here + # so `output_hidden_states` patterns covering `head.logits` also catch the grad. + self._debug( + grad, + f"logits.grad{"" if self._config.cross_entropy_splits == 1 else f"_{split_index}"}", + (kwargs.get(LanguageModelKwargs.hidden_token_dim), self._vocab_dim), + kwargs, + scale=self._config.logits_scale_factor, + ) + if not self.training or grad is None: return sum(losses_) if losses_ else None, None diff --git a/tools/evaluate_precision.py b/tools/evaluate_precision.py index afcb818c0..61d612895 100644 --- a/tools/evaluate_precision.py +++ b/tools/evaluate_precision.py @@ -216,11 +216,8 @@ def _layer_name(tensor_name: str) -> str: return " ".join(prefix) if prefix else "?" -def _logits_row(rows: list[dict[str, typing.Any]]) -> dict[str, typing.Any] | None: - return next( - (r for r in rows if r["tensor_name"].split(":", 1)[-1].strip() == "head.logits"), - None, - ) +def _named_row(rows: list[dict[str, typing.Any]], name: str) -> dict[str, typing.Any] | None: + return next((r for r in rows if r["tensor_name"].split(":", 1)[-1].strip() == name), None) def _print_summary(results: dict[str, list[dict[str, typing.Any]]]) -> None: @@ -238,13 +235,14 @@ def _print_summary(results: dict[str, list[dict[str, typing.Any]]]) -> None: endpoint_labels[(kind, "first")] = _layer_name(group[0]["tensor_name"]) endpoint_labels[(kind, "last")] = _layer_name(group[-1]["tensor_name"]) mid_labels = {"median": "mid med", "max": "mid max", "logits": "logits"} - # Logits show up on the fw side via `output_hidden_states` ("Global : head.logits"); - # add a dedicated column for it (chronologically just before the head output / loss). - has_logits = _logits_row(sample) is not None - aggs_per_kind = { - "fw": ("first", "median", "max", "logits", "last") if has_logits else ("first", "median", "max", "last"), - "bw": ("first", "median", "max", "last"), - } + # Logits show up via `output_hidden_states` (`Global : head.logits` on the fw side and + # `Global : head.logits.grad` on the bw side once the loss has computed dL/dlogits). + # Each gets a dedicated column placed chronologically: end-of-fw and start-of-bw. + has_fw_logits = _named_row(sample, "head.logits") is not None + has_bw_logits = _named_row(sample, "head.logits.grad") is not None + fw_aggs = ("first", "median", "max") + (("logits",) if has_fw_logits else ()) + ("last",) + bw_aggs = (("logits",) if has_bw_logits else ()) + ("first", "median", "max", "last") + aggs_per_kind = {"fw": fw_aggs, "bw": bw_aggs} kinds = ("fw", "bw") def _label(kind: str, agg: str) -> str: @@ -265,7 +263,12 @@ def _label(kind: str, agg: str) -> str: print(bottom) print("-" * len(bottom)) for name, rows in results.items(): - logits_value = _logits_row(rows)["rms_rel"] if _logits_row(rows) else float("nan") + logits_fw = _named_row(rows, "head.logits") + logits_bw = _named_row(rows, "head.logits.grad") + logits_value = { + "fw": logits_fw["rms_rel"] if logits_fw else float("nan"), + "bw": logits_bw["rms_rel"] if logits_bw else float("nan"), + } groups = [] for kind in kinds: values = [r["rms_rel"] for r in rows if r["kind"] == kind] @@ -280,7 +283,7 @@ def _label(kind: str, agg: str) -> str: elif agg == "last": value = values[-1] elif agg == "logits": - value = logits_value + value = logits_value[kind] elif agg == "max": value = max(intermediate) else: @@ -294,7 +297,9 @@ def _classify(tensor_name: str) -> str: # Stage._log_layer_forward / _log_layer_backward produce " fw[, mb=…]" # and " bw[, mb=…]"; log_distributed_tensor may prefix the name # with "Global " and append a ": " suffix when reconstructing a - # tensor-parallel-global tensor. + # tensor-parallel-global tensor. Other entries (e.g. `Global : head.logits`, + # `Global : head.logits.grad`) come from the `_debug` / `output_hidden_states` path + # and are surfaced via dedicated logits columns in the summary. for kind in ("fw", "bw"): if f" {kind}:" in tensor_name or f" {kind}," in tensor_name or tensor_name.endswith(f" {kind}"): return kind From 9ca17115835b53b26a79da21a7cb4b9f188fb6ba Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 28 May 2026 13:35:48 -0400 Subject: [PATCH 20/41] Place logits after head in bw summary; widen format for sub-percent values MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `.3f%` was rounding the bw-logits values down to 0.001%-0.000%, hiding real signal. Switch to `.4g%` so values across 5 orders of magnitude (0.0001% to ~20%) all render with meaningful precision; large values keep 4 significant figures, tiny ones spell out their leading non-zero digits or fall back to scientific. Bw column order is now first / logits / mid med / mid max / last so `logits` sits right after `head` (the first bw row) — semantically the gradient at logits is what the head's backward consumes before producing the gradient at its input. Co-Authored-By: Claude Opus 4.7 (1M context) --- tools/evaluate_precision.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tools/evaluate_precision.py b/tools/evaluate_precision.py index 61d612895..5f23d206e 100644 --- a/tools/evaluate_precision.py +++ b/tools/evaluate_precision.py @@ -241,7 +241,7 @@ def _print_summary(results: dict[str, list[dict[str, typing.Any]]]) -> None: has_fw_logits = _named_row(sample, "head.logits") is not None has_bw_logits = _named_row(sample, "head.logits.grad") is not None fw_aggs = ("first", "median", "max") + (("logits",) if has_fw_logits else ()) + ("last",) - bw_aggs = (("logits",) if has_bw_logits else ()) + ("first", "median", "max", "last") + bw_aggs = ("first",) + (("logits",) if has_bw_logits else ()) + ("median", "max", "last") aggs_per_kind = {"fw": fw_aggs, "bw": bw_aggs} kinds = ("fw", "bw") @@ -288,7 +288,7 @@ def _label(kind: str, agg: str) -> str: value = max(intermediate) else: value = statistics.median(intermediate) - cells.append(f"{value * 100:.3f}%") + cells.append(f"{value * 100:.4g}%") groups.append(" ".join(f"{c:<{cell_width}}" for c in cells)) print(f"{name:<{name_width}}" + group_sep.join(groups)) From f2655f39223ea13555bf9b7aec4efd96dbcc5eac Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 28 May 2026 13:40:11 -0400 Subject: [PATCH 21/41] =?UTF-8?q?Pick=20per-column=20decimals=20to=20guara?= =?UTF-8?q?ntee=20=E2=89=A52=20sig=20figs?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Keep the prior `.3f%` default in the summary so most columns still show `0.000%` / `12.672%` style values, but compute a per-column decimal count based on the smallest non-zero value in that column — bumping up just enough that every cell carries at least two significant figures. Decimal count is uniform within a column. For the GSPO run, only the bw-logits column hits the threshold and gets bumped from 3 to 5 decimals, surfacing values like `0.00095%` that previously rounded to `0.001%` or worse. Co-Authored-By: Claude Opus 4.7 (1M context) --- tools/evaluate_precision.py | 49 +++++++++++++++++++++++++++++-------- 1 file changed, 39 insertions(+), 10 deletions(-) diff --git a/tools/evaluate_precision.py b/tools/evaluate_precision.py index 5f23d206e..6f6d86657 100644 --- a/tools/evaluate_precision.py +++ b/tools/evaluate_precision.py @@ -1,5 +1,6 @@ import json import logging +import math import pathlib import shutil import statistics @@ -262,6 +263,9 @@ def _label(kind: str, agg: str) -> str: print(top) print(bottom) print("-" * len(bottom)) + # Collect raw values first so we can pick a per-column decimal count: keep the previous + # .3f% default, but bump up just enough to give every cell in a column ≥ 2 sig figs. + raw: dict[str, dict[tuple[str, str], float | None]] = {} for name, rows in results.items(): logits_fw = _named_row(rows, "head.logits") logits_bw = _named_row(rows, "head.logits.grad") @@ -269,30 +273,55 @@ def _label(kind: str, agg: str) -> str: "fw": logits_fw["rms_rel"] if logits_fw else float("nan"), "bw": logits_bw["rms_rel"] if logits_bw else float("nan"), } - groups = [] + cells: dict[tuple[str, str], float | None] = {} for kind in kinds: values = [r["rms_rel"] for r in rows if r["kind"] == kind] intermediate = values[1:-1] or values - cells = [] for agg in aggs_per_kind[kind]: if not values: - cells.append("n/a") + cells[(kind, agg)] = None continue if agg == "first": - value = values[0] + cells[(kind, agg)] = values[0] elif agg == "last": - value = values[-1] + cells[(kind, agg)] = values[-1] elif agg == "logits": - value = logits_value[kind] + cells[(kind, agg)] = logits_value[kind] elif agg == "max": - value = max(intermediate) + cells[(kind, agg)] = max(intermediate) else: - value = statistics.median(intermediate) - cells.append(f"{value * 100:.4g}%") - groups.append(" ".join(f"{c:<{cell_width}}" for c in cells)) + cells[(kind, agg)] = statistics.median(intermediate) + raw[name] = cells + + column_decimals: dict[tuple[str, str], int] = {} + for kind in kinds: + for agg in aggs_per_kind[kind]: + column_decimals[(kind, agg)] = _column_decimals( + cells[(kind, agg)] for cells in raw.values() if cells[(kind, agg)] is not None + ) + for name, cells in raw.items(): + groups = [] + for kind in kinds: + formatted = [] + for agg in aggs_per_kind[kind]: + value = cells[(kind, agg)] + if value is None: + formatted.append("n/a") + else: + formatted.append(f"{value * 100:.{column_decimals[(kind, agg)]}f}%") + groups.append(" ".join(f"{c:<{cell_width}}" for c in formatted)) print(f"{name:<{name_width}}" + group_sep.join(groups)) +def _column_decimals(values: typing.Iterable[float], min_sig_figs: int = 2, default: int = 3) -> int: + # Keep the previous default precision, but bump up so the smallest non-zero value + # carries at least `min_sig_figs` significant digits when formatted as percent. + smallest = min((abs(v) * 100 for v in values if v != 0), default=None) + if smallest is None or smallest >= 10 ** -(default - min_sig_figs + 1): + return default + return max(default, -math.floor(math.log10(smallest)) + min_sig_figs - 1) + + def _classify(tensor_name: str) -> str: # Stage._log_layer_forward / _log_layer_backward produce " fw[, mb=…]" # and " bw[, mb=…]"; log_distributed_tensor may prefix the name From 7f8ef96cebce3b0e4ad7dfa76eb224ad8d1ebfae Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 28 May 2026 13:48:10 -0400 Subject: [PATCH 22/41] Tighten summary table spacing Cell width drops from `max_label + 1` to `max_label`, inter-cell sep from two spaces to one, group sep from four spaces to three. About 18 chars narrower on the GSPO smoke run with no loss of alignment or readability. Co-Authored-By: Claude Opus 4.7 (1M context) --- tools/evaluate_precision.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/tools/evaluate_precision.py b/tools/evaluate_precision.py index 6f6d86657..ab35e3e2f 100644 --- a/tools/evaluate_precision.py +++ b/tools/evaluate_precision.py @@ -249,16 +249,17 @@ def _print_summary(results: dict[str, list[dict[str, typing.Any]]]) -> None: def _label(kind: str, agg: str) -> str: return endpoint_labels[(kind, agg)] if agg in ("first", "last") else mid_labels[agg] - name_width = max((len(name) for name in results), default=7) + 2 - cell_width = max(len(_label(k, a)) for k in kinds for a in aggs_per_kind[k]) + 1 - group_sep = " " + name_width = max((len(name) for name in results), default=7) + 1 + cell_width = max(len(_label(k, a)) for k in kinds for a in aggs_per_kind[k]) + cell_sep = " " + group_sep = " " group_widths = { - kind: len(" ".join(f"{_label(kind, a):<{cell_width}}" for a in aggs_per_kind[kind])) for kind in kinds + kind: len(cell_sep.join(f"{_label(kind, a):<{cell_width}}" for a in aggs_per_kind[kind])) for kind in kinds } print("\n=== Summary (Relative %; mid = excluding first/last) ===") top = f"{'':<{name_width}}" + group_sep.join(f"{kind:^{group_widths[kind]}}" for kind in kinds) bottom = f"{'Variant':<{name_width}}" + group_sep.join( - " ".join(f"{_label(kind, a):<{cell_width}}" for a in aggs_per_kind[kind]) for kind in kinds + cell_sep.join(f"{_label(kind, a):<{cell_width}}" for a in aggs_per_kind[kind]) for kind in kinds ) print(top) print(bottom) @@ -309,7 +310,7 @@ def _label(kind: str, agg: str) -> str: formatted.append("n/a") else: formatted.append(f"{value * 100:.{column_decimals[(kind, agg)]}f}%") - groups.append(" ".join(f"{c:<{cell_width}}" for c in formatted)) + groups.append(cell_sep.join(f"{c:<{cell_width}}" for c in formatted)) print(f"{name:<{name_width}}" + group_sep.join(groups)) From 08b163745529db50fbeaa81c0aa51ad1162c082f Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 28 May 2026 14:07:18 -0400 Subject: [PATCH 23/41] Support HF Hub model ids in pretrained.path Lets `pretrained.path: org/model-id` resolve via huggingface_hub.snapshot_download when not a local directory, matching transformers' from_pretrained behavior. Local paths pass through unchanged. Co-Authored-By: Claude Opus 4.7 (1M context) --- fast_llm/engine/checkpoint/huggingface.py | 39 +++++++++++++++-------- 1 file changed, 26 insertions(+), 13 deletions(-) diff --git a/fast_llm/engine/checkpoint/huggingface.py b/fast_llm/engine/checkpoint/huggingface.py index a4810dc1a..4c99798c5 100644 --- a/fast_llm/engine/checkpoint/huggingface.py +++ b/fast_llm/engine/checkpoint/huggingface.py @@ -100,6 +100,18 @@ def _get_key(cls, parameter_name: str, shard_name: str) -> str: Assert.eq(shard_name, "weights") return parameter_name + @classmethod + def _resolve_path(cls, path: pathlib.Path) -> pathlib.Path: + """Resolve a local directory or HF Hub model id (e.g. ``meta-llama/Llama-3.2-1B``) to a + local snapshot directory. Local directories pass through unchanged; everything else is + materialized via :func:`huggingface_hub.snapshot_download` (cached on subsequent calls). + """ + if path.is_dir(): + return path + import huggingface_hub + + return pathlib.Path(huggingface_hub.snapshot_download(str(path))) + # Use custom config instead of relying on the transformers library @classmethod def _load_config(cls, directory: pathlib.Path | str) -> dict: @@ -222,28 +234,29 @@ def _load_weights( import transformers Assert.eq(self.get_shard_names(config), ("weights",)) - if (config.path / transformers.utils.SAFE_WEIGHTS_NAME).is_file(): - paths = {config.path / transformers.utils.SAFE_WEIGHTS_NAME} - elif (config.path / transformers.utils.SAFE_WEIGHTS_INDEX_NAME).is_file(): - logger.info(f"Loading index from {config.path / transformers.utils.SAFE_WEIGHTS_INDEX_NAME}") + directory = self._resolve_path(config.path) + if (directory / transformers.utils.SAFE_WEIGHTS_NAME).is_file(): + paths = {directory / transformers.utils.SAFE_WEIGHTS_NAME} + elif (directory / transformers.utils.SAFE_WEIGHTS_INDEX_NAME).is_file(): + logger.info(f"Loading index from {directory / transformers.utils.SAFE_WEIGHTS_INDEX_NAME}") paths = { - config.path / path - for path in json.loads((config.path / transformers.utils.SAFE_WEIGHTS_INDEX_NAME).read_text())[ + directory / path + for path in json.loads((directory / transformers.utils.SAFE_WEIGHTS_INDEX_NAME).read_text())[ "weight_map" ].values() } - elif (config.path / transformers.utils.WEIGHTS_NAME).is_file(): - paths = {config.path / transformers.utils.WEIGHTS_NAME} - elif (config.path / transformers.utils.WEIGHTS_INDEX_NAME).is_file(): - logger.info(f"Loading index from {config.path / transformers.utils.WEIGHTS_INDEX_NAME}") + elif (directory / transformers.utils.WEIGHTS_NAME).is_file(): + paths = {directory / transformers.utils.WEIGHTS_NAME} + elif (directory / transformers.utils.WEIGHTS_INDEX_NAME).is_file(): + logger.info(f"Loading index from {directory / transformers.utils.WEIGHTS_INDEX_NAME}") paths = { - config.path / path - for path in json.loads((config.path / transformers.utils.WEIGHTS_INDEX_NAME).read_text())[ + directory / path + for path in json.loads((directory / transformers.utils.WEIGHTS_INDEX_NAME).read_text())[ "weight_map" ].values() } else: - raise FileNotFoundError(f"No compatible checkpoint found in {config.path}") + raise FileNotFoundError(f"No compatible checkpoint found in {directory}") for path in paths: logger.info(f"Loading from {path}") From 77eae22226dcb46d258d22822754979ca6d977f4 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 28 May 2026 14:07:25 -0400 Subject: [PATCH 24/41] Add example precision-evaluation configs Two ready-to-run configs for tools/evaluate_precision: smol.yaml sweeps precision-stability features (full_precision_gradients, full_precision_residual, fp32_lm_head) on SmolLM2-135M; smol_gspo.yaml repeats the sweep with the GSPO policy-gradient loss enabled. Co-Authored-By: Claude Opus 4.7 (1M context) --- examples/evaluate_precision/smol.yaml | 34 ++++++++++++++++++++ examples/evaluate_precision/smol_gspo.yaml | 37 ++++++++++++++++++++++ 2 files changed, 71 insertions(+) create mode 100644 examples/evaluate_precision/smol.yaml create mode 100644 examples/evaluate_precision/smol_gspo.yaml diff --git a/examples/evaluate_precision/smol.yaml b/examples/evaluate_precision/smol.yaml new file mode 100644 index 000000000..2d443d3ba --- /dev/null +++ b/examples/evaluate_precision/smol.yaml @@ -0,0 +1,34 @@ +# Example precision-evaluation config: sweep precision-stability features on SmolLM2-135M. +# +# Run with: +# python -m tools.evaluate_precision -c examples/evaluate_precision/smol.yaml +# +# `pretrained.path` accepts either a local checkpoint directory or a HF Hub model id +# (auto-downloaded via `huggingface_hub.snapshot_download` on first use). +pretrained: + path: HuggingFaceTB/SmolLM2-135M + format: llama +output_dir: /tmp/fast_llm_tests/evaluate_precision/features +sequence_length: 128 +num_samples: 512 +variants: + # Baseline bf16: compute_dtype=bf16 + Fast-LLM defaults (fp32 gradient accumulation, bf16 residual, bf16 lm_head). + bf16: + model.distributed.compute_dtype: bfloat16 + # Turn OFF the default fp32 gradient accumulation — gradients accumulate in bf16. + bf16_no_fp32_gradients: + model.distributed.compute_dtype: bfloat16 + model.multi_stage.full_precision_gradients: false + # Turn ON full-precision residual stream. + bf16_fp32_residual: + model.distributed.compute_dtype: bfloat16 + model.base_model.embeddings.full_precision_residual: true + # Turn ON fp32 LM head matmul (PR #526). + bf16_fp32_lm_head: + model.distributed.compute_dtype: bfloat16 + model.base_model.head.fp32_lm_head: true + # Both stability features on (most precise bf16-compute configuration). + bf16_max_precision: + model.distributed.compute_dtype: bfloat16 + model.base_model.embeddings.full_precision_residual: true + model.base_model.head.fp32_lm_head: true diff --git a/examples/evaluate_precision/smol_gspo.yaml b/examples/evaluate_precision/smol_gspo.yaml new file mode 100644 index 000000000..9e7188529 --- /dev/null +++ b/examples/evaluate_precision/smol_gspo.yaml @@ -0,0 +1,37 @@ +# Example precision-evaluation config: sweep precision-stability features on SmolLM2-135M +# with the GSPO policy-gradient loss (uses advantages and old log-probabilities). +# +# Run with: +# python -m tools.evaluate_precision -c examples/evaluate_precision/smol_gspo.yaml +# +# `pretrained.path` accepts either a local checkpoint directory or a HF Hub model id +# (auto-downloaded via `huggingface_hub.snapshot_download` on first use). +pretrained: + path: HuggingFaceTB/SmolLM2-135M + format: llama +model: + base_model: + head: + losses: + gspo: + type: gspo +output_dir: /tmp/fast_llm_tests/evaluate_precision/gspo +data_path: /tmp/fast_llm_tests/evaluate_precision/gspo_data +sequence_length: 128 +num_samples: 512 +variants: + bf16: + model.distributed.compute_dtype: bfloat16 + bf16_no_fp32_gradients: + model.distributed.compute_dtype: bfloat16 + model.multi_stage.full_precision_gradients: false + bf16_fp32_residual: + model.distributed.compute_dtype: bfloat16 + model.base_model.embeddings.full_precision_residual: true + bf16_fp32_lm_head: + model.distributed.compute_dtype: bfloat16 + model.base_model.head.fp32_lm_head: true + bf16_max_precision: + model.distributed.compute_dtype: bfloat16 + model.base_model.embeddings.full_precision_residual: true + model.base_model.head.fp32_lm_head: true From efa95b1fd4ee609581ed3e698a68b99a8cb39a90 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 28 May 2026 14:15:05 -0400 Subject: [PATCH 25/41] Drop bf16_no_fp32_gradients variant from example configs A single forward+backward pass with micro_batch_size=1 has no gradient accumulation, so toggling full_precision_gradients produces bit-identical results to the bf16 baseline. Co-Authored-By: Claude Opus 4.7 (1M context) --- examples/evaluate_precision/smol.yaml | 4 ---- examples/evaluate_precision/smol_gspo.yaml | 3 --- 2 files changed, 7 deletions(-) diff --git a/examples/evaluate_precision/smol.yaml b/examples/evaluate_precision/smol.yaml index 2d443d3ba..8e052cbef 100644 --- a/examples/evaluate_precision/smol.yaml +++ b/examples/evaluate_precision/smol.yaml @@ -15,10 +15,6 @@ variants: # Baseline bf16: compute_dtype=bf16 + Fast-LLM defaults (fp32 gradient accumulation, bf16 residual, bf16 lm_head). bf16: model.distributed.compute_dtype: bfloat16 - # Turn OFF the default fp32 gradient accumulation — gradients accumulate in bf16. - bf16_no_fp32_gradients: - model.distributed.compute_dtype: bfloat16 - model.multi_stage.full_precision_gradients: false # Turn ON full-precision residual stream. bf16_fp32_residual: model.distributed.compute_dtype: bfloat16 diff --git a/examples/evaluate_precision/smol_gspo.yaml b/examples/evaluate_precision/smol_gspo.yaml index 9e7188529..c64276bdd 100644 --- a/examples/evaluate_precision/smol_gspo.yaml +++ b/examples/evaluate_precision/smol_gspo.yaml @@ -22,9 +22,6 @@ num_samples: 512 variants: bf16: model.distributed.compute_dtype: bfloat16 - bf16_no_fp32_gradients: - model.distributed.compute_dtype: bfloat16 - model.multi_stage.full_precision_gradients: false bf16_fp32_residual: model.distributed.compute_dtype: bfloat16 model.base_model.embeddings.full_precision_residual: true From 46bc5b8ea74957e5e553dfd3e3397560b7138c07 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 28 May 2026 14:49:27 -0400 Subject: [PATCH 26/41] Add weight gradients to per-variant report tables Enables debug_all_param_gradients so every parameter's reduced gradient is captured in tensor_logs alongside the existing layer activations and input gradients. New rows are tagged with kind 'grad' and appear in the per-variant table but stay out of the fw/bw summary table. Also makes the per-variant table's Tensor column width fit the longest name (parameter gradients can be 40+ chars) and bumps the Relative column to adaptive precision (capped at 5 decimals) so legitimately tiny values stay legible. Co-Authored-By: Claude Opus 4.7 (1M context) --- tools/evaluate_precision.py | 38 +++++++++++++++++++++++++++---------- 1 file changed, 28 insertions(+), 10 deletions(-) diff --git a/tools/evaluate_precision.py b/tools/evaluate_precision.py index ab35e3e2f..e38ce347a 100644 --- a/tools/evaluate_precision.py +++ b/tools/evaluate_precision.py @@ -162,6 +162,7 @@ def _run_one(self, name: str, variant_overrides: dict[str, typing.Any]) -> None: tool_overrides = { ("model", "multi_stage", "debug_layer_outputs"): _LOG_LEVEL, ("model", "multi_stage", "debug_layer_gradients"): _LOG_LEVEL, + ("model", "multi_stage", "debug_all_param_gradients"): _LOG_LEVEL, # Capture the LM-head logits via the `output_hidden_states` mechanism: the head's # `_debug(logits, ...)` call matches this pattern and emits to `tensor_logs`. ("model", "multi_stage", "debug_hidden_states_log"): [r"head\.logits"], @@ -314,22 +315,32 @@ def _label(kind: str, agg: str) -> str: print(f"{name:<{name_width}}" + group_sep.join(groups)) -def _column_decimals(values: typing.Iterable[float], min_sig_figs: int = 2, default: int = 3) -> int: - # Keep the previous default precision, but bump up so the smallest non-zero value - # carries at least `min_sig_figs` significant digits when formatted as percent. +def _column_decimals( + values: typing.Iterable[float], min_sig_figs: int = 2, default: int = 3, max_decimals: int | None = None +) -> int: + # Keep the default precision, but bump up so the smallest non-zero value carries at least + # `min_sig_figs` significant digits when formatted as percent. `max_decimals` caps the + # bump so a single tiny noisy value doesn't widen the whole column. smallest = min((abs(v) * 100 for v in values if v != 0), default=None) if smallest is None or smallest >= 10 ** -(default - min_sig_figs + 1): - return default - return max(default, -math.floor(math.log10(smallest)) + min_sig_figs - 1) + result = default + else: + result = max(default, -math.floor(math.log10(smallest)) + min_sig_figs - 1) + return min(result, max_decimals) if max_decimals is not None else result def _classify(tensor_name: str) -> str: # Stage._log_layer_forward / _log_layer_backward produce " fw[, mb=…]" # and " bw[, mb=…]"; log_distributed_tensor may prefix the name # with "Global " and append a ": " suffix when reconstructing a - # tensor-parallel-global tensor. Other entries (e.g. `Global : head.logits`, - # `Global : head.logits.grad`) come from the `_debug` / `output_hidden_states` path - # and are surfaced via dedicated logits columns in the summary. + # tensor-parallel-global tensor. Per-parameter gradient logs come from + # `Fsdp.log_shard(name="gradient", ...)` and are tagged "grad" so they appear + # in the per-variant table but stay out of the fw/bw summary aggregation. + # Other entries (e.g. `Global : head.logits`, `Global : head.logits.grad`) come + # from the `_debug` / `output_hidden_states` path and are surfaced via dedicated + # logits columns in the summary. + if "gradient:" in tensor_name: + return "grad" for kind in ("fw", "bw"): if f" {kind}:" in tensor_name or f" {kind}," in tensor_name or tensor_name.endswith(f" {kind}"): return kind @@ -341,9 +352,16 @@ def _print_table(name: str, rows: list[dict[str, typing.Any]]) -> None: if not rows: print("(no matching tensors)") return + name_fn = lambda r: f"{r['tensor_name'].split(':', 1)[-1].strip()} ({r['kind']})" + name_width = max(len("Tensor"), max(len(name_fn(r)) for r in rows)) + # Adaptive precision for the relative column: bump decimals so small but real values + # (typical for weight gradients) stay legible, capped at 5 to bound column width. + relative_decimals = _column_decimals((r["rms_rel"] for r in rows), default=2, max_decimals=5) + relative_fn = lambda r: f"{r['rms_rel'] * 100:.{relative_decimals}f}%" + relative_width = max(len("Relative"), max(len(relative_fn(r)) for r in rows)) columns: list[tuple[str, int, typing.Callable[[dict[str, typing.Any]], str]]] = [ - ("Tensor", 26, lambda r: f"{r['tensor_name'].split(':', 1)[-1].strip()} ({r['kind']})"), - ("Relative", 8, lambda r: f"{r['rms_rel'] * 100:.2f}%"), + ("Tensor", name_width, name_fn), + ("Relative", relative_width, relative_fn), ("Absolute", 10, lambda r: f"{r['rms_abs']:.4g}"), ("Max", 10, lambda r: f"{r['max_abs']:.4g}"), ("Scale", 10, lambda r: f"{r['ref_scale']:.4g}"), From bef2f0db6a94b31501767c97d56e9c265c40e270 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 28 May 2026 14:56:03 -0400 Subject: [PATCH 27/41] Separate fw/bw/grad rows in per-variant tables MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Group rows in the per-variant tables by display group with blank lines between fw, bw, and grad. The reduce_gradients hook emits parameter gradients chronologically interleaved with the backward pass, which made the previous table hard to scan. Display grouping is independent of `kind` so the summary aggregation is unaffected — head.logits.grad just moves to the bw block visually. Co-Authored-By: Claude Opus 4.7 (1M context) --- tools/evaluate_precision.py | 27 ++++++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/tools/evaluate_precision.py b/tools/evaluate_precision.py index e38ce347a..f220df2e1 100644 --- a/tools/evaluate_precision.py +++ b/tools/evaluate_precision.py @@ -329,6 +329,17 @@ def _column_decimals( return min(result, max_decimals) if max_decimals is not None else result +def _display_group(row: dict[str, typing.Any]) -> str: + # Map each row to one of "fw"/"bw"/"grad" for the per-variant table, independent + # of `kind`: head.logits is a forward activation, head.logits.grad is a backward + # quantity, parameter gradients are their own group. + if row["kind"] == "grad": + return "grad" + if row["kind"] == "bw" or row["tensor_name"].endswith(".grad"): + return "bw" + return "fw" + + def _classify(tensor_name: str) -> str: # Stage._log_layer_forward / _log_layer_backward produce " fw[, mb=…]" # and " bw[, mb=…]"; log_distributed_tensor may prefix the name @@ -369,8 +380,22 @@ def _print_table(name: str, rows: list[dict[str, typing.Any]]) -> None: header = " ".join(f"{title:<{width}}" for title, width, _ in columns) print(header) print("-" * len(header)) + # Display grouping (fw / bw / grad) separates the chronologically-interleaved + # backward and reduce_gradients hooks. Independent of `kind` so the summary + # aggregation isn't affected. + groups = ("fw", "bw", "grad") + grouped: dict[str, list[dict[str, typing.Any]]] = {g: [] for g in groups} for row in rows: - print(" ".join(f"{format_fn(row):<{width}}" for _, width, format_fn in columns)) + grouped[_display_group(row)].append(row) + first = True + for group in groups: + if not grouped[group]: + continue + if not first: + print() + first = False + for row in grouped[group]: + print(" ".join(f"{format_fn(row):<{width}}" for _, width, format_fn in columns)) if __name__ == "__main__": From 4fecad4860de22b300f0591b7484aadecef6da07 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 28 May 2026 15:15:59 -0400 Subject: [PATCH 28/41] Split summary into three tables (fw, bw, grad) Each pass gets its own self-contained Variant x columns table with labels picked from the actual first/last logged tensor. Weight gradients get a head/mid med/mid max/embeddings layout mirroring the bw structure; the grad table makes large norm_1 outliers (>200% relative) immediately visible at a glance. Co-Authored-By: Claude Opus 4.7 (1M context) --- tools/evaluate_precision.py | 126 ++++++++++++++++-------------------- 1 file changed, 54 insertions(+), 72 deletions(-) diff --git a/tools/evaluate_precision.py b/tools/evaluate_precision.py index f220df2e1..236871b35 100644 --- a/tools/evaluate_precision.py +++ b/tools/evaluate_precision.py @@ -209,8 +209,11 @@ def _compare(self, ref_path: pathlib.Path, test_path: pathlib.Path) -> list[dict def _layer_name(tensor_name: str) -> str: # Stage hooks name tensors `Global fw: ...` / `Global bw: ...`; - # extract the layer to use as a meaningful column label. + # Fsdp.log_shard names weight gradients `Global gradient: `. prefix = tensor_name.split(":", 1)[0].strip().split() + if prefix == ["Global", "gradient"]: + param = tensor_name.split(":", 1)[1].strip() + return param.split(".")[0] if prefix and prefix[0] == "Global": prefix = prefix[1:] if prefix and prefix[-1] in ("fw", "bw"): @@ -223,51 +226,38 @@ def _named_row(rows: list[dict[str, typing.Any]], name: str) -> dict[str, typing def _print_summary(results: dict[str, list[dict[str, typing.Any]]]) -> None: - # Per-pass labels for `first`/`last` come from the actual layer name on the matching row. sample = next(iter(results.values())) - endpoint_labels: dict[tuple[str, str], str] = { - ("fw", "first"): "first", - ("fw", "last"): "last", - ("bw", "first"): "first", - ("bw", "last"): "last", - } - for kind in ("fw", "bw"): - group = [r for r in sample if r["kind"] == kind] - if group: - endpoint_labels[(kind, "first")] = _layer_name(group[0]["tensor_name"]) - endpoint_labels[(kind, "last")] = _layer_name(group[-1]["tensor_name"]) - mid_labels = {"median": "mid med", "max": "mid max", "logits": "logits"} - # Logits show up via `output_hidden_states` (`Global : head.logits` on the fw side and - # `Global : head.logits.grad` on the bw side once the loss has computed dL/dlogits). - # Each gets a dedicated column placed chronologically: end-of-fw and start-of-bw. has_fw_logits = _named_row(sample, "head.logits") is not None has_bw_logits = _named_row(sample, "head.logits.grad") is not None + # Each kind's aggregation columns are listed chronologically (left-to-right matches + # the order tensors are logged). Logits show up via `output_hidden_states` on the + # fw/bw boundary; weight gradients have no logits hook. fw_aggs = ("first", "median", "max") + (("logits",) if has_fw_logits else ()) + ("last",) bw_aggs = ("first",) + (("logits",) if has_bw_logits else ()) + ("median", "max", "last") - aggs_per_kind = {"fw": fw_aggs, "bw": bw_aggs} - kinds = ("fw", "bw") + grad_aggs = ("first", "median", "max", "last") + aggs_per_kind = {"fw": fw_aggs, "bw": bw_aggs, "grad": grad_aggs} + for kind in ("fw", "bw", "grad"): + _print_summary_table(results, kind, aggs_per_kind[kind]) + - def _label(kind: str, agg: str) -> str: - return endpoint_labels[(kind, agg)] if agg in ("first", "last") else mid_labels[agg] +def _print_summary_table(results: dict[str, list[dict[str, typing.Any]]], kind: str, aggs: tuple[str, ...]) -> None: + sample = next(iter(results.values())) + group = [r for r in sample if r["kind"] == kind] + if not group: + return + endpoint_labels = { + "first": _layer_name(group[0]["tensor_name"]), + "last": _layer_name(group[-1]["tensor_name"]), + } + mid_labels = {"median": "mid med", "max": "mid max", "logits": "logits"} + + def _label(agg: str) -> str: + return endpoint_labels[agg] if agg in endpoint_labels else mid_labels[agg] name_width = max((len(name) for name in results), default=7) + 1 - cell_width = max(len(_label(k, a)) for k in kinds for a in aggs_per_kind[k]) + cell_width = max(len(_label(a)) for a in aggs) cell_sep = " " - group_sep = " " - group_widths = { - kind: len(cell_sep.join(f"{_label(kind, a):<{cell_width}}" for a in aggs_per_kind[kind])) for kind in kinds - } - print("\n=== Summary (Relative %; mid = excluding first/last) ===") - top = f"{'':<{name_width}}" + group_sep.join(f"{kind:^{group_widths[kind]}}" for kind in kinds) - bottom = f"{'Variant':<{name_width}}" + group_sep.join( - cell_sep.join(f"{_label(kind, a):<{cell_width}}" for a in aggs_per_kind[kind]) for kind in kinds - ) - print(top) - print(bottom) - print("-" * len(bottom)) - # Collect raw values first so we can pick a per-column decimal count: keep the previous - # .3f% default, but bump up just enough to give every cell in a column ≥ 2 sig figs. - raw: dict[str, dict[tuple[str, str], float | None]] = {} + raw: dict[str, dict[str, float | None]] = {} for name, rows in results.items(): logits_fw = _named_row(rows, "head.logits") logits_bw = _named_row(rows, "head.logits.grad") @@ -275,44 +265,36 @@ def _label(kind: str, agg: str) -> str: "fw": logits_fw["rms_rel"] if logits_fw else float("nan"), "bw": logits_bw["rms_rel"] if logits_bw else float("nan"), } - cells: dict[tuple[str, str], float | None] = {} - for kind in kinds: - values = [r["rms_rel"] for r in rows if r["kind"] == kind] - intermediate = values[1:-1] or values - for agg in aggs_per_kind[kind]: - if not values: - cells[(kind, agg)] = None - continue - if agg == "first": - cells[(kind, agg)] = values[0] - elif agg == "last": - cells[(kind, agg)] = values[-1] - elif agg == "logits": - cells[(kind, agg)] = logits_value[kind] - elif agg == "max": - cells[(kind, agg)] = max(intermediate) - else: - cells[(kind, agg)] = statistics.median(intermediate) + values = [r["rms_rel"] for r in rows if r["kind"] == kind] + intermediate = values[1:-1] or values + cells: dict[str, float | None] = {} + for agg in aggs: + if not values: + cells[agg] = None + elif agg == "first": + cells[agg] = values[0] + elif agg == "last": + cells[agg] = values[-1] + elif agg == "logits": + cells[agg] = logits_value[kind] + elif agg == "max": + cells[agg] = max(intermediate) + else: + cells[agg] = statistics.median(intermediate) raw[name] = cells - column_decimals: dict[tuple[str, str], int] = {} - for kind in kinds: - for agg in aggs_per_kind[kind]: - column_decimals[(kind, agg)] = _column_decimals( - cells[(kind, agg)] for cells in raw.values() if cells[(kind, agg)] is not None - ) + column_decimals = { + agg: _column_decimals(cells[agg] for cells in raw.values() if cells[agg] is not None) for agg in aggs + } + print(f"\n=== Summary: {kind} (Relative %; mid = excluding first/last) ===") + header = f"{'Variant':<{name_width}}" + cell_sep.join(f"{_label(a):<{cell_width}}" for a in aggs) + print(header) + print("-" * len(header)) for name, cells in raw.items(): - groups = [] - for kind in kinds: - formatted = [] - for agg in aggs_per_kind[kind]: - value = cells[(kind, agg)] - if value is None: - formatted.append("n/a") - else: - formatted.append(f"{value * 100:.{column_decimals[(kind, agg)]}f}%") - groups.append(cell_sep.join(f"{c:<{cell_width}}" for c in formatted)) - print(f"{name:<{name_width}}" + group_sep.join(groups)) + formatted = [ + f"{cells[agg] * 100:.{column_decimals[agg]}f}%" if cells[agg] is not None else "n/a" for agg in aggs + ] + print(f"{name:<{name_width}}" + cell_sep.join(f"{c:<{cell_width}}" for c in formatted)) def _column_decimals( From 4f47dc045bab9def7f4e6a0286ea253c5c8262f0 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 28 May 2026 15:26:01 -0400 Subject: [PATCH 29/41] Split grad summary by parameter category Replace the chronological first/last columns in the grad table with named lookups (lm_head / embeddings) and split the intermediate aggregation by category: linear weights, norm weights, biases. The bias columns appear only when biases exist. lm_head shows n/a when the LM head weight is tied to the embedding (e.g. SmolLM2), since the combined gradient is recorded under the embedding parameter. Co-Authored-By: Claude Opus 4.7 (1M context) --- tools/evaluate_precision.py | 80 ++++++++++++++++++++++++++++++++----- 1 file changed, 69 insertions(+), 11 deletions(-) diff --git a/tools/evaluate_precision.py b/tools/evaluate_precision.py index 236871b35..abbc1e2da 100644 --- a/tools/evaluate_precision.py +++ b/tools/evaluate_precision.py @@ -225,21 +225,41 @@ def _named_row(rows: list[dict[str, typing.Any]], name: str) -> dict[str, typing return next((r for r in rows if r["tensor_name"].split(":", 1)[-1].strip() == name), None) +_LM_HEAD_NAME = "head.output_weights" +_EMBEDDINGS_NAME = "embeddings.word_embeddings_weight" + + def _print_summary(results: dict[str, list[dict[str, typing.Any]]]) -> None: sample = next(iter(results.values())) has_fw_logits = _named_row(sample, "head.logits") is not None has_bw_logits = _named_row(sample, "head.logits.grad") is not None + has_bias = any( + r["kind"] == "grad" and r["tensor_name"].split(":", 1)[-1].strip().endswith(".bias") for r in sample + ) # Each kind's aggregation columns are listed chronologically (left-to-right matches # the order tensors are logged). Logits show up via `output_hidden_states` on the # fw/bw boundary; weight gradients have no logits hook. fw_aggs = ("first", "median", "max") + (("logits",) if has_fw_logits else ()) + ("last",) bw_aggs = ("first",) + (("logits",) if has_bw_logits else ()) + ("median", "max", "last") - grad_aggs = ("first", "median", "max", "last") + grad_aggs = ( + ("lm_head", "linear_med", "linear_max", "norm_med", "norm_max") + + (("bias_med", "bias_max") if has_bias else ()) + + ("embeddings",) + ) aggs_per_kind = {"fw": fw_aggs, "bw": bw_aggs, "grad": grad_aggs} for kind in ("fw", "bw", "grad"): _print_summary_table(results, kind, aggs_per_kind[kind]) +def _grad_category(tensor_name: str) -> str: + name = tensor_name.split(":", 1)[-1].strip() + if name.endswith(".bias"): + return "bias" + if ".norm_" in name or name.endswith(".norm.weight"): + return "norm" + return "linear" + + def _print_summary_table(results: dict[str, list[dict[str, typing.Any]]], kind: str, aggs: tuple[str, ...]) -> None: sample = next(iter(results.values())) group = [r for r in sample if r["kind"] == kind] @@ -249,7 +269,19 @@ def _print_summary_table(results: dict[str, list[dict[str, typing.Any]]], kind: "first": _layer_name(group[0]["tensor_name"]), "last": _layer_name(group[-1]["tensor_name"]), } - mid_labels = {"median": "mid med", "max": "mid max", "logits": "logits"} + mid_labels = { + "median": "mid med", + "max": "mid max", + "logits": "logits", + "lm_head": "lm head", + "embeddings": "embeddings", + "linear_med": "linear med", + "linear_max": "linear max", + "norm_med": "norm med", + "norm_max": "norm max", + "bias_med": "bias med", + "bias_max": "bias max", + } def _label(agg: str) -> str: return endpoint_labels[agg] if agg in endpoint_labels else mid_labels[agg] @@ -265,28 +297,54 @@ def _label(agg: str) -> str: "fw": logits_fw["rms_rel"] if logits_fw else float("nan"), "bw": logits_bw["rms_rel"] if logits_bw else float("nan"), } - values = [r["rms_rel"] for r in rows if r["kind"] == kind] + kind_rows = [r for r in rows if r["kind"] == kind] + values = [r["rms_rel"] for r in kind_rows] + if kind == "grad": + decoder_rows = [r for r in kind_rows if r["tensor_name"].split(":", 1)[-1].strip().startswith("decoder.")] + category_values: dict[str, list[float]] = {"linear": [], "norm": [], "bias": []} + for r in decoder_rows: + category_values[_grad_category(r["tensor_name"])].append(r["rms_rel"]) + lm_head_row = _named_row(kind_rows, _LM_HEAD_NAME) + embeddings_row = _named_row(kind_rows, _EMBEDDINGS_NAME) + else: + category_values = {} + lm_head_row = embeddings_row = None intermediate = values[1:-1] or values cells: dict[str, float | None] = {} for agg in aggs: - if not values: - cells[agg] = None - elif agg == "first": - cells[agg] = values[0] + if agg == "first": + cells[agg] = values[0] if values else None elif agg == "last": - cells[agg] = values[-1] + cells[agg] = values[-1] if values else None elif agg == "logits": cells[agg] = logits_value[kind] + elif agg == "lm_head": + cells[agg] = lm_head_row["rms_rel"] if lm_head_row else None + elif agg == "embeddings": + cells[agg] = embeddings_row["rms_rel"] if embeddings_row else None + elif "_" in agg and agg.split("_", 1)[0] in category_values: + cat, stat = agg.split("_", 1) + cat_values = category_values[cat] + if not cat_values: + cells[agg] = None + elif stat == "max": + cells[agg] = max(cat_values) + else: + cells[agg] = statistics.median(cat_values) elif agg == "max": - cells[agg] = max(intermediate) + cells[agg] = max(intermediate) if intermediate else None else: - cells[agg] = statistics.median(intermediate) + cells[agg] = statistics.median(intermediate) if intermediate else None raw[name] = cells column_decimals = { agg: _column_decimals(cells[agg] for cells in raw.values() if cells[agg] is not None) for agg in aggs } - print(f"\n=== Summary: {kind} (Relative %; mid = excluding first/last) ===") + if kind == "grad": + subtitle = " (Relative %)" + else: + subtitle = " (Relative %; mid = excluding first/last)" + print(f"\n=== Summary: {kind}{subtitle} ===") header = f"{'Variant':<{name_width}}" + cell_sep.join(f"{_label(a):<{cell_width}}" for a in aggs) print(header) print("-" * len(header)) From 5198c2551ebef9e6f53d273243037daf13c2763c Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 28 May 2026 15:50:18 -0400 Subject: [PATCH 30/41] Per-tensor sample-density overrides in TensorLogsConfig Add `sample_level_overrides: dict[str, int]` (regex pattern -> level) to `TensorLogsConfig`. `log_tensor` raises the effective level for any tensor whose logged name matches a pattern, so callers can collect more samples for specific tensors without changing the default. Useful for sparsely-non-zero tensors like embedding-weight gradients, where the default uniform stride misses every non-zero row. evaluate_precision: switch `num_samples` to actually drive the level (was only cropping the text log), bump default to 8192, default sequence length to 2048 in the example yamls, and add a 1M-sample override for `Global gradient: embeddings.*` to make embedding-grad errors measurable on small batches. Co-Authored-By: Claude Opus 4.7 (1M context) --- examples/evaluate_precision/smol.yaml | 3 +-- examples/evaluate_precision/smol_gspo.yaml | 3 +-- fast_llm/engine/config_utils/logging.py | 9 +++++++ fast_llm/logging.py | 9 +++++++ tools/evaluate_precision.py | 29 +++++++++++++++------- 5 files changed, 40 insertions(+), 13 deletions(-) diff --git a/examples/evaluate_precision/smol.yaml b/examples/evaluate_precision/smol.yaml index 8e052cbef..cf0f8554b 100644 --- a/examples/evaluate_precision/smol.yaml +++ b/examples/evaluate_precision/smol.yaml @@ -9,8 +9,7 @@ pretrained: path: HuggingFaceTB/SmolLM2-135M format: llama output_dir: /tmp/fast_llm_tests/evaluate_precision/features -sequence_length: 128 -num_samples: 512 +sequence_length: 2048 variants: # Baseline bf16: compute_dtype=bf16 + Fast-LLM defaults (fp32 gradient accumulation, bf16 residual, bf16 lm_head). bf16: diff --git a/examples/evaluate_precision/smol_gspo.yaml b/examples/evaluate_precision/smol_gspo.yaml index c64276bdd..5e3545573 100644 --- a/examples/evaluate_precision/smol_gspo.yaml +++ b/examples/evaluate_precision/smol_gspo.yaml @@ -17,8 +17,7 @@ model: type: gspo output_dir: /tmp/fast_llm_tests/evaluate_precision/gspo data_path: /tmp/fast_llm_tests/evaluate_precision/gspo_data -sequence_length: 128 -num_samples: 512 +sequence_length: 2048 variants: bf16: model.distributed.compute_dtype: bfloat16 diff --git a/fast_llm/engine/config_utils/logging.py b/fast_llm/engine/config_utils/logging.py index 32deb4562..b82d4c847 100644 --- a/fast_llm/engine/config_utils/logging.py +++ b/fast_llm/engine/config_utils/logging.py @@ -76,6 +76,15 @@ class TensorLogsConfig(Config): valid=check_field(Assert.gt, 0), ) full_tensors: bool = Field(default=False, desc="Save and/or print entire tensors.") + sample_level_overrides: dict[str, int] = Field( + default_factory=dict, + desc="Per-tensor sample-density overrides (regex pattern -> level)." + " For tensors whose logged name matches a pattern, the effective `log_tensor` level is" + " raised to the matching override (samples = 2 ** (level - 3))." + " Useful for sparse tensors like embedding-weight gradients where the default sampling" + " stride misses most non-zero rows.", + hint=FieldHint.logging, + ) class TensorLogs: diff --git a/fast_llm/logging.py b/fast_llm/logging.py index 2619883d6..6326e7e4b 100644 --- a/fast_llm/logging.py +++ b/fast_llm/logging.py @@ -131,6 +131,15 @@ def log_tensor[T]( ) -> T | None: if level < 1: return + # Per-tensor sample-density override: lets users boost the effective level for specific + # tensors (e.g. sparse embedding-weight gradients) via `TensorLogsConfig`. + overrides = TensorLogs.config.sample_level_overrides if TensorLogs.config else None + if overrides: + import re + + for pattern, override in overrides.items(): + if re.search(pattern, name): + level = max(level, override) tensor = tensor.detach() if tensor.ndim == 0: tensor = tensor[None] diff --git a/tools/evaluate_precision.py b/tools/evaluate_precision.py index abbc1e2da..02131cd6a 100644 --- a/tools/evaluate_precision.py +++ b/tools/evaluate_precision.py @@ -20,11 +20,14 @@ logger = logging.getLogger(__name__) -# Tensor-log verbosity level. 13 gives 2**(13-3)=1024 sampled values per tensor, -# matching the convention in the existing layer-comparison tests. -_LOG_LEVEL = 13 _REFERENCE_NAME = "reference" _MODEL_TYPE = "gpt" +# Embedding-weight gradients are row-sparse (only input-token rows non-zero), so a +# uniformly-spaced sample of vocab_size entries usually misses all of them. The pattern +# is applied via `TensorLogsConfig.sample_level_overrides` and picked up inside +# `log_tensor` (samples = 2 ** (level - 3) -> level 23 yields ~1M samples per tensor). +_SPARSE_GRAD_LEVEL = 23 +_SPARSE_GRAD_OVERRIDES = {r"Global gradient: embeddings\.": _SPARSE_GRAD_LEVEL} @config_class() @@ -48,8 +51,10 @@ class EvaluatePrecisionConfig(PretrainedGPTModelConfig, RunnableConfig): hint=FieldHint.core, ) num_samples: int = Field( - default=1024, - desc="Number of sampled values stored per logged tensor.", + default=8192, + desc="Number of sampled values stored per logged tensor (rounded up to next power of 2)." + " Sparse tensors (e.g. embedding-weight gradients) get a higher level via" + " `TensorLogsConfig.sample_level_overrides`.", hint=FieldHint.feature, ) micro_batch_size: int = Field( @@ -150,9 +155,15 @@ def _run_one(self, name: str, variant_overrides: dict[str, typing.Any]) -> None: }, "run": { "experiment_dir": str((self.output_dir / name).resolve()), - "tensor_logs": {"save": True, "show": False, "max_elements": self.num_samples}, + "tensor_logs": { + "save": True, + "show": False, + "sample_level_overrides": _SPARSE_GRAD_OVERRIDES, + }, }, } + # Translate `num_samples` to a `log_tensor` level: 2**(level-3) = samples. + log_level = math.ceil(math.log2(max(self.num_samples, 1))) + 3 fp32_dtypes = { ("model", "distributed", "compute_dtype"): "float32", ("model", "distributed", "optimization_dtype"): "float32", @@ -160,9 +171,9 @@ def _run_one(self, name: str, variant_overrides: dict[str, typing.Any]) -> None: variant_updates = {tuple(key.split(".")): value for key, value in variant_overrides.items()} # Tool-required overrides win over variants — a variant must not silently disable tensor logging. tool_overrides = { - ("model", "multi_stage", "debug_layer_outputs"): _LOG_LEVEL, - ("model", "multi_stage", "debug_layer_gradients"): _LOG_LEVEL, - ("model", "multi_stage", "debug_all_param_gradients"): _LOG_LEVEL, + ("model", "multi_stage", "debug_layer_outputs"): log_level, + ("model", "multi_stage", "debug_layer_gradients"): log_level, + ("model", "multi_stage", "debug_all_param_gradients"): log_level, # Capture the LM-head logits via the `output_hidden_states` mechanism: the head's # `_debug(logits, ...)` call matches this pattern and emits to `tensor_logs`. ("model", "multi_stage", "debug_hidden_states_log"): [r"head\.logits"], From 312343e7cdca3b56699f4494c25b0908bfe7d447 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 29 May 2026 15:06:16 -0400 Subject: [PATCH 31/41] Chosen-logprob loss, per-variant grad-scale auto-calibration, fp16 variants MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - New `chosen_logprob` LM loss: logs `log_softmax(logits)[label]` per position with no gradient contribution. Tool auto-adds it and surfaces a dedicated summary with bias, correlation, slope, and residual-after-linear-fit. - `_compute_diff` reports bias_abs/rel, correlation, slope, residual_rms_abs/rel — the linear decomposition separates systematic shift/scale from per-position noise. - Per-variant auto-calibrated power-of-2 gradient scale: each variant runs a calibration pass at scale=1 to measure max unscaled gradient, then the real run picks the largest power-of-2 scale that fits within fp16 range (with a small safety factor for fused-kernel partial sums). `_compare` unscales per variant. - Tool: backend-override mechanism (`_torch_backend.*`) and `_torch_matmul_precision` variant keys for diagnostic variants. New variants: `bf16_in_fp32_out` (probes whether `fp32_lm_head`'s gain is from output dtype vs matmul precision), `bf16_reduced_reduction` (probes the split-K reduction path), and a full fp16 sweep mirroring the bf16 variants. - Fix: `data.micro_batch_size` in Fast-LLM is the per-sample sequence length, not the batch dim. Tool was passing 1 → every prior run was on 1-token inputs. Co-Authored-By: Claude Opus 4.7 (1M context) --- examples/evaluate_precision/smol.yaml | 30 +++ examples/evaluate_precision/smol_gspo.yaml | 19 ++ .../config_utils/compare_tensor_logs.py | 26 +- .../language_model/loss/chosen_logprob.py | 41 ++++ fast_llm/layers/language_model/loss/config.py | 25 ++ tools/evaluate_precision.py | 222 ++++++++++++++++-- 6 files changed, 342 insertions(+), 21 deletions(-) create mode 100644 fast_llm/layers/language_model/loss/chosen_logprob.py diff --git a/examples/evaluate_precision/smol.yaml b/examples/evaluate_precision/smol.yaml index cf0f8554b..cc17c19e0 100644 --- a/examples/evaluate_precision/smol.yaml +++ b/examples/evaluate_precision/smol.yaml @@ -27,3 +27,33 @@ variants: model.distributed.compute_dtype: bfloat16 model.base_model.embeddings.full_precision_residual: true model.base_model.head.fp32_lm_head: true + # Diagnostic: enable bf16 reduced-precision reductions in cuBLAS GEMMs. Tests whether the + # within-engine bf16-vs-fp32 gap is sensitive to the partial-sum reduction precision (the + # MMA accumulator is fp32 by hardware on H100/A100; this flag affects split-K reductions). + bf16_reduced_reduction: + model.distributed.compute_dtype: bfloat16 + _torch_backend.cuda.matmul.allow_bf16_reduced_precision_reduction: true + # Diagnostic: simulate a "bf16 inputs, fp32 output" lm-head matmul kernel. fp32_lm_head=True + # upcasts inputs+weights to fp32, then matmul_precision='medium' runs the matmul through + # bf16 Tensor Cores anyway, then logits stay fp32. Tests whether fp32_lm_head's gain comes + # from input precision or from skipping the bf16 output cast. + bf16_in_fp32_out: + model.distributed.compute_dtype: bfloat16 + model.base_model.head.fp32_lm_head: true + _torch_matmul_precision: medium + # fp16 sweep: probes whether the precision-vs-noise picture (rms noise ~0.1 nats per token + # for bf16) shrinks ~8× for fp16 (10 mantissa bits vs 7), as the literature's "switch to + # fp16" recommendation implies. Default dynamic grad-scaler (initial 2^16) is uniform + # across variants, so relative comparisons stay meaningful. + fp16: + model.distributed.compute_dtype: float16 + fp16_fp32_residual: + model.distributed.compute_dtype: float16 + model.base_model.embeddings.full_precision_residual: true + fp16_fp32_lm_head: + model.distributed.compute_dtype: float16 + model.base_model.head.fp32_lm_head: true + fp16_max_precision: + model.distributed.compute_dtype: float16 + model.base_model.embeddings.full_precision_residual: true + model.base_model.head.fp32_lm_head: true diff --git a/examples/evaluate_precision/smol_gspo.yaml b/examples/evaluate_precision/smol_gspo.yaml index 5e3545573..b0e8e319d 100644 --- a/examples/evaluate_precision/smol_gspo.yaml +++ b/examples/evaluate_precision/smol_gspo.yaml @@ -31,3 +31,22 @@ variants: model.distributed.compute_dtype: bfloat16 model.base_model.embeddings.full_precision_residual: true model.base_model.head.fp32_lm_head: true + bf16_reduced_reduction: + model.distributed.compute_dtype: bfloat16 + _torch_backend.cuda.matmul.allow_bf16_reduced_precision_reduction: true + bf16_in_fp32_out: + model.distributed.compute_dtype: bfloat16 + model.base_model.head.fp32_lm_head: true + _torch_matmul_precision: medium + fp16: + model.distributed.compute_dtype: float16 + fp16_fp32_residual: + model.distributed.compute_dtype: float16 + model.base_model.embeddings.full_precision_residual: true + fp16_fp32_lm_head: + model.distributed.compute_dtype: float16 + model.base_model.head.fp32_lm_head: true + fp16_max_precision: + model.distributed.compute_dtype: float16 + model.base_model.embeddings.full_precision_residual: true + model.base_model.head.fp32_lm_head: true diff --git a/fast_llm/engine/config_utils/compare_tensor_logs.py b/fast_llm/engine/config_utils/compare_tensor_logs.py index 080510036..dbad78a25 100644 --- a/fast_llm/engine/config_utils/compare_tensor_logs.py +++ b/fast_llm/engine/config_utils/compare_tensor_logs.py @@ -100,8 +100,24 @@ def _compute_diff(self, tensor_ref, tensor_test, step_name, tensor_name) -> dict samples_test = samples_test / sub_config.scale scale_unreg = (samples_ref**2).mean() ** 0.5 rms_scale = (scale_unreg**2 + sub_config.rms_eps**2) ** 0.5 - rms = ((samples_ref - samples_test) ** 2).mean() ** 0.5 - max_diff = (samples_ref - samples_test).abs().max() + diff = samples_test - samples_ref + rms = (diff**2).mean() ** 0.5 + max_diff = diff.abs().max() + bias = diff.mean() + # Linear-regression decomposition: `test ≈ slope * ref + intercept + residual`. + # Useful for separating systematic distortion (slope ≠ 1) from per-position decorrelated + # noise (residual). For RL importance ratios, slope ≠ 1 indicates likely-token-dependent + # bias which is more dangerous than a uniform shift. + centered_test = samples_test - samples_test.mean() + centered_ref = samples_ref - samples_ref.mean() + var_ref = (centered_ref**2).mean() + var_test = (centered_test**2).mean() + cov = (centered_test * centered_ref).mean() + denom = (var_test * var_ref) ** 0.5 + correlation = (cov / denom).item() if denom > 0 else float("nan") + slope = (cov / var_ref).item() if var_ref > 0 else float("nan") + residual_var = (var_test - cov**2 / var_ref).clamp(min=0.0) if var_ref > 0 else var_test + residual_rms = residual_var**0.5 return { "rms_abs": rms.item(), "rms_rel": (rms / rms_scale).item(), @@ -109,6 +125,12 @@ def _compute_diff(self, tensor_ref, tensor_test, step_name, tensor_name) -> dict "max_rel": (max_diff / rms_scale).item(), "ref_scale": scale_unreg.item(), "ref_scale_regularized": rms_scale.item(), + "bias_abs": bias.item(), + "bias_rel": (bias / rms_scale).item(), + "correlation": correlation, + "slope": slope, + "residual_rms_abs": residual_rms.item(), + "residual_rms_rel": (residual_rms / rms_scale).item(), } def compare_tensors(self, tensor_ref, tensor_test, errors, step_name, tensor_name): diff --git a/fast_llm/layers/language_model/loss/chosen_logprob.py b/fast_llm/layers/language_model/loss/chosen_logprob.py new file mode 100644 index 000000000..cb99e7c17 --- /dev/null +++ b/fast_llm/layers/language_model/loss/chosen_logprob.py @@ -0,0 +1,41 @@ +import math +import typing + +import torch + +from fast_llm.layers.language_model.loss.config import LanguageModelChosenLogprobLossConfig +from fast_llm.layers.language_model.loss.loss import LanguageModelLoss +from fast_llm.logging import log_tensor + + +class LanguageModelChosenLogprobLoss[ConfigType: LanguageModelChosenLogprobLossConfig](LanguageModelLoss[ConfigType]): + """Logs log π(label) per position via the tensor-log pipeline; contributes nothing to gradients.""" + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + # Don't surface a "chosen_logprob: 0" line in the training metrics. + self._do_register_loss = False + + def _forward_backward( + self, + logits: "torch.Tensor", + kwargs: dict[str, typing.Any], + losses: dict | None = None, + split_index: int = 0, + grad_logits: torch.Tensor | None = None, + ) -> "tuple[torch.Tensor, torch.Tensor | None]": + if self._vocab_parallel: + raise NotImplementedError("chosen_logprob loss does not support vocab parallel") + labels = self._get_labels(kwargs, split_index).reshape(-1).long() + with torch.no_grad(): + log_probs = torch.log_softmax(logits.float() * self._logits_scale_factor, dim=-1) + # Mask out-of-range labels (e.g. -100 for prompt tokens in RL data) before gather to + # avoid CUDA assert. Fast-LLM convention: any label < 0 is masked. + valid = labels >= 0 + safe_labels = labels.clamp(min=0) + chosen_logprob = log_probs.gather(-1, safe_labels.unsqueeze(-1)).squeeze(-1) + chosen_logprob = chosen_logprob[valid] + # Capture the full tensor: bias is the mean over all positions, not a sampled subset. + level = math.ceil(math.log2(max(chosen_logprob.numel(), 1))) + 3 + log_tensor(f"Global : {self._name}", chosen_logprob, level=level) + return torch.zeros((), dtype=logits.dtype, device=logits.device), grad_logits diff --git a/fast_llm/layers/language_model/loss/config.py b/fast_llm/layers/language_model/loss/config.py index 9a220aacf..aa05fbb9a 100644 --- a/fast_llm/layers/language_model/loss/config.py +++ b/fast_llm/layers/language_model/loss/config.py @@ -9,6 +9,7 @@ from fast_llm.utils import Assert if typing.TYPE_CHECKING: + from fast_llm.layers.language_model.loss.chosen_logprob import LanguageModelChosenLogprobLoss from fast_llm.layers.language_model.loss.dpo import LanguageModelDPOLoss from fast_llm.layers.language_model.loss.entropy_loss import ( LanguageModelDistillationLoss, @@ -186,6 +187,30 @@ def get_reference_models(self) -> set[str]: return {self.reference_model} +@config_class(dynamic_type={LanguageModelLossConfig: "chosen_logprob"}) +class LanguageModelChosenLogprobLossConfig(LanguageModelLossConfig): + """No-gradient diagnostic loss that logs log π(label) per position via the tensor-log pipeline. + + The chosen-token log-prob is the scalar that policy-gradient importance ratios depend on, + so its precision drift is a more direct signal than bulk-logit RMS. + """ + + _abstract: typing.ClassVar[bool] = False + + weight: float = Field( + default=0.0, + hint=FieldHint.derived, + desc="Forced to 0: this loss has no gradient contribution.", + valid=check_field(Assert.eq, 0.0), + ) + + @property + def loss_class(self) -> "type[LanguageModelChosenLogprobLoss]": + from fast_llm.layers.language_model.loss.chosen_logprob import LanguageModelChosenLogprobLoss + + return LanguageModelChosenLogprobLoss + + @config_class(dynamic_type={LanguageModelLossConfig: "z_loss"}) class LanguageModelZLossConfig(LanguageModelLossConfig): """Z-loss regularization to prevent overconfidence.""" diff --git a/tools/evaluate_precision.py b/tools/evaluate_precision.py index 02131cd6a..9da8904a1 100644 --- a/tools/evaluate_precision.py +++ b/tools/evaluate_precision.py @@ -28,6 +28,23 @@ # `log_tensor` (samples = 2 ** (level - 3) -> level 23 yields ~1M samples per tensor). _SPARSE_GRAD_LEVEL = 23 _SPARSE_GRAD_OVERRIDES = {r"Global gradient: embeddings\.": _SPARSE_GRAD_LEVEL} +_CHOSEN_LOGPROB_NAME = "chosen_logprob" +# Auto-calibration of the constant gradient scaler. Each variant runs a calibration pass at +# `scale=1` (no overflow risk), then the actual run uses the largest power-of-2 scale that +# keeps logged gradient magnitudes (and a small safety factor for hidden in-kernel +# intermediates like norm partial sums) within fp16's representable range. Per-variant +# unscaling at compare time lets different variants pick different scales without polluting +# the relative metrics. +_HIDDEN_INTERMEDIATE_HEADROOM = 4.0 # safety factor for fused-kernel partial sums we don't log +_CALIBRATION_SUBDIR_PREFIX = ".calibration_" +# Variant-override keys starting with this prefix are interpreted as `torch.backends.` and +# applied before each run. Used for diagnostics (e.g. enabling bf16 reduced-precision reductions); +# entries are listed in `_TORCH_BACKEND_DEFAULTS` and reset to their defaults before applying. +_TORCH_BACKEND_PREFIX = "_torch_backend." +_TORCH_BACKEND_DEFAULTS = { + "cuda.matmul.allow_bf16_reduced_precision_reduction": False, +} +_TORCH_MATMUL_PRECISION_KEY = "_torch_matmul_precision" @config_class() @@ -57,14 +74,10 @@ class EvaluatePrecisionConfig(PretrainedGPTModelConfig, RunnableConfig): " `TensorLogsConfig.sample_level_overrides`.", hint=FieldHint.feature, ) - micro_batch_size: int = Field( - default=1, - desc="Micro-batch size for the single forward+backward pass.", - hint=FieldHint.feature, - ) sequence_length: int = Field( default=2048, - desc="Sequence length (maximum document length) for the random input.", + desc="Sequence length per micro-batch sample. Drives both `data.micro_batch_size` (the" + " per-sample token count, despite the name) and `data.maximum_document_length`.", hint=FieldHint.feature, ) data_path: pathlib.Path | None = Field( @@ -88,20 +101,50 @@ def run(self) -> None: self._prepare_data() runs: dict[str, dict[str, typing.Any]] = {_REFERENCE_NAME: {}} runs.update(self.variants) + scales: dict[str, float] = {} for name, variant_overrides in runs.items(): - self._run_one(name, variant_overrides) + scales[name] = self._calibrate_and_run(name, variant_overrides) ref_artifacts = self._artifact_path(_REFERENCE_NAME) - results = {name: self._compare(ref_artifacts, self._artifact_path(name)) for name in self.variants} + results = { + name: self._compare(ref_artifacts, self._artifact_path(name), scales[_REFERENCE_NAME], scales[name]) + for name in self.variants + } report_path = self.output_dir / "precision_report.json" - report_path.write_text(json.dumps(results, indent=2)) + report_path.write_text(json.dumps({"scales": scales, "variants": results}, indent=2)) logger.info(f"Wrote report to {report_path}") + logger.info(f"Per-variant gradient scales: {scales}") for name, rows in results.items(): _print_table(name, rows) _print_summary(results) + def _calibrate_and_run(self, name: str, variant_overrides: dict[str, typing.Any]) -> float: + """Pick a power-of-2 gradient scale for this variant via a calibration pass, then run with it. + + Calibration runs with `constant=1.0` so no overflow is possible; scanning logged gradients + then gives us `max_unscaled`. The largest safe power of 2 keeps `scale * max_unscaled` below + `fp16_max / hidden_intermediate_budget`, where the budget reserves headroom for partial sums + inside fused kernels (e.g. norm-weight grads sum over the sequence dimension). + """ + import torch + + cal_dir = self.output_dir / f"{_CALIBRATION_SUBDIR_PREFIX}{name}" + self._run_one(name, variant_overrides, constant_scale=1.0, experiment_dir=cal_dir) + max_unscaled = _scan_max_grad(cal_dir / "runs" / "0" / "artifacts") + shutil.rmtree(cal_dir) + if max_unscaled <= 0.0: + scale = 1.0 + logger.warning(f"[{name}] calibration found no nonzero gradient — falling back to scale=1.0") + else: + fp16_max = torch.finfo(torch.float16).max + optimal_unrounded = fp16_max / max_unscaled / _HIDDEN_INTERMEDIATE_HEADROOM + scale = float(2 ** max(0, math.floor(math.log2(optimal_unrounded)))) + logger.info(f"[{name}] calibration: max_unscaled={max_unscaled:.4e} -> gradient_scaler.constant={scale:g}") + self._run_one(name, variant_overrides, constant_scale=scale) + return scale + def _prepare_data(self) -> None: if self.data_path is None: return @@ -122,15 +165,28 @@ def _prepare_data(self) -> None: def _artifact_path(self, name: str) -> pathlib.Path: return self.output_dir / name / "runs" / "0" / "artifacts" - def _run_one(self, name: str, variant_overrides: dict[str, typing.Any]) -> None: + def _run_one( + self, + name: str, + variant_overrides: dict[str, typing.Any], + *, + constant_scale: float | None = None, + experiment_dir: pathlib.Path | None = None, + ) -> None: # The trainer's Run picks the next `runs/` subdir based on what already exists; wipe # any prior contents so each invocation lands in `runs/0` and stale artifacts can't be # read by `_artifact_path` below. - experiment_dir = self.output_dir / name + if experiment_dir is None: + experiment_dir = self.output_dir / name if experiment_dir.exists(): shutil.rmtree(experiment_dir) # Base config: hardcoded training/optimizer/data/run skeleton plus the user's model/pretrained. # Forced fp32 on the reference baseline lives in here too so a variant can override it. + optimizer_config: dict[str, typing.Any] = { + "learning_rate": {"base": 0.0, "decay_style": "constant", "warmup_iterations": 0}, + } + if constant_scale is not None: + optimizer_config["gradient_scaler"] = {"constant": float(constant_scale)} base_dict: dict[str, typing.Any] = { "pretrained": self.pretrained.to_dict(), "model": self.model.to_dict(), @@ -139,9 +195,7 @@ def _run_one(self, name: str, variant_overrides: dict[str, typing.Any]) -> None: "num_workers": 0, "logs": {"interval": 1}, }, - "optimizer": { - "learning_rate": {"base": 0.0, "decay_style": "constant", "warmup_iterations": 0}, - }, + "optimizer": optimizer_config, "data": { "datasets": { "training": ( @@ -150,11 +204,13 @@ def _run_one(self, name: str, variant_overrides: dict[str, typing.Any]) -> None: else {"type": "random"} ) }, - "micro_batch_size": self.micro_batch_size, + # Despite the name, Fast-LLM's `data.micro_batch_size` is the per-sample sequence + # length, not the batch dimension. Default 2048 → 2048-token sample. + "micro_batch_size": self.sequence_length, "maximum_document_length": self.sequence_length, }, "run": { - "experiment_dir": str((self.output_dir / name).resolve()), + "experiment_dir": str(experiment_dir.resolve()), "tensor_logs": { "save": True, "show": False, @@ -168,16 +224,36 @@ def _run_one(self, name: str, variant_overrides: dict[str, typing.Any]) -> None: ("model", "distributed", "compute_dtype"): "float32", ("model", "distributed", "optimization_dtype"): "float32", } - variant_updates = {tuple(key.split(".")): value for key, value in variant_overrides.items()} + # Split off torch-backend overrides before passing the rest to Fast-LLM's config system. + backend_overrides = { + key[len(_TORCH_BACKEND_PREFIX) :]: value + for key, value in variant_overrides.items() + if key.startswith(_TORCH_BACKEND_PREFIX) + } + _apply_torch_backend_overrides(backend_overrides) + matmul_precision = variant_overrides.get(_TORCH_MATMUL_PRECISION_KEY, "highest") + _apply_torch_matmul_precision(matmul_precision) + variant_updates = { + tuple(key.split(".")): value + for key, value in variant_overrides.items() + if not key.startswith(_TORCH_BACKEND_PREFIX) and key != _TORCH_MATMUL_PRECISION_KEY + } # Tool-required overrides win over variants — a variant must not silently disable tensor logging. - tool_overrides = { + tool_overrides: dict[tuple[str, ...], typing.Any] = { ("model", "multi_stage", "debug_layer_outputs"): log_level, ("model", "multi_stage", "debug_layer_gradients"): log_level, ("model", "multi_stage", "debug_all_param_gradients"): log_level, # Capture the LM-head logits via the `output_hidden_states` mechanism: the head's # `_debug(logits, ...)` call matches this pattern and emits to `tensor_logs`. ("model", "multi_stage", "debug_hidden_states_log"): [r"head\.logits"], + # Diagnostic loss that logs log π(label) per position via the tensor-log pipeline. + # Contributes no gradient (weight=0); the comparison code picks it up by name. + ("model", "base_model", "head", "losses", _CHOSEN_LOGPROB_NAME): {"type": "chosen_logprob"}, } + # When the user hasn't configured any loss, the head defaults to cross-entropy. Adding a + # loss explicitly suppresses that default, so re-add it so gradients still flow. + if not (self.model.base_model.head.losses or {}): + tool_overrides[("model", "base_model", "head", "losses", "cross_entropy")] = {"type": "label"} logger.info(f"=== Running {name!r} ===") if variant_overrides: logger.info(f"Variant overrides: {variant_overrides}") @@ -186,13 +262,23 @@ def _run_one(self, name: str, variant_overrides: dict[str, typing.Any]) -> None: trainer_config.configure_logging() trainer_config._get_runnable()() - def _compare(self, ref_path: pathlib.Path, test_path: pathlib.Path) -> list[dict[str, typing.Any]]: + def _compare( + self, + ref_path: pathlib.Path, + test_path: pathlib.Path, + ref_scale: float, + test_scale: float, + ) -> list[dict[str, typing.Any]]: compare_config = CompareConfig() errors: list[str] = [] ref_logs = compare_config._extract_tensor_logs(ref_path, errors) test_logs = compare_config._extract_tensor_logs(test_path, errors) for error in errors: logger.warning(error) + # Each variant's gradient logs are scaled by its own `constant` factor (auto-calibrated). + # Undo per-variant scaling so the relative comparison reflects unscaled gradient diffs. + _unscale_gradients_in_place(ref_logs, ref_scale) + _unscale_gradients_in_place(test_logs, test_scale) rows: list[dict[str, typing.Any]] = [] for step_name in sorted(ref_logs): if step_name not in test_logs: @@ -218,6 +304,66 @@ def _compare(self, ref_path: pathlib.Path, test_path: pathlib.Path) -> list[dict return rows +def _is_gradient_like(tensor_name: str) -> bool: + # Anything affected by the loss-scaling multiplier: parameter gradients from `Fsdp.log_shard`, + # backward activations from layer hooks, and explicit `.grad` debug entries (e.g. logits.grad). + return ("gradient:" in tensor_name) or (" bw" in tensor_name) or (".grad" in tensor_name) + + +def _scan_max_grad(artifact_path: pathlib.Path) -> float: + max_abs = 0.0 + compare_config = CompareConfig() + errors: list[str] = [] + logs = compare_config._extract_tensor_logs(artifact_path, errors) + for step_logs in logs.values(): + for tensor_name, entry in step_logs.items(): + if not _is_gradient_like(tensor_name): + continue + # Saved stats include min/max; fall back to samples if absent. + if "max" in entry and "min" in entry: + value = max(abs(float(entry["max"])), abs(float(entry["min"]))) + else: + value = float(entry["samples"].abs().max().item()) + if math.isfinite(value) and value > max_abs: + max_abs = value + return max_abs + + +def _unscale_gradients_in_place(logs: dict, scale: float) -> None: + if scale == 1.0: + return + inv = 1.0 / scale + for step_logs in logs.values(): + for tensor_name, entry in step_logs.items(): + if not _is_gradient_like(tensor_name): + continue + entry["samples"] = entry["samples"].float() * inv + for key in ("min", "max", "mu", "std"): + if key in entry and entry[key] is not None: + entry[key] = float(entry[key]) * inv + + +def _apply_torch_backend_overrides(overrides: dict[str, typing.Any]) -> None: + import torch + + unknown = set(overrides) - set(_TORCH_BACKEND_DEFAULTS) + if unknown: + logger.warning(f"Unknown torch backend overrides (ignored): {sorted(unknown)}") + for path, default in _TORCH_BACKEND_DEFAULTS.items(): + value = overrides.get(path, default) + obj: typing.Any = torch.backends + parts = path.split(".") + for part in parts[:-1]: + obj = getattr(obj, part) + setattr(obj, parts[-1], value) + + +def _apply_torch_matmul_precision(precision: str) -> None: + import torch + + torch.set_float32_matmul_precision(precision) + + def _layer_name(tensor_name: str) -> str: # Stage hooks name tensors `Global fw: ...` / `Global bw: ...`; # Fsdp.log_shard names weight gradients `Global gradient: `. @@ -260,6 +406,40 @@ def _print_summary(results: dict[str, list[dict[str, typing.Any]]]) -> None: aggs_per_kind = {"fw": fw_aggs, "bw": bw_aggs, "grad": grad_aggs} for kind in ("fw", "bw", "grad"): _print_summary_table(results, kind, aggs_per_kind[kind]) + if _named_row(sample, _CHOSEN_LOGPROB_NAME) is not None: + _print_chosen_logprob_summary(results) + + +def _print_chosen_logprob_summary(results: dict[str, list[dict[str, typing.Any]]]) -> None: + rows_by_variant = {name: _named_row(rows, _CHOSEN_LOGPROB_NAME) for name, rows in results.items()} + # log π(label) is the scalar that policy-gradient importance ratios depend on. Bias persists + # under per-document averaging where RMS shrinks ~1/√T, so for RL stability it's the more + # informative signal — surface it alongside RMS, slope and residual. + rms_rel_decimals = _column_decimals((r["rms_rel"] for r in rows_by_variant.values()), default=3, max_decimals=5) + bias_rel_decimals = _column_decimals((r["bias_rel"] for r in rows_by_variant.values()), default=3, max_decimals=5) + resid_rel_decimals = _column_decimals( + (r["residual_rms_rel"] for r in rows_by_variant.values()), default=3, max_decimals=5 + ) + name_width = max((len(name) for name in results), default=7) + 1 + cols = [ + ("RMS rel", lambda r: f"{r['rms_rel'] * 100:.{rms_rel_decimals}f}%"), + ("Bias rel", lambda r: f"{r['bias_rel'] * 100:+.{bias_rel_decimals}f}%"), + ("Resid rel", lambda r: f"{r['residual_rms_rel'] * 100:.{resid_rel_decimals}f}%"), + ("Corr", lambda r: f"{r['correlation']:.5f}"), + ("Slope", lambda r: f"{r['slope']:+.5f}"), + ("Max abs", lambda r: f"{r['max_abs']:.4g}"), + ("Scale", lambda r: f"{r['ref_scale']:.4g}"), + ] + widths = [max(len(label), max(len(fn(r)) for r in rows_by_variant.values())) for label, fn in cols] + print(f"\n=== Summary: chosen_logprob (per-token) ===") + header = f"{'Variant':<{name_width}}" + " ".join( + f"{label:<{w}}" for (label, _), w in zip(cols, widths, strict=True) + ) + print(header) + print("-" * len(header)) + for name, row in rows_by_variant.items(): + cells = [fn(row) for _, fn in cols] + print(f"{name:<{name_width}}" + " ".join(f"{c:<{w}}" for c, w in zip(cells, widths, strict=True))) def _grad_category(tensor_name: str) -> str: @@ -420,10 +600,14 @@ def _print_table(name: str, rows: list[dict[str, typing.Any]]) -> None: # (typical for weight gradients) stay legible, capped at 5 to bound column width. relative_decimals = _column_decimals((r["rms_rel"] for r in rows), default=2, max_decimals=5) relative_fn = lambda r: f"{r['rms_rel'] * 100:.{relative_decimals}f}%" + bias_decimals = _column_decimals((r["bias_rel"] for r in rows), default=2, max_decimals=5) + bias_fn = lambda r: f"{r['bias_rel'] * 100:+.{bias_decimals}f}%" relative_width = max(len("Relative"), max(len(relative_fn(r)) for r in rows)) + bias_width = max(len("Bias"), max(len(bias_fn(r)) for r in rows)) columns: list[tuple[str, int, typing.Callable[[dict[str, typing.Any]], str]]] = [ ("Tensor", name_width, name_fn), ("Relative", relative_width, relative_fn), + ("Bias", bias_width, bias_fn), ("Absolute", 10, lambda r: f"{r['rms_abs']:.4g}"), ("Max", 10, lambda r: f"{r['max_abs']:.4g}"), ("Scale", 10, lambda r: f"{r['ref_scale']:.4g}"), From 497c76cd07c75259a213189f10e257eea85c025c Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 1 Jun 2026 12:37:09 -0400 Subject: [PATCH 32/41] Lean fixed-input runner + DeepSpeed-side precision comparison MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace the Trainer + data-loading path in tools/evaluate_precision.py with a lean forward+backward runner (InferenceRunner-style: model + ScheduleRunner + lr-0 optimizer, training-phase schedule) fed a fixed, already-preprocessed input. This lets the model see an exact token tensor (the data pipeline would re-randomize the model input via shuffle/packing) and drops the training/data-loading infrastructure the tool doesn't need — which also fixes the GPU-memory accumulation that OOM'd on larger models (each run's model+optimizer is now freed). The input is built once (configurable input_text_file -> tokenized, or uniform-random) and saved to output_dir/input_ids.pt so the DeepSpeed-side tool can consume byte-identical model input. Add tools/evaluate_precision_deepspeed.py: the HF-transformers + DeepSpeed counterpart, mirroring PipelineRL's proven fp32-lm-head and log-pi computation, reporting the same chosen-logprob and categorized-gradient metrics so Fast-LLM's bf16 precision pattern can be benchmarked against DeepSpeed's. fp16 gradients use loss scaling to avoid underflow. Add examples/evaluate_precision/qwen.yaml and sample_text.txt for the Qwen2.5-0.5B comparison. Co-Authored-By: Claude Opus 4.8 (1M context) --- examples/evaluate_precision/qwen.yaml | 25 ++ examples/evaluate_precision/sample_text.txt | 23 ++ tools/evaluate_precision.py | 137 +++++++--- tools/evaluate_precision_deepspeed.py | 281 ++++++++++++++++++++ 4 files changed, 431 insertions(+), 35 deletions(-) create mode 100644 examples/evaluate_precision/qwen.yaml create mode 100644 examples/evaluate_precision/sample_text.txt create mode 100644 tools/evaluate_precision_deepspeed.py diff --git a/examples/evaluate_precision/qwen.yaml b/examples/evaluate_precision/qwen.yaml new file mode 100644 index 000000000..9a28e270d --- /dev/null +++ b/examples/evaluate_precision/qwen.yaml @@ -0,0 +1,25 @@ +# Precision-evaluation config on Qwen2.5-0.5B — the model used for the Fast-LLM vs DeepSpeed +# precision-pattern comparison (DeepSpeed side: tools/evaluate_precision_deepspeed.py). +# +# Run with: +# python -m tools.evaluate_precision -c examples/evaluate_precision/qwen.yaml +pretrained: + path: Qwen/Qwen2.5-0.5B + format: qwen2 +output_dir: /tmp/fast_llm_tests/evaluate_precision/qwen_features +sequence_length: 2048 +variants: + # Maps to the DeepSpeed harness's `bf16_head_bf16` (compute bf16, lm head in compute dtype). + bf16: + model.distributed.compute_dtype: bfloat16 + # Maps to the DeepSpeed harness's `bf16` (compute bf16, fp32 lm head — the stack default). + bf16_fp32_lm_head: + model.distributed.compute_dtype: bfloat16 + model.base_model.head.fp32_lm_head: true + # Maps to the DeepSpeed harness's `fp16_head_fp16`. + fp16: + model.distributed.compute_dtype: float16 + # Maps to the DeepSpeed harness's `fp16`. + fp16_fp32_lm_head: + model.distributed.compute_dtype: float16 + model.base_model.head.fp32_lm_head: true diff --git a/examples/evaluate_precision/sample_text.txt b/examples/evaluate_precision/sample_text.txt new file mode 100644 index 000000000..8b207173d --- /dev/null +++ b/examples/evaluate_precision/sample_text.txt @@ -0,0 +1,23 @@ +The history of computing is often told as a story of ever-smaller and ever-faster machines, but the more interesting thread is the slow accumulation of good abstractions. Early programmers spoke directly to the hardware, toggling switches and rewiring panels, and every problem had to be solved in the vocabulary of the machine in front of them. The arrival of assembly language, and then of compiled languages, did not make the computers any faster; it made the programmers faster, because it let them think in terms closer to the problem and further from the circuitry. Each new layer hid a mess of detail beneath a clean interface, and each clean interface freed the people above it to build something larger than the layer below could have imagined. + +Numerical computation followed the same pattern, though its abstractions were mathematical rather than mechanical. The first scientific programs tracked every digit by hand, and a single rounding decision could quietly ruin a long calculation. Floating point arithmetic was a hard-won compromise: it traded a little accuracy for an enormous gain in range and convenience, and it came with rules subtle enough that careful engineers spent entire careers studying them. The promise was never that the answers would be exact, only that the errors would be small and, more importantly, predictable. A method whose errors stay bounded and behave smoothly is far more useful than one that is occasionally perfect and occasionally catastrophic, because predictability is what lets you reason about a system you cannot fully observe. + +This distinction between bounded error and occasional disaster runs through the whole of engineering. A bridge is not designed to bear exactly the load it will encounter; it is designed with margins, so that the inevitable surprises fall inside a region the designer has already considered. Software that processes real data is no different. The inputs will be messier than the specification promised, the edge cases will arrive in combinations nobody enumerated, and the only durable defense is to build systems whose failure modes are gentle. A program that degrades gracefully under unexpected input is worth more than one that is flawless on the cases its author happened to imagine, because the world is under no obligation to supply only imaginable cases. + +Modern machine learning lives squarely inside this tradition, even when its practitioners do not describe it that way. Training a large model means multiplying enormous matrices billions of times, and the precision of each multiplication is a design choice rather than a fixed fact of nature. Lower precision means smaller numbers to move and faster hardware to move them, but it also means coarser rounding, and the central question is always whether that rounding stays in the harmless regime or crosses into the dangerous one. The answer depends on the model, the data, and the particular sequence of operations involved, which is exactly why it has to be measured rather than assumed. Intuition about numerical behavior is notoriously unreliable at scale, where quantities interact in ways that small examples never reveal. + +Consider what happens to a single number as it flows through a deep network. It begins as an input, is scaled and shifted and combined with thousands of its neighbors, passes through a nonlinearity, and emerges as part of the input to the next layer, where the whole process repeats. By the time it reaches the final layer it has been transformed dozens of times, and any error introduced early has had dozens of opportunities to grow or shrink. Sometimes these errors cancel, averaging out across many independent contributions; sometimes they reinforce, when the same systematic bias is applied at every step. The difference between these two fates is the difference between a model that trains stably and one that diverges for reasons its authors struggle to explain. + +The output layer deserves special attention, because it is where the model finally commits to a prediction. Up to that point the internal representations are abstract and somewhat forgiving; small perturbations shift them a little without changing their meaning. But the final projection turns those representations into concrete scores over a large vocabulary, and those scores are then exponentiated and normalized into probabilities. Exponentiation is unforgiving of additive error: a small shift in a score becomes a multiplicative change in a probability, and a small change in a probability can flip a decision. This is why the precision of the last step is often discussed out of proportion to its share of the total computation. It is not that the last matrix multiply is expensive; it is that it sits at the most sensitive point in the pipeline. + +Yet sensitivity at a single point does not automatically translate into importance for the whole. If the representation arriving at that point already carries substantial error from everything upstream, then cleaning up only the final step yields little, because the dominant error was introduced earlier and is simply passed through. The benefit of high precision at the output is largest exactly when the rest of the pipeline is already clean, and smallest when the upstream is noisy. This is a general principle of error analysis that beginners frequently miss: the value of fixing one stage depends entirely on whether that stage is the bottleneck, and the bottleneck is rarely where attention is first drawn. + +There is a further subtlety, which is that the magnitude of the quantities involved changes how much a fixed rounding error matters in relative terms. When a model is confident, the score it assigns to the chosen outcome is close to the maximum, the corresponding log probability is close to zero, and a small absolute error in that log probability is a large fraction of its tiny value. When a model is uncertain, spreading its belief across many outcomes, the same log probability is a large negative number, and the identical absolute error is a negligible fraction of it. The relative importance of a rounding step therefore depends not only on where it sits in the pipeline but on the regime the model is operating in, which is set by the data it happens to be processing at that moment. + +This is why measurements that look contradictory are often perfectly consistent once the regime is accounted for. A change that appears to make no difference on one dataset can make a clear difference on another, not because the underlying arithmetic changed, but because the quantities being rounded shifted from one regime to the other. An honest investigation reports both results and the condition that distinguishes them, rather than picking whichever supports a tidy story. The condition is the finding; the individual numbers are only evidence for it. + +Reinforcement learning from human feedback adds yet another layer to this picture, because it compares the behavior of two systems rather than examining one in isolation. A model generates text under one implementation and is then evaluated under another, and the learning signal depends on the ratio between the probabilities the two implementations assign to the same tokens. If the two implementations agree, the ratio is near one and the signal is clean; if they disagree systematically, the ratio carries a bias that no amount of careful optimization can remove, because it is baked into the comparison itself. The danger here is not random noise, which averages away over many samples, but systematic disagreement, which does not. Two correct-looking systems can still disagree in a way that quietly corrupts everything built on top of their comparison. + +The practical lesson is that matching matters more than absolute accuracy in this setting. It is better for two systems to be wrong in the same way than for one to be right and the other wrong, because a shared error cancels in the ratio while an unshared one does not. This inverts the usual intuition, which prizes accuracy above all. It explains why engineers sometimes deliberately make a fast system reproduce the quirks of a slow one, rather than improving it, and why a change that improves a system in isolation can hurt the larger pipeline it lives in if it breaks an agreement that other parts relied upon. Consistency is a feature, even when it is consistency in imperfection. + +All of this argues for a particular discipline: measure the thing you actually care about, under the conditions it will actually face, and report the conditions alongside the numbers. Good measurement, like a good abstraction, is what lets us trust the layers we cannot see. It does not eliminate uncertainty, but it bounds it, and a bounded uncertainty is something an engineer can build on. The goal is never to pretend the errors are gone. The goal is to know how large they are, where they come from, and whether they stay in the gentle regime or threaten to cross into the steep one where small causes produce large and unwelcome effects. diff --git a/tools/evaluate_precision.py b/tools/evaluate_precision.py index 9da8904a1..48e67500c 100644 --- a/tools/evaluate_precision.py +++ b/tools/evaluate_precision.py @@ -29,6 +29,8 @@ _SPARSE_GRAD_LEVEL = 23 _SPARSE_GRAD_OVERRIDES = {r"Global gradient: embeddings\.": _SPARSE_GRAD_LEVEL} _CHOSEN_LOGPROB_NAME = "chosen_logprob" +# Seed for the random-token fixed input when no input text file is given. +_INPUT_SEED = 0 # Auto-calibration of the constant gradient scaler. Each variant runs a calibration pass at # `scale=1` (no overflow risk), then the actual run uses the largest power-of-2 scale that # keeps logged gradient magnitudes (and a small safety factor for hidden in-kernel @@ -80,11 +82,12 @@ class EvaluatePrecisionConfig(PretrainedGPTModelConfig, RunnableConfig): " per-sample token count, despite the name) and `data.maximum_document_length`.", hint=FieldHint.feature, ) - data_path: pathlib.Path | None = Field( + input_text_file: pathlib.Path | None = Field( default=None, - desc="If set, prepare a tokenized memmap dataset with advantages and `old_log_probabilities`" - " at this path (using the test helper `_get_test_dataset`) and use it as the training" - " input — required for policy-gradient losses like GSPO/GRPO. If unset, uses random tokens.", + desc="If set, tokenize this text file (via the pretrained tokenizer) to build the fixed model" + " input, tiled/truncated to `sequence_length`. If unset, the input is uniform-random token ids." + " The exact input tensor is saved to `output_dir/input_ids.pt` so the DeepSpeed-side tool" + " (`tools/evaluate_precision_deepspeed.py`) can consume the identical model input.", hint=FieldHint.feature, ) @@ -98,12 +101,12 @@ def _validate(self) -> None: def run(self) -> None: self.output_dir.mkdir(parents=True, exist_ok=True) - self._prepare_data() + input_ids = self._prepare_input_ids() runs: dict[str, dict[str, typing.Any]] = {_REFERENCE_NAME: {}} runs.update(self.variants) scales: dict[str, float] = {} for name, variant_overrides in runs.items(): - scales[name] = self._calibrate_and_run(name, variant_overrides) + scales[name] = self._calibrate_and_run(name, variant_overrides, input_ids) ref_artifacts = self._artifact_path(_REFERENCE_NAME) results = { @@ -120,7 +123,9 @@ def run(self) -> None: _print_table(name, rows) _print_summary(results) - def _calibrate_and_run(self, name: str, variant_overrides: dict[str, typing.Any]) -> float: + def _calibrate_and_run( + self, name: str, variant_overrides: dict[str, typing.Any], input_ids: "torch.Tensor" + ) -> float: """Pick a power-of-2 gradient scale for this variant via a calibration pass, then run with it. Calibration runs with `constant=1.0` so no overflow is possible; scanning logged gradients @@ -131,7 +136,7 @@ def _calibrate_and_run(self, name: str, variant_overrides: dict[str, typing.Any] import torch cal_dir = self.output_dir / f"{_CALIBRATION_SUBDIR_PREFIX}{name}" - self._run_one(name, variant_overrides, constant_scale=1.0, experiment_dir=cal_dir) + self._run_one(name, variant_overrides, input_ids, constant_scale=1.0, experiment_dir=cal_dir) max_unscaled = _scan_max_grad(cal_dir / "runs" / "0" / "artifacts") shutil.rmtree(cal_dir) if max_unscaled <= 0.0: @@ -142,25 +147,32 @@ def _calibrate_and_run(self, name: str, variant_overrides: dict[str, typing.Any] optimal_unrounded = fp16_max / max_unscaled / _HIDDEN_INTERMEDIATE_HEADROOM scale = float(2 ** max(0, math.floor(math.log2(optimal_unrounded)))) logger.info(f"[{name}] calibration: max_unscaled={max_unscaled:.4e} -> gradient_scaler.constant={scale:g}") - self._run_one(name, variant_overrides, constant_scale=scale) + self._run_one(name, variant_overrides, input_ids, constant_scale=scale) return scale - def _prepare_data(self) -> None: - if self.data_path is None: - return - if (self.data_path / "fast_llm_config.yaml").is_file(): - return - # Couples `tools/` to `tests/utils/` for now — extract later if it sticks. - from tests.utils.dataset import _get_test_dataset - - self.data_path.mkdir(parents=True, exist_ok=True) - logger.info(f"Preparing memmap dataset at {self.data_path}") - _get_test_dataset( - self.data_path, - seed=42, - has_grpo_data=True, - max_vocab_size=self.model.base_model.embeddings.vocab_size, - ) + def _prepare_input_ids(self) -> "torch.Tensor": + """Build the fixed model input once and save it so the DeepSpeed-side tool feeds the exact + same tokens. Going through Fast-LLM's data pipeline would re-randomize the model input + (shuffle/packing), so the input is constructed directly here and fed verbatim to the runner.""" + import torch + + vocab_size = self.model.base_model.embeddings.vocab_size + if self.input_text_file is not None: + import transformers + + tokenizer = transformers.AutoTokenizer.from_pretrained(str(self.pretrained.path)) + ids = tokenizer(self.input_text_file.read_text(), return_tensors="pt").input_ids[0] + if ids.numel() < self.sequence_length: + ids = ids.repeat((self.sequence_length + ids.numel() - 1) // ids.numel()) + ids = ids[: self.sequence_length].to(torch.int64) + else: + generator = torch.Generator().manual_seed(_INPUT_SEED) + ids = torch.randint(0, vocab_size, (self.sequence_length,), generator=generator, dtype=torch.int64) + input_ids = ids.unsqueeze(0) + path = self.output_dir / "input_ids.pt" + torch.save(input_ids, path) + logger.info(f"Shared model input: {tuple(input_ids.shape)} saved to {path}") + return input_ids def _artifact_path(self, name: str) -> pathlib.Path: return self.output_dir / name / "runs" / "0" / "artifacts" @@ -169,6 +181,7 @@ def _run_one( self, name: str, variant_overrides: dict[str, typing.Any], + input_ids: "torch.Tensor", *, constant_scale: float | None = None, experiment_dir: pathlib.Path | None = None, @@ -196,16 +209,11 @@ def _run_one( "logs": {"interval": 1}, }, "optimizer": optimizer_config, + # The lean runner feeds a fixed input directly and ignores this dataset; it's only here so + # the TrainerConfig validates. Despite the name, `data.micro_batch_size` is the per-sample + # sequence length, not the batch dimension. "data": { - "datasets": { - "training": ( - {"type": "file", "path": str(self.data_path / "fast_llm_config.yaml")} - if self.data_path is not None - else {"type": "random"} - ) - }, - # Despite the name, Fast-LLM's `data.micro_batch_size` is the per-sample sequence - # length, not the batch dimension. Default 2048 → 2048-token sample. + "datasets": {"training": {"type": "random"}}, "micro_batch_size": self.sequence_length, "maximum_document_length": self.sequence_length, }, @@ -260,7 +268,7 @@ def _run_one( trainer_class = TrainerConfig.get_subclass(_MODEL_TYPE) trainer_config = trainer_class.from_dict(base_dict, fp32_dtypes, variant_updates, tool_overrides) trainer_config.configure_logging() - trainer_config._get_runnable()() + _run_fixed_input(trainer_config, input_ids, self.sequence_length) def _compare( self, @@ -304,6 +312,65 @@ def _compare( return rows +def _run_fixed_input(config, input_ids, sequence_length: int) -> None: + """Lean forward+backward on a fixed, already-preprocessed input — like `InferenceRunner` but with a + training-phase schedule + an (lr-0) optimizer so `run_step` runs the backward and the existing + chosen-logprob loss / `debug_all_param_gradients` logging captures everything. Replaces the trainer + + data pipeline so the model sees exactly `input_ids` (the pipeline would re-randomize it) and so the + tool stops paying for training/data-loading infrastructure it doesn't need.""" + import gc + + import torch + + from fast_llm.data.document.language_model import LanguageModelBatch + from fast_llm.engine.distributed.config import PhaseType + from fast_llm.engine.distributed.distributed import Distributed + from fast_llm.engine.multi_stage.config import StageMode + from fast_llm.engine.optimizer.config import ParamGroup + from fast_llm.engine.schedule.runner import ScheduleRunner + from fast_llm.engine.schedule.schedule import Schedule + + distributed = Distributed(config.model.distributed) + run = config.get_run(distributed) + with run: + multi_stage = config.model.get_model_class()( + config.model, optimizer_state_names=config.optimizer.state_names() + ) + with torch.no_grad(): + multi_stage.setup(distributed, mode=StageMode.training) + multi_stage.load_checkpoint(config.pretrained) + param_groups, grads_for_norm = multi_stage.get_param_groups(ParamGroup) + optimizer = config.optimizer.optimizer_cls( + config.optimizer, param_groups=param_groups, grads_for_norm=grads_for_norm, distributed=distributed + ) + optimizer.reset_state() + runner = ScheduleRunner( + config=config.schedule, multi_stage=multi_stage, distributed_config=config.model.distributed + ) + with torch.no_grad(): + runner.setup(distributed, optimizer) + preprocessing_config = multi_stage.get_preprocessing_config( + PhaseType.training, config.schedule.micro_batch_splits + ) + # `get_model_inputs` splits off `num_labels` tokens for the shifted next-token labels, so the + # actual model input is `len(tokens) - num_labels`. The schedule meta must match that length. + schedule = Schedule( + config=config.schedule, + multi_stage=multi_stage, + batch_meta=preprocessing_config.get_input_meta(sequence_length - preprocessing_config.num_labels), + distributed_config=config.model.distributed, + phase=PhaseType.training, + ) + tokens = input_ids.flatten().to(device=distributed.device, dtype=torch.int64) + batch = LanguageModelBatch(tokens=tokens, lengths=[tokens.numel()]) + model_inputs = batch.get_model_inputs(preprocessing_config) + runner.run_step(iter((tuple(model_inputs),)), schedule, iteration=1) + # Break the trainer/model/runner reference cycles so each variant's GPU memory is reclaimed. + del multi_stage, optimizer, runner, schedule, distributed, run + gc.collect() + torch.cuda.empty_cache() + + def _is_gradient_like(tensor_name: str) -> bool: # Anything affected by the loss-scaling multiplier: parameter gradients from `Fsdp.log_shard`, # backward activations from layer hooks, and explicit `.grad` debug entries (e.g. logits.grad). diff --git a/tools/evaluate_precision_deepspeed.py b/tools/evaluate_precision_deepspeed.py new file mode 100644 index 000000000..00fa3e942 --- /dev/null +++ b/tools/evaluate_precision_deepspeed.py @@ -0,0 +1,281 @@ +"""Within-engine numerical-precision sweep for the HF-transformers + DeepSpeed stack. + +This is the DeepSpeed-side counterpart to `tools/evaluate_precision.py` (which measures the +same thing inside Fast-LLM). It loads a HF checkpoint, runs one forward + backward per precision +variant through a DeepSpeed engine, and reports two quantities against the fp32 reference, using +the same metrics (`CompareConfig._compute_diff`: RMS / bias / correlation / slope / residual): + + * chosen-token log-probability per position (the RL importance-ratio input); + * parameter gradients, aggregated by category (embedding/head, linear, norm, bias). + +The point is to check whether Fast-LLM's bf16 loses precision the *same way* DeepSpeed's bf16 +does — each measured against its own fp32 reference. + +The log-π computation and the fp32 LM-head mechanism mirror PipelineRL's DeepSpeed trainer +(`pipelinerl/finetune/rl/__init__.py` and `pipelinerl/finetune/checkpoints.py`) so the numbers +reflect the proven baseline rather than a bespoke path. `param.grad` is populated and already +unscaled after `engine.backward` (verified for both bf16 and fp16), so gradients are read directly. + +Run where transformers + deepspeed are installed (e.g. the PipelineRL stack image): + + python -m tools.evaluate_precision_deepspeed --model Qwen/Qwen2.5-0.5B --sequence-length 2048 +""" + +import argparse +import functools +import logging +import os +import statistics +import typing + +import torch + +logger = logging.getLogger(__name__) + +_REFERENCE_NAME = "fp32" +# (name, compute dtype, fp32 lm head). Reference is fp32 + fp32 head. `*_head_` variants +# turn the fp32 head OFF (head runs in compute dtype) to reproduce the within-engine +# "fp32 lm head has ~no effect" finding on the DeepSpeed side. +_VARIANTS: list[tuple[str, torch.dtype, bool]] = [ + (_REFERENCE_NAME, torch.float32, True), + ("bf16", torch.bfloat16, True), + ("bf16_head_bf16", torch.bfloat16, False), + ("fp16", torch.float16, True), + ("fp16_head_fp16", torch.float16, False), +] + +_FIXED_TEXT = ( + "The numerical precision of large language model training depends on the dtype used for " + "matrix multiplications, the accumulation precision of the hardware, and whether the output " + "projection is kept in full precision. In reinforcement learning from human feedback, the " + "importance ratio between the new and old policy is the exponential of the difference of " + "log-probabilities, so even small per-token errors in the log-probability can compound. " + "We compute the chosen-token log-probability as the log-softmax of the logits evaluated at " + "the next token, and we compare bfloat16 and float16 against a float32 reference. " +) + + +def apply_fp32_lm_head(model: torch.nn.Module, layer_prefix: str = "lm_head") -> torch.nn.Module: + """Cast the LM head to fp32 at compute time. Mirrors PipelineRL `apply_fp32_lm_head`. + + For tied embeddings (e.g. Qwen2.5-0.5B) the weight storage stays in the model dtype and is + upcast only for the head matmul; for untied heads the storage itself is moved to fp32. + """ + lm_head = model.get_output_embeddings() + if lm_head is None or not isinstance(lm_head, torch.nn.Linear): + raise RuntimeError(f"Could not find an nn.Linear LM head via get_output_embeddings(): {lm_head!r}") + tied = False + inp_emb = model.get_input_embeddings() + if inp_emb is not None and hasattr(inp_emb, "weight"): + tied = lm_head.weight is inp_emb.weight + if not tied and lm_head.weight.dtype != torch.float32: + lm_head.to(dtype=torch.float32) + original_forward = lm_head.forward + + @functools.wraps(original_forward) + def fp32_forward(x: torch.Tensor) -> torch.Tensor: + x32 = x if x.dtype == torch.float32 else x.float() + w = lm_head.weight + w32 = w if w.dtype == torch.float32 else w.float() + b = lm_head.bias + b32 = b.float() if (b is not None and b.dtype != torch.float32) else b + return torch.nn.functional.linear(x32, w32, b32) + + lm_head.forward = fp32_forward + logger.info(f"Applied fp32 lm head (tied={tied})") + return model + + +def chosen_logprob(logits: torch.Tensor, input_ids: torch.Tensor, temperature: float = 1.0) -> torch.Tensor: + """log π(next token) per position. Mirrors PipelineRL `rl/__init__.py:203-208`.""" + logits = logits[:, :-1, :].float() / temperature + next_ids = input_ids[:, 1:].unsqueeze(-1) + selected = torch.gather(logits, 2, next_ids).squeeze(-1) + log_norm = torch.logsumexp(logits, dim=-1) + return (selected - log_norm).reshape(-1) + + +def build_input_ids(tokenizer, vocab_size: int, sequence_length: int, device: torch.device, mode: str) -> torch.Tensor: + if mode == "random": + # Match Fast-LLM's random dataset (uniform token ids over the model vocab) so both engines + # see the same input distribution. The relative metrics depend strongly on it: on random + # tokens the model is maximally surprised (|log π| large), on realistic text |log π| ≈ 0, + # which shifts the relative RMS by several-fold even at identical absolute precision. + generator = torch.Generator().manual_seed(0) + ids = torch.randint(0, vocab_size, (sequence_length,), generator=generator) + else: + ids = tokenizer(_FIXED_TEXT, return_tensors="pt").input_ids[0] + repeats = (sequence_length + ids.numel() - 1) // ids.numel() + ids = ids.repeat(repeats)[:sequence_length] + return ids.unsqueeze(0).to(device) + + +def _ds_config(dtype: torch.dtype) -> dict[str, typing.Any]: + config: dict[str, typing.Any] = { + "train_micro_batch_size_per_gpu": 1, + "optimizer": {"type": "Adam", "params": {"lr": 1e-6}}, + } + if dtype == torch.bfloat16: + config["bf16"] = {"enabled": True} + elif dtype == torch.float16: + config["fp16"] = {"enabled": True, "initial_scale_power": 16} + return config + + +def grad_category(name: str) -> str: + if name.endswith(".bias"): + return "bias" + if "layernorm" in name or name.endswith("norm.weight"): + return "norm" + if "embed_tokens" in name or "lm_head" in name: + return "embed_head" + return "linear" + + +def capture_variant( + model_id: str, dtype: torch.dtype, fp32_head: bool, input_ids: torch.Tensor, attn_implementation: str +) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: + """Forward + backward one variant through a DeepSpeed engine. Returns (chosen_logprob, + {param_name: gradient}), both on CPU in fp32.""" + import deepspeed + import transformers + + model = transformers.AutoModelForCausalLM.from_pretrained( + model_id, dtype=dtype, attn_implementation=attn_implementation + ) + if fp32_head: + apply_fp32_lm_head(model) + engine, *_ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=_ds_config(dtype)) + outputs = engine(input_ids) + logprob = chosen_logprob(outputs.logits, input_ids) + # fp16's narrow exponent range underflows small gradients; scale the loss up before backward and + # divide it back out (loss scaling, as in fp16 training). bf16/fp32 have fp32 range, no scaling. + # engine.backward leaves param.grad unscaled, so dividing by our own loss_scale recovers the true + # gradient computed with extra headroom against underflow. + loss_scale = 256.0 if dtype == torch.float16 else 1.0 + engine.backward(-logprob.mean() * loss_scale) + grads = { + name: (p.grad.detach().float() / loss_scale).cpu() + for name, p in model.named_parameters() + if p.grad is not None + } + logprob = logprob.detach().float().cpu() + del engine, model, outputs + torch.cuda.empty_cache() + return logprob, grads + + +def _entry(tensor: torch.Tensor) -> dict[str, typing.Any]: + return {"shape": list(tensor.shape), "step": 1, "samples": tensor} + + +def _print_logprob_summary(metrics_by_variant: dict[str, dict[str, typing.Any]]) -> None: + cols = [ + ("RMS rel", lambda m: f"{m['rms_rel'] * 100:.4f}%"), + ("Bias rel", lambda m: f"{m['bias_rel'] * 100:+.4f}%"), + ("Resid rel", lambda m: f"{m['residual_rms_rel'] * 100:.4f}%"), + ("Corr", lambda m: f"{m['correlation']:.5f}"), + ("Slope", lambda m: f"{m['slope']:+.5f}"), + ("Max abs", lambda m: f"{m['max_abs']:.4g}"), + ("Scale", lambda m: f"{m['ref_scale']:.4g}"), + ] + _print_table("chosen_logprob (per-token) vs fp32 reference", metrics_by_variant, cols) + + +def _print_grad_summary(grad_metrics_by_variant: dict[str, dict[str, list[float]]]) -> None: + # Per-category aggregation of gradient RMS-rel, mirroring tools/evaluate_precision.py's grad table. + def med(values: list[float]) -> str: + return f"{statistics.median(values) * 100:.4f}%" if values else "n/a" + + def mx(values: list[float]) -> str: + return f"{max(values) * 100:.4f}%" if values else "n/a" + + cols = [ + ("embed_head", lambda c: med(c.get("embed_head", []))), + ("linear med", lambda c: med(c.get("linear", []))), + ("linear max", lambda c: mx(c.get("linear", []))), + ("norm med", lambda c: med(c.get("norm", []))), + ("norm max", lambda c: mx(c.get("norm", []))), + ("bias med", lambda c: med(c.get("bias", []))), + ("bias max", lambda c: mx(c.get("bias", []))), + ] + _print_table("gradient RMS-rel by category vs fp32 reference", grad_metrics_by_variant, cols) + + +def _print_table(title: str, by_variant: dict, cols: list[tuple[str, typing.Callable]]) -> None: + name_width = max((len(n) for n in by_variant), default=7) + 1 + widths = [max(len(label), max((len(fn(v)) for v in by_variant.values()), default=0)) for label, fn in cols] + print(f"\n=== DeepSpeed/HF: {title} ===") + header = f"{'Variant':<{name_width}}" + " ".join( + f"{label:<{w}}" for (label, _), w in zip(cols, widths, strict=True) + ) + print(header) + print("-" * len(header)) + for name, value in by_variant.items(): + cells = [fn(value) for _, fn in cols] + print(f"{name:<{name_width}}" + " ".join(f"{c:<{w}}" for c, w in zip(cells, widths, strict=True))) + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--model", default="Qwen/Qwen2.5-0.5B") + parser.add_argument("--sequence-length", type=int, default=2048) + parser.add_argument("--attn-implementation", default="sdpa") + parser.add_argument("--input-mode", choices=["random", "text"], default="random") + parser.add_argument( + "--input-file", + default=None, + help="Path to an input_ids.pt saved by tools/evaluate_precision.py. When set, feeds that exact" + " model input (so Fast-LLM and DeepSpeed see byte-identical tokens); --input-mode is ignored.", + ) + args = parser.parse_args() + + logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") + for key, value in ( + ("RANK", "0"), + ("LOCAL_RANK", "0"), + ("WORLD_SIZE", "1"), + ("MASTER_ADDR", "127.0.0.1"), + ("MASTER_PORT", "29555"), + ): + os.environ.setdefault(key, value) + + import transformers + + from fast_llm.engine.config_utils.compare_tensor_logs import CompareConfig + + device = torch.device("cuda:0") + if args.input_file is not None: + input_ids = torch.load(args.input_file).to(device=device, dtype=torch.int64) + logger.info(f"Loaded shared model input {tuple(input_ids.shape)} from {args.input_file}") + else: + tokenizer = transformers.AutoTokenizer.from_pretrained(args.model) + vocab_size = transformers.AutoConfig.from_pretrained(args.model).vocab_size + input_ids = build_input_ids(tokenizer, vocab_size, args.sequence_length, device, args.input_mode) + logger.info(f"input_ids shape {tuple(input_ids.shape)}") + + compare = CompareConfig() + ref_logprob: torch.Tensor | None = None + ref_grads: dict[str, torch.Tensor] = {} + logprob_metrics: dict[str, dict[str, typing.Any]] = {} + grad_metrics: dict[str, dict[str, list[float]]] = {} + for name, dtype, fp32_head in _VARIANTS: + logger.info(f"=== variant {name} (dtype={dtype}, fp32_head={fp32_head}) ===") + logprob, grads = capture_variant(args.model, dtype, fp32_head, input_ids, args.attn_implementation) + if name == _REFERENCE_NAME: + ref_logprob, ref_grads = logprob, grads + logprob_metrics[name] = compare._compute_diff(_entry(ref_logprob), _entry(logprob), "step", "chosen_logprob") + by_category: dict[str, list[float]] = {} + for param_name, grad in grads.items(): + if param_name not in ref_grads: + continue + metrics = compare._compute_diff(_entry(ref_grads[param_name]), _entry(grad), "step", param_name) + by_category.setdefault(grad_category(param_name), []).append(metrics["rms_rel"]) + grad_metrics[name] = by_category + + _print_logprob_summary(logprob_metrics) + _print_grad_summary(grad_metrics) + + +if __name__ == "__main__": + main() From cecf7aef0610aa2f81a9827d16ff45e3c726c62e Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 1 Jun 2026 12:54:00 -0400 Subject: [PATCH 33/41] Support random-init (model_weights=False) in both precision tools MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Lean runner now honors pretrained.model_weights (initialize_weights when not loading), matching the trainer's branch; the DeepSpeed harness gains --random-init (build from config). Note: random init is a poor cross-engine test — the two engines use different init schemes (different models), and HF's from_config init yields near-uniform untrained logits where bf16 noise dominates (correlation ~0). The pretrained comparison is the meaningful one. Co-Authored-By: Claude Opus 4.8 (1M context) --- tools/evaluate_precision.py | 5 ++++- tools/evaluate_precision_deepspeed.py | 28 ++++++++++++++++++++++----- 2 files changed, 27 insertions(+), 6 deletions(-) diff --git a/tools/evaluate_precision.py b/tools/evaluate_precision.py index 48e67500c..9d63a54c8 100644 --- a/tools/evaluate_precision.py +++ b/tools/evaluate_precision.py @@ -338,7 +338,10 @@ def _run_fixed_input(config, input_ids, sequence_length: int) -> None: ) with torch.no_grad(): multi_stage.setup(distributed, mode=StageMode.training) - multi_stage.load_checkpoint(config.pretrained) + if config.pretrained.path is not None and config.pretrained.model_weights: + multi_stage.load_checkpoint(config.pretrained) + else: + multi_stage.initialize_weights() param_groups, grads_for_norm = multi_stage.get_param_groups(ParamGroup) optimizer = config.optimizer.optimizer_cls( config.optimizer, param_groups=param_groups, grads_for_norm=grads_for_norm, distributed=distributed diff --git a/tools/evaluate_precision_deepspeed.py b/tools/evaluate_precision_deepspeed.py index 00fa3e942..d9cf102dc 100644 --- a/tools/evaluate_precision_deepspeed.py +++ b/tools/evaluate_precision_deepspeed.py @@ -133,16 +133,26 @@ def grad_category(name: str) -> str: def capture_variant( - model_id: str, dtype: torch.dtype, fp32_head: bool, input_ids: torch.Tensor, attn_implementation: str + model_id: str, + dtype: torch.dtype, + fp32_head: bool, + input_ids: torch.Tensor, + attn_implementation: str, + random_init: bool = False, ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: """Forward + backward one variant through a DeepSpeed engine. Returns (chosen_logprob, {param_name: gradient}), both on CPU in fp32.""" import deepspeed import transformers - model = transformers.AutoModelForCausalLM.from_pretrained( - model_id, dtype=dtype, attn_implementation=attn_implementation - ) + if random_init: + model = transformers.AutoModelForCausalLM.from_config( + transformers.AutoConfig.from_pretrained(model_id), dtype=dtype, attn_implementation=attn_implementation + ) + else: + model = transformers.AutoModelForCausalLM.from_pretrained( + model_id, dtype=dtype, attn_implementation=attn_implementation + ) if fp32_head: apply_fp32_lm_head(model) engine, *_ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=_ds_config(dtype)) @@ -228,6 +238,12 @@ def main() -> None: help="Path to an input_ids.pt saved by tools/evaluate_precision.py. When set, feeds that exact" " model input (so Fast-LLM and DeepSpeed see byte-identical tokens); --input-mode is ignored.", ) + parser.add_argument( + "--random-init", + action="store_true", + help="Build the model from config with random weights instead of loading the pretrained" + " checkpoint (contrast with the pretrained run; weights won't match Fast-LLM's random init).", + ) args = parser.parse_args() logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") @@ -261,7 +277,9 @@ def main() -> None: grad_metrics: dict[str, dict[str, list[float]]] = {} for name, dtype, fp32_head in _VARIANTS: logger.info(f"=== variant {name} (dtype={dtype}, fp32_head={fp32_head}) ===") - logprob, grads = capture_variant(args.model, dtype, fp32_head, input_ids, args.attn_implementation) + logprob, grads = capture_variant( + args.model, dtype, fp32_head, input_ids, args.attn_implementation, args.random_init + ) if name == _REFERENCE_NAME: ref_logprob, ref_grads = logprob, grads logprob_metrics[name] = compare._compute_diff(_entry(ref_logprob), _entry(logprob), "step", "chosen_logprob") From a6d93143d231535b010569d5719bccde4727c111 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 1 Jun 2026 14:36:02 -0400 Subject: [PATCH 34/41] vLLM within-engine precision tool Subprocess-per-variant log-prob precision sweep mirroring the trainer-side tools: feeds a fixed prompt, reads vLLM prompt_logprobs (chosen-token log-pi aligned with the trainers), and compares each precision variant against the fp32 reference. Forces a single attention backend across variants to isolate precision from the kernel. Co-Authored-By: Claude Opus 4.8 (1M context) --- tools/evaluate_precision_vllm.py | 247 +++++++++++++++++++++++++++++++ 1 file changed, 247 insertions(+) create mode 100644 tools/evaluate_precision_vllm.py diff --git a/tools/evaluate_precision_vllm.py b/tools/evaluate_precision_vllm.py new file mode 100644 index 000000000..ca9d917ce --- /dev/null +++ b/tools/evaluate_precision_vllm.py @@ -0,0 +1,247 @@ +"""Within-engine numerical-precision sweep for the vLLM inference stack. + +This is the vLLM-side counterpart to `tools/evaluate_precision.py` (Fast-LLM) and +`tools/evaluate_precision_deepspeed.py` (HF + DeepSpeed). It loads a checkpoint once per +precision variant, feeds a fixed prompt, and reports the per-token chosen-token +log-probability against the vLLM fp32 reference, using the same metrics +(`CompareConfig._compute_diff`: RMS / bias / correlation / slope / residual). + +vLLM is inference-only here, so there is no backward pass: only the forward log-π is measured. +For a prompt of tokens (t_0 .. t_{n-1}), vLLM's `prompt_logprobs[i]` is log P(t_i | t_0..t_{i-1}); +slicing `[1:]` gives the n-1 next-token log-probs that align 1:1 with the trainers' +`chosen_logprob` (log-softmax of the logits at the next token). + +The variants mirror the PipelineRL inference settings: + + * `bf16_fp32_head` is the production setting — vLLM is always launched with + `--quantization bf16_last_layer_fp32` (bf16 body, fp32 LM head with fp32 logits), which is + the vLLM analog of the trainers' fp32 LM head; + * `bf16` turns that off (head runs in bf16) to isolate the head's contribution. + +Each variant runs in its own subprocess because vLLM's global CUDA/engine state does not tear +down cleanly in-process; a fresh process per variant guarantees isolation. Each worker writes +`output_dir/logprobs_.pt`; the parent loads them and prints the comparison table. + +Run where vLLM (and, for the fp32-head variants, pipelinerl) is installed, e.g. the PipelineRL +stack image: + + python -m tools.evaluate_precision_vllm --model Qwen/Qwen2.5-0.5B --input-file /input_ids.pt +""" + +import argparse +import logging +import pathlib +import subprocess +import sys +import typing + +import torch + +logger = logging.getLogger(__name__) + +_REFERENCE_NAME = "fp32" +# (name, vLLM dtype, quantization). Reference is full fp32. `*_fp32_head` variants add the +# `bf16_last_layer_fp32` quantization (bf16/fp16 body, fp32 LM head + fp32 logits) — the +# production PipelineRL inference setting; the plain `bf16`/`fp16` variants run the head in the +# body dtype, isolating the fp32-head contribution. +# The bf16_last_layer_fp32 quantization only supports bf16/fp32 bodies (it rejects fp16), so there is +# no fp16_fp32_head variant. +_QUANTIZATION = "bf16_last_layer_fp32" +_VARIANTS: list[tuple[str, str, str | None]] = [ + (_REFERENCE_NAME, "float32", None), + ("bf16", "bfloat16", None), + ("bf16_fp32_head", "bfloat16", _QUANTIZATION), + ("fp16", "float16", None), +] + + +def build_input_ids(model: str, sequence_length: int, text_file: str | None) -> torch.Tensor: + """Build the fixed prompt, mirroring Fast-LLM's `tools/evaluate_precision.py:_prepare_input_ids` + (same tokenizer + truncation for text, same seed-0 uniform-random tokens otherwise) so a standalone + run is byte-identical to the trainers' and cross-engine-ready.""" + import transformers + + if text_file is not None: + tokenizer = transformers.AutoTokenizer.from_pretrained(model) + ids = tokenizer(pathlib.Path(text_file).read_text(), return_tensors="pt").input_ids[0] + if ids.numel() < sequence_length: + ids = ids.repeat((sequence_length + ids.numel() - 1) // ids.numel()) + return ids[:sequence_length].to(torch.int64) + vocab_size = transformers.AutoConfig.from_pretrained(model).vocab_size + generator = torch.Generator().manual_seed(0) + return torch.randint(0, vocab_size, (sequence_length,), generator=generator, dtype=torch.int64) + + +def run_worker( + model: str, + variant: str, + dtype: str, + quantization: str | None, + attention_backend: str | None, + input_file: str, + output_dir: str, +) -> None: + """Load the model at one precision variant, feed the fixed prompt, save per-token chosen log-π.""" + import os + + # Force a single attention backend across all variants so the fp32-vs-bf16 diff reflects precision, + # not a backend switch (vLLM otherwise picks flash-attn for bf16/fp16 but a Triton/flex backend for + # fp32). TRITON_ATTN supports all three dtypes. Mirrors forcing sdpa on the DeepSpeed side. + if attention_backend is not None: + os.environ.setdefault("VLLM_ATTENTION_BACKEND", attention_backend) + if quantization is not None: + os.environ.setdefault("PIPELINERL_FP32_LAYER_PREFIX", "lm_head") + import pipelinerl.vllm_quantization # noqa: F401 registers the bf16_last_layer_fp32 config + + import vllm + + input_ids = torch.load(input_file).flatten().to(torch.int64) + token_ids = input_ids.tolist() + + llm = vllm.LLM( + model=model, + dtype=dtype, + quantization=quantization, + seed=0, + enforce_eager=True, + gpu_memory_utilization=0.9, + max_model_len=len(token_ids) + 16, + enable_prefix_caching=False, + logprobs_mode="processed_logprobs", + ) + sampling_params = vllm.SamplingParams(temperature=1.0, max_tokens=1, prompt_logprobs=0) + output = llm.generate(prompts=[{"prompt_token_ids": token_ids}], sampling_params=sampling_params)[0] + + # prompt_logprobs[i] is log P(t_i | t_0..t_{i-1}); [0] is None. Take the actual token at each + # position from i=1 on -> n-1 next-token log-probs aligned with the trainers' chosen_logprob. + prompt_logprobs = output.prompt_logprobs + logprobs = torch.tensor( + [prompt_logprobs[i][token_ids[i]].logprob for i in range(1, len(token_ids))], dtype=torch.float32 + ) + path = pathlib.Path(output_dir) / f"logprobs_{variant}.pt" + torch.save(logprobs, path) + logger.info(f"variant {variant}: {logprobs.numel()} tokens, scale {logprobs.square().mean().sqrt():.4g} -> {path}") + + +def _entry(tensor: torch.Tensor) -> dict[str, typing.Any]: + return {"shape": list(tensor.shape), "step": 1, "samples": tensor} + + +def _print_table(title: str, by_variant: dict, cols: list[tuple[str, typing.Callable]]) -> None: + name_width = max((len(n) for n in by_variant), default=7) + 1 + widths = [max(len(label), max((len(fn(v)) for v in by_variant.values()), default=0)) for label, fn in cols] + print(f"\n=== vLLM: {title} ===") + header = f"{'Variant':<{name_width}}" + " ".join( + f"{label:<{w}}" for (label, _), w in zip(cols, widths, strict=True) + ) + print(header) + print("-" * len(header)) + for name, value in by_variant.items(): + cells = [fn(value) for _, fn in cols] + print(f"{name:<{name_width}}" + " ".join(f"{c:<{w}}" for c, w in zip(cells, widths, strict=True))) + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--model", default="Qwen/Qwen2.5-0.5B") + parser.add_argument("--sequence-length", type=int, default=2048) + parser.add_argument( + "--input-file", + default=None, + help="Path to an input_ids.pt saved by tools/evaluate_precision.py. When set, feeds that exact" + " model input (byte-identical to the trainers); otherwise the input is built from" + " --input-text-file or seed-0 random tokens.", + ) + parser.add_argument( + "--input-text-file", + default=None, + help="Tokenize this text file (same tokenizer + truncation as Fast-LLM) for a realistic-text" + " prompt instead of random tokens. Ignored when --input-file is set.", + ) + parser.add_argument("--output-dir", default="/tmp/fast_llm_tests/evaluate_precision/vllm") + parser.add_argument( + "--attention-backend", + default="TRITON_ATTN", + help="vLLM attention backend forced for every variant (isolates precision from the kernel)." + " Pass 'auto' to let vLLM pick per dtype (the production path: flash-attn for bf16/fp16).", + ) + # Internal: a single-variant worker invocation (one vLLM engine per process). + parser.add_argument("--worker-variant", default=None, help=argparse.SUPPRESS) + parser.add_argument("--worker-dtype", default=None, help=argparse.SUPPRESS) + parser.add_argument("--worker-quantization", default=None, help=argparse.SUPPRESS) + args = parser.parse_args() + + logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") + output_dir = pathlib.Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + attention_backend = None if args.attention_backend == "auto" else args.attention_backend + if args.worker_variant is not None: + quantization = args.worker_quantization or None + run_worker( + args.model, + args.worker_variant, + args.worker_dtype, + quantization, + attention_backend, + args.input_file, + args.output_dir, + ) + return + + if args.input_file is not None: + input_file = args.input_file + logger.info(f"Using shared model input {input_file}") + else: + input_ids = build_input_ids(args.model, args.sequence_length, args.input_text_file).unsqueeze(0) + input_file = str(output_dir / "input_ids.pt") + torch.save(input_ids, input_file) + kind = "text" if args.input_text_file is not None else "seed-0 random" + logger.info(f"Generated {kind} input {tuple(input_ids.shape)} -> {input_file}") + + for name, dtype, quantization in _VARIANTS: + logger.info(f"=== variant {name} (dtype={dtype}, quantization={quantization}) ===") + cmd = [ + sys.executable, + "-m", + "tools.evaluate_precision_vllm", + "--model", + args.model, + "--input-file", + input_file, + "--output-dir", + args.output_dir, + "--attention-backend", + args.attention_backend, + "--worker-variant", + name, + "--worker-dtype", + dtype, + ] + if quantization is not None: + cmd += ["--worker-quantization", quantization] + subprocess.run(cmd, check=True) + + from fast_llm.engine.config_utils.compare_tensor_logs import CompareConfig + + compare = CompareConfig() + ref = torch.load(output_dir / f"logprobs_{_REFERENCE_NAME}.pt") + metrics: dict[str, dict[str, typing.Any]] = {} + for name, _, _ in _VARIANTS: + logprob = torch.load(output_dir / f"logprobs_{name}.pt") + metrics[name] = compare._compute_diff(_entry(ref), _entry(logprob), "step", "chosen_logprob") + + cols = [ + ("RMS rel", lambda m: f"{m['rms_rel'] * 100:.4f}%"), + ("Bias rel", lambda m: f"{m['bias_rel'] * 100:+.4f}%"), + ("Resid rel", lambda m: f"{m['residual_rms_rel'] * 100:.4f}%"), + ("Corr", lambda m: f"{m['correlation']:.5f}"), + ("Slope", lambda m: f"{m['slope']:+.5f}"), + ("Max abs", lambda m: f"{m['max_abs']:.4g}"), + ("Scale", lambda m: f"{m['ref_scale']:.4g}"), + ] + _print_table("chosen_logprob (per-token) vs fp32 reference", metrics, cols) + + +if __name__ == "__main__": + main() From 26cd8ab8fd35c7ce384eb098bcb934c1c9b76782 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 1 Jun 2026 15:15:03 -0400 Subject: [PATCH 35/41] Auto-bind vLLM fp32 head on tied-embedding models The bf16_last_layer_fp32 quant matches its fp32 head by layer-name suffix. vLLM names the tied head embed_tokens (lm_head = embed_tokens), so the production lm_head prefix silently runs a bf16 head on tied models. Default --fp32-head-prefix auto now picks embed_tokens when embeddings are tied so the fp32 head genuinely binds (text bf16_fp32_head 1.05% -> 0.79%, matching the trainers); pass lm_head for the literal production setting. Co-Authored-By: Claude Opus 4.8 (1M context) --- tools/evaluate_precision_vllm.py | 35 ++++++++++++++++++++++++++++---- 1 file changed, 31 insertions(+), 4 deletions(-) diff --git a/tools/evaluate_precision_vllm.py b/tools/evaluate_precision_vllm.py index ca9d917ce..a1a2934c4 100644 --- a/tools/evaluate_precision_vllm.py +++ b/tools/evaluate_precision_vllm.py @@ -13,9 +13,11 @@ The variants mirror the PipelineRL inference settings: - * `bf16_fp32_head` is the production setting — vLLM is always launched with - `--quantization bf16_last_layer_fp32` (bf16 body, fp32 LM head with fp32 logits), which is - the vLLM analog of the trainers' fp32 LM head; + * `bf16_fp32_head` is the production setting — `--quantization bf16_last_layer_fp32` (bf16 body, fp32 + LM head with fp32 logits), the vLLM analog of the trainers' fp32 LM head. The quant binds its fp32 + head by layer-name suffix; for tied-embedding models vLLM names the head `embed_tokens` rather than + `lm_head`, so `--fp32-head-prefix auto` targets the right layer (otherwise the head silently no-ops + on tied models, running in bf16); * `bf16` turns that off (head runs in bf16) to isolate the head's contribution. Each variant runs in its own subprocess because vLLM's global CUDA/engine state does not tear @@ -72,12 +74,26 @@ def build_input_ids(model: str, sequence_length: int, text_file: str | None) -> return torch.randint(0, vocab_size, (sequence_length,), generator=generator, dtype=torch.int64) +def resolve_fp32_head_prefix(model: str, prefix: str) -> str: + """The bf16_last_layer_fp32 quant binds its fp32 head by matching the layer-name suffix. For tied + embeddings vLLM sets `lm_head = model.embed_tokens`, so the head is named `embed_tokens`, not + `lm_head` — the production `lm_head` prefix then silently no-ops. Pick the suffix that matches the + actual output head so the fp32 head genuinely binds on tied models too.""" + if prefix != "auto": + return prefix + import transformers + + tied = transformers.AutoConfig.from_pretrained(model).tie_word_embeddings + return "embed_tokens" if tied else "lm_head" + + def run_worker( model: str, variant: str, dtype: str, quantization: str | None, attention_backend: str | None, + fp32_head_prefix: str, input_file: str, output_dir: str, ) -> None: @@ -90,7 +106,7 @@ def run_worker( if attention_backend is not None: os.environ.setdefault("VLLM_ATTENTION_BACKEND", attention_backend) if quantization is not None: - os.environ.setdefault("PIPELINERL_FP32_LAYER_PREFIX", "lm_head") + os.environ["PIPELINERL_FP32_LAYER_PREFIX"] = fp32_head_prefix import pipelinerl.vllm_quantization # noqa: F401 registers the bf16_last_layer_fp32 config import vllm @@ -165,6 +181,13 @@ def main() -> None: help="vLLM attention backend forced for every variant (isolates precision from the kernel)." " Pass 'auto' to let vLLM pick per dtype (the production path: flash-attn for bf16/fp16).", ) + parser.add_argument( + "--fp32-head-prefix", + default="auto", + help="Layer-name suffix the bf16_last_layer_fp32 quant matches for its fp32 head. 'auto' picks" + " embed_tokens for tied-embedding models / lm_head otherwise so the fp32 head actually binds." + " Pass 'lm_head' for the literal production setting (a no-op on tied models).", + ) # Internal: a single-variant worker invocation (one vLLM engine per process). parser.add_argument("--worker-variant", default=None, help=argparse.SUPPRESS) parser.add_argument("--worker-dtype", default=None, help=argparse.SUPPRESS) @@ -176,6 +199,7 @@ def main() -> None: output_dir.mkdir(parents=True, exist_ok=True) attention_backend = None if args.attention_backend == "auto" else args.attention_backend + fp32_head_prefix = resolve_fp32_head_prefix(args.model, args.fp32_head_prefix) if args.worker_variant is not None: quantization = args.worker_quantization or None run_worker( @@ -184,6 +208,7 @@ def main() -> None: args.worker_dtype, quantization, attention_backend, + fp32_head_prefix, args.input_file, args.output_dir, ) @@ -213,6 +238,8 @@ def main() -> None: args.output_dir, "--attention-backend", args.attention_backend, + "--fp32-head-prefix", + fp32_head_prefix, "--worker-variant", name, "--worker-dtype", From fc6072cd4b06f4b77c673669a4bc7ea4c8832e6b Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 2 Jun 2026 11:53:47 -0400 Subject: [PATCH 36/41] =?UTF-8?q?Cross-engine=20log-prob=20comparison=20to?= =?UTF-8?q?ol=20+=20per-token=20log=20=CF=80=20persistence?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add tools/evaluate_precision_cross_engine.py: loads each engine's per-token log π vectors and reports the cross-engine log-ratio δ = log π_A − log π_B (mean/RMS/max/clip), plus the error-correlation decomposition δ = floor + (e_A − e_B) with ρ = corr(e_A, e_B) — the quantity that explains why fp32-head matters across engines (removes the uncorrelated head-rounding component) while being small within one. Persist the full per-token log π (token order, aligned 1:1 with vLLM's prompt_logprobs[1:]) from the two trainer tools so the comparison can consume plain tensors: evaluate_precision.py extracts the chosen_logprob vector from the run artifacts; evaluate_precision_deepspeed.py gains --output-dir. Co-Authored-By: Claude Opus 4.8 (1M context) --- tools/evaluate_precision.py | 17 ++ tools/evaluate_precision_cross_engine.py | 209 +++++++++++++++++++++++ tools/evaluate_precision_deepspeed.py | 12 ++ 3 files changed, 238 insertions(+) create mode 100644 tools/evaluate_precision_cross_engine.py diff --git a/tools/evaluate_precision.py b/tools/evaluate_precision.py index 9d63a54c8..3bf5e4c12 100644 --- a/tools/evaluate_precision.py +++ b/tools/evaluate_precision.py @@ -107,6 +107,7 @@ def run(self) -> None: scales: dict[str, float] = {} for name, variant_overrides in runs.items(): scales[name] = self._calibrate_and_run(name, variant_overrides, input_ids) + self._save_chosen_logprob(name) ref_artifacts = self._artifact_path(_REFERENCE_NAME) results = { @@ -177,6 +178,22 @@ def _prepare_input_ids(self) -> "torch.Tensor": def _artifact_path(self, name: str) -> pathlib.Path: return self.output_dir / name / "runs" / "0" / "artifacts" + def _save_chosen_logprob(self, name: str) -> None: + """Persist the full per-token log π vector (token order) as a plain tensor for the + cross-engine comparison. The chosen_logprob loss logs the whole tensor with step=1, so the + saved samples are the complete ordered vector — aligned 1:1 with vLLM's `prompt_logprobs[1:]`.""" + import torch + + compare_config = CompareConfig() + errors: list[str] = [] + logs = compare_config._extract_tensor_logs(self._artifact_path(name), errors) + for step_logs in logs.values(): + for tensor_name, entry in step_logs.items(): + if tensor_name.split(":", 1)[-1].strip() == _CHOSEN_LOGPROB_NAME: + torch.save(entry["samples"].float().cpu(), self.output_dir / f"logprobs_{name}.pt") + return + logger.warning(f"[{name}] chosen_logprob not found in tensor logs; cross-engine vector not saved") + def _run_one( self, name: str, diff --git a/tools/evaluate_precision_cross_engine.py b/tools/evaluate_precision_cross_engine.py new file mode 100644 index 000000000..1618f0f71 --- /dev/null +++ b/tools/evaluate_precision_cross_engine.py @@ -0,0 +1,209 @@ +"""Cross-engine per-token log-probability comparison for Fast-LLM, DeepSpeed and vLLM. + +The within-engine tools (`tools/evaluate_precision{,_deepspeed,_vllm}.py`) each measure a variant +against *its own* fp32 reference — that isolates one engine's internal rounding. This tool measures +the gap *across* engines on the identical input: the per-token log-ratio + + δ = log π_A − log π_B + +over the chosen tokens. When A is a trainer and B is the vLLM sampler, δ is the log of the RL +importance ratio exp(log π_train − log π_old) that multiplies the advantage, so δ is exactly the +quantity that perturbs the gradient — and the gap the literature quotes in nats. + +Each engine's per-token log π is a plain fp32 vector saved by its within-engine tool +(`/logprobs_.pt`), all length L−1 and aligned 1:1 on the shared input. This tool maps +each engine's variant names to three canonical configs and reports two things per regime: + + 1. the δ distribution (mean = systematic bias, RMS, max, per-sequence sum, PPO clip fraction) for + the fp32 floor, the matched production config (bf16 body + fp32 head on both sides), and the + mismatched config (vLLM's as-shipped bf16 head vs the trainer's fp32 head); + 2. the error-correlation decomposition. With e = log π_bf16 − log π_fp32 per engine, + δ_AB = (fp32 floor) + (e_A − e_B), and RMS(e_A − e_B) is governed by ρ = corr(e_A, e_B): + ρ→1 means the engines round the same way and the errors cancel (gap collapses to the floor); + ρ→0 means independent rounding (errors add in quadrature). fp32-head matching works by removing + the large, cross-engine-uncorrelated head-rounding component, which is why it matters across + engines while being nearly invisible within one. + +Run (any subset of the three dirs; pairs are formed from whatever is available): + + python -m tools.evaluate_precision_cross_engine \\ + --fast-llm-dir /text --deepspeed-dir /text/ds --vllm-dir /text/vllm \\ + --label text +""" + +import argparse +import itertools +import pathlib + +import torch + +# Canonical cross-engine configs, mapped to each within-engine tool's variant names. +_ENGINE_VARIANTS: dict[str, dict[str, str]] = { + "fast_llm": {"fp32": "reference", "bf16_fp32head": "bf16_fp32_lm_head", "bf16_bf16head": "bf16"}, + "deepspeed": {"fp32": "fp32", "bf16_fp32head": "bf16", "bf16_bf16head": "bf16_head_bf16"}, + "vllm": {"fp32": "fp32", "bf16_fp32head": "bf16_fp32_head", "bf16_bf16head": "bf16"}, +} +_ENGINE_ORDER = ("fast_llm", "deepspeed", "vllm") +_ENGINE_LABELS = {"fast_llm": "Fast-LLM", "deepspeed": "DeepSpeed", "vllm": "vLLM"} + + +def _rms(x: torch.Tensor) -> float: + return x.pow(2).mean().sqrt().item() + + +def _corr(x: torch.Tensor, y: torch.Tensor) -> float: + x = x - x.mean() + y = y - y.mean() + denom = x.norm() * y.norm() + return (x @ y / denom).item() if denom > 0 else float("nan") + + +def _slope(reference: torch.Tensor, test: torch.Tensor) -> float: + # Regression of `test` on `reference` (test ≈ slope · reference); slope ≠ 1 is a multiplicative + # mismatch (e.g. a temperature/scale discrepancy), distinct from the additive offset mean(δ) catches. + reference = reference - reference.mean() + test = test - test.mean() + denom = reference.pow(2).sum() + return (reference @ test / denom).item() if denom > 0 else float("nan") + + +def _delta_stats(a: torch.Tensor, b: torch.Tensor, epsilon: float) -> dict[str, float]: + delta = a - b + abs_delta = delta.abs() + return { + "mean": delta.mean().item(), + "rms": _rms(delta), + "max": abs_delta.max().item(), + "sum": delta.sum().item(), + "clip": (abs_delta > epsilon).float().mean().item(), + } + + +def _load_engine(engine: str, directory: pathlib.Path) -> dict[str, torch.Tensor]: + vectors: dict[str, torch.Tensor] = {} + for config, variant in _ENGINE_VARIANTS[engine].items(): + path = directory / f"logprobs_{variant}.pt" + if path.exists(): + vectors[config] = torch.load(path, map_location="cpu").float().flatten() + else: + print(f" [{_ENGINE_LABELS[engine]}] {config}: missing {path} — skipping") + return vectors + + +def _error_vector(vectors: dict[str, torch.Tensor], engine: str, mode: str) -> torch.Tensor | None: + # bf16 rounding error against this engine's own fp32. In the mismatched config only vLLM keeps a + # bf16 head; the trainers always use the fp32 head. + config = "bf16_bf16head" if (mode == "mismatched" and engine == "vllm") else "bf16_fp32head" + if config not in vectors or "fp32" not in vectors: + return None + return vectors[config] - vectors["fp32"] + + +def _pair_config(engine: str, mode: str) -> str: + if mode == "fp32": + return "fp32" + if mode == "matched": + return "bf16_fp32head" + return "bf16_bf16head" if engine == "vllm" else "bf16_fp32head" + + +def _print_delta_table( + rows: list[tuple[str, str, str, dict[str, float]]], sequence_length: int, epsilon: float, label: str +) -> None: + print(f"\n=== Cross-engine log π gap{f' [{label}]' if label else ''} (δ = A − B, nats) ===") + print(f"(per-sequence sum over {sequence_length} tokens; clip = fraction with |δ| > {epsilon})") + header = f"{'group':<22} {'A − B':<22} {'mean δ':>10} {'RMS δ':>9} {'max|δ|':>9} {'Σδ (seq)':>10} {'clip%':>7}" + print(header) + print("-" * len(header)) + for group, engine_a, engine_b, stats in rows: + pair = f"{_ENGINE_LABELS[engine_a]} − {_ENGINE_LABELS[engine_b]}" + print( + f"{group:<22} {pair:<22} {stats['mean']:>+10.4f} {stats['rms']:>9.4f} {stats['max']:>9.4f}" + f" {stats['sum']:>+10.2f} {stats['clip'] * 100:>6.2f}%" + ) + + +def _print_decomposition_table(rows: list[tuple[str, str, str, dict[str, float]]], label: str) -> None: + print(f"\n=== Error-correlation decomposition{f' [{label}]' if label else ''} (e = bf16 − fp32, nats) ===") + print("(δ = floor + (e_A − e_B); ρ = corr(e_A, e_B): ρ→1 errors cancel, ρ→0 add in quadrature)") + header = ( + f"{'config':<12} {'A − B':<22} {'ρ(err)':>8} {'RMS e_A':>9} {'RMS e_B':>9}" + f" {'RMS(e_A−e_B)':>13} {'RMS floor':>10} {'slope':>8}" + ) + print(header) + print("-" * len(header)) + for config, engine_a, engine_b, stats in rows: + pair = f"{_ENGINE_LABELS[engine_a]} − {_ENGINE_LABELS[engine_b]}" + print( + f"{config:<12} {pair:<22} {stats['rho']:>8.4f} {stats['rms_a']:>9.4f} {stats['rms_b']:>9.4f}" + f" {stats['rms_diff']:>13.4f} {stats['rms_floor']:>10.4f} {stats['slope']:>+8.4f}" + ) + + +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) + parser.add_argument("--fast-llm-dir", default=None, help="Dir with Fast-LLM logprobs_.pt files.") + parser.add_argument("--deepspeed-dir", default=None, help="Dir with DeepSpeed logprobs_.pt files.") + parser.add_argument("--vllm-dir", default=None, help="Dir with vLLM logprobs_.pt files.") + parser.add_argument( + "--epsilon", type=float, default=0.2, help="PPO clip band for the clip-fraction column (default 0.2)." + ) + parser.add_argument("--label", default="", help="Regime label for the table headers (e.g. 'text', 'random').") + args = parser.parse_args() + + dirs = {"fast_llm": args.fast_llm_dir, "deepspeed": args.deepspeed_dir, "vllm": args.vllm_dir} + print("Loading per-engine log π vectors:") + engines: dict[str, dict[str, torch.Tensor]] = {} + for engine in _ENGINE_ORDER: + if dirs[engine] is not None: + engines[engine] = _load_engine(engine, pathlib.Path(dirs[engine])) + + lengths = {vector.numel() for vectors in engines.values() for vector in vectors.values()} + if len(lengths) > 1: + raise ValueError(f"Per-token log π vectors have mismatched lengths {sorted(lengths)} — inputs not aligned.") + sequence_length = next(iter(lengths)) if lengths else 0 + available = [engine for engine in _ENGINE_ORDER if engine in engines] + pairs = list(itertools.combinations(available, 2)) + + delta_rows: list[tuple[str, str, str, dict[str, float]]] = [] + for mode, group_label in (("fp32", "fp32 floor"), ("matched", "matched (fp32 head)")): + for engine_a, engine_b in pairs: + config_a, config_b = _pair_config(engine_a, mode), _pair_config(engine_b, mode) + if config_a in engines[engine_a] and config_b in engines[engine_b]: + stats = _delta_stats(engines[engine_a][config_a], engines[engine_b][config_b], args.epsilon) + delta_rows.append((group_label, engine_a, engine_b, stats)) + for engine_a, engine_b in pairs: + if "vllm" not in (engine_a, engine_b): + continue + config_a, config_b = _pair_config(engine_a, "mismatched"), _pair_config(engine_b, "mismatched") + if config_a in engines[engine_a] and config_b in engines[engine_b]: + stats = _delta_stats(engines[engine_a][config_a], engines[engine_b][config_b], args.epsilon) + delta_rows.append(("mismatched (vLLM bf16 head)", engine_a, engine_b, stats)) + + decomposition_rows: list[tuple[str, str, str, dict[str, float]]] = [] + for mode in ("matched", "mismatched"): + for engine_a, engine_b in pairs: + if mode == "mismatched" and "vllm" not in (engine_a, engine_b): + continue + error_a = _error_vector(engines[engine_a], engine_a, mode) + error_b = _error_vector(engines[engine_b], engine_b, mode) + if error_a is None or error_b is None: + continue + floor = engines[engine_a]["fp32"] - engines[engine_b]["fp32"] + config_a, config_b = _pair_config(engine_a, mode), _pair_config(engine_b, mode) + stats = { + "rho": _corr(error_a, error_b), + "rms_a": _rms(error_a), + "rms_b": _rms(error_b), + "rms_diff": _rms(error_a - error_b), + "rms_floor": _rms(floor), + "slope": _slope(engines[engine_b][config_b], engines[engine_a][config_a]), + } + decomposition_rows.append((mode, engine_a, engine_b, stats)) + + _print_delta_table(delta_rows, sequence_length, args.epsilon, args.label) + _print_decomposition_table(decomposition_rows, args.label) + + +if __name__ == "__main__": + main() diff --git a/tools/evaluate_precision_deepspeed.py b/tools/evaluate_precision_deepspeed.py index d9cf102dc..f02c8e17a 100644 --- a/tools/evaluate_precision_deepspeed.py +++ b/tools/evaluate_precision_deepspeed.py @@ -25,6 +25,7 @@ import functools import logging import os +import pathlib import statistics import typing @@ -244,6 +245,13 @@ def main() -> None: help="Build the model from config with random weights instead of loading the pretrained" " checkpoint (contrast with the pretrained run; weights won't match Fast-LLM's random init).", ) + parser.add_argument( + "--output-dir", + default=None, + help="If set, save each variant's full per-token log π vector to" + " `/logprobs_.pt` (plain fp32 CPU tensor, aligned 1:1 with vLLM's" + " `prompt_logprobs[1:]`) for the cross-engine comparison.", + ) args = parser.parse_args() logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") @@ -280,6 +288,10 @@ def main() -> None: logprob, grads = capture_variant( args.model, dtype, fp32_head, input_ids, args.attn_implementation, args.random_init ) + if args.output_dir is not None: + output_dir = pathlib.Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + torch.save(logprob, output_dir / f"logprobs_{name}.pt") if name == _REFERENCE_NAME: ref_logprob, ref_grads = logprob, grads logprob_metrics[name] = compare._compute_diff(_entry(ref_logprob), _entry(logprob), "step", "chosen_logprob") From 767e7ebd01b9f3970e8b2c67eef6891da485a5e9 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 2 Jun 2026 12:38:42 -0400 Subject: [PATCH 37/41] Cross-engine: all-pairwise mismatched group + fp16 precision MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Generalize the comparison over precisions {bf16, fp16}: each gets a matched (fp32 head both sides) and mismatched (A fp32 head, B body-dtype head) group, both now spanning every engine pair — the mismatched group previously only covered vLLM pairs, so Fast-LLM−DeepSpeed was missing. vLLM has no fp16+fp32 head (its quant rejects an fp16 body), so fp16 matched is trainer-only; fp16 mismatched still covers all pairs with vLLM as the body-head side. Co-Authored-By: Claude Opus 4.8 (1M context) --- tools/evaluate_precision_cross_engine.py | 137 +++++++++++++---------- 1 file changed, 79 insertions(+), 58 deletions(-) diff --git a/tools/evaluate_precision_cross_engine.py b/tools/evaluate_precision_cross_engine.py index 1618f0f71..8d72ed61d 100644 --- a/tools/evaluate_precision_cross_engine.py +++ b/tools/evaluate_precision_cross_engine.py @@ -14,9 +14,10 @@ (`/logprobs_.pt`), all length L−1 and aligned 1:1 on the shared input. This tool maps each engine's variant names to three canonical configs and reports two things per regime: - 1. the δ distribution (mean = systematic bias, RMS, max, per-sequence sum, PPO clip fraction) for - the fp32 floor, the matched production config (bf16 body + fp32 head on both sides), and the - mismatched config (vLLM's as-shipped bf16 head vs the trainer's fp32 head); + 1. the δ distribution (mean = systematic bias, RMS, max, per-sequence sum, PPO clip fraction) over + every engine pair, for the fp32 floor and, per low precision (bf16, fp16), the matched config + (fp32 head on both sides) and the mismatched config (A keeps the fp32 head, B runs the head in + the body dtype — e.g. vLLM's as-shipped bf16 head vs a trainer's fp32 head); 2. the error-correlation decomposition. With e = log π_bf16 − log π_fp32 per engine, δ_AB = (fp32 floor) + (e_A − e_B), and RMS(e_A − e_B) is governed by ρ = corr(e_A, e_B): ρ→1 means the engines round the same way and the errors cancel (gap collapses to the floor); @@ -37,16 +38,45 @@ import torch -# Canonical cross-engine configs, mapped to each within-engine tool's variant names. +# Canonical cross-engine configs, mapped to each within-engine tool's variant names. For a precision +# `prec`, `_fp32head` keeps the LM head / logits in fp32 while `_head` runs the head +# in the body dtype. vLLM has no fp16+fp32 head (its quant rejects an fp16 body), so that entry is absent. _ENGINE_VARIANTS: dict[str, dict[str, str]] = { - "fast_llm": {"fp32": "reference", "bf16_fp32head": "bf16_fp32_lm_head", "bf16_bf16head": "bf16"}, - "deepspeed": {"fp32": "fp32", "bf16_fp32head": "bf16", "bf16_bf16head": "bf16_head_bf16"}, - "vllm": {"fp32": "fp32", "bf16_fp32head": "bf16_fp32_head", "bf16_bf16head": "bf16"}, + "fast_llm": { + "fp32": "reference", + "bf16_fp32head": "bf16_fp32_lm_head", + "bf16_bf16head": "bf16", + "fp16_fp32head": "fp16_fp32_lm_head", + "fp16_fp16head": "fp16", + }, + "deepspeed": { + "fp32": "fp32", + "bf16_fp32head": "bf16", + "bf16_bf16head": "bf16_head_bf16", + "fp16_fp32head": "fp16", + "fp16_fp16head": "fp16_head_fp16", + }, + "vllm": { + "fp32": "fp32", + "bf16_fp32head": "bf16_fp32_head", + "bf16_bf16head": "bf16", + "fp16_fp16head": "fp16", + }, } +_PRECISIONS = ("bf16", "fp16") _ENGINE_ORDER = ("fast_llm", "deepspeed", "vllm") _ENGINE_LABELS = {"fast_llm": "Fast-LLM", "deepspeed": "DeepSpeed", "vllm": "vLLM"} +def _fp32_head(precision: str) -> str: + return f"{precision}_fp32head" + + +def _body_head(precision: str) -> str: + # Head / logits in the body dtype (no fp32 upcast) — the low-precision-head config. + return f"{precision}_{precision}head" + + def _rms(x: torch.Tensor) -> float: return x.pow(2).mean().sqrt().item() @@ -90,28 +120,12 @@ def _load_engine(engine: str, directory: pathlib.Path) -> dict[str, torch.Tensor return vectors -def _error_vector(vectors: dict[str, torch.Tensor], engine: str, mode: str) -> torch.Tensor | None: - # bf16 rounding error against this engine's own fp32. In the mismatched config only vLLM keeps a - # bf16 head; the trainers always use the fp32 head. - config = "bf16_bf16head" if (mode == "mismatched" and engine == "vllm") else "bf16_fp32head" - if config not in vectors or "fp32" not in vectors: - return None - return vectors[config] - vectors["fp32"] - - -def _pair_config(engine: str, mode: str) -> str: - if mode == "fp32": - return "fp32" - if mode == "matched": - return "bf16_fp32head" - return "bf16_bf16head" if engine == "vllm" else "bf16_fp32head" - - def _print_delta_table( rows: list[tuple[str, str, str, dict[str, float]]], sequence_length: int, epsilon: float, label: str ) -> None: print(f"\n=== Cross-engine log π gap{f' [{label}]' if label else ''} (δ = A − B, nats) ===") print(f"(per-sequence sum over {sequence_length} tokens; clip = fraction with |δ| > {epsilon})") + print("(matched = both engines fp32 head; mismatched = A fp32 head, B body-dtype head)") header = f"{'group':<22} {'A − B':<22} {'mean δ':>10} {'RMS δ':>9} {'max|δ|':>9} {'Σδ (seq)':>10} {'clip%':>7}" print(header) print("-" * len(header)) @@ -124,10 +138,13 @@ def _print_delta_table( def _print_decomposition_table(rows: list[tuple[str, str, str, dict[str, float]]], label: str) -> None: - print(f"\n=== Error-correlation decomposition{f' [{label}]' if label else ''} (e = bf16 − fp32, nats) ===") + print( + f"\n=== Error-correlation decomposition{f' [{label}]' if label else ''}" + " (e = low-precision − fp32, nats) ===" + ) print("(δ = floor + (e_A − e_B); ρ = corr(e_A, e_B): ρ→1 errors cancel, ρ→0 add in quadrature)") header = ( - f"{'config':<12} {'A − B':<22} {'ρ(err)':>8} {'RMS e_A':>9} {'RMS e_B':>9}" + f"{'config':<16} {'A − B':<22} {'ρ(err)':>8} {'RMS e_A':>9} {'RMS e_B':>9}" f" {'RMS(e_A−e_B)':>13} {'RMS floor':>10} {'slope':>8}" ) print(header) @@ -135,7 +152,7 @@ def _print_decomposition_table(rows: list[tuple[str, str, str, dict[str, float]] for config, engine_a, engine_b, stats in rows: pair = f"{_ENGINE_LABELS[engine_a]} − {_ENGINE_LABELS[engine_b]}" print( - f"{config:<12} {pair:<22} {stats['rho']:>8.4f} {stats['rms_a']:>9.4f} {stats['rms_b']:>9.4f}" + f"{config:<16} {pair:<22} {stats['rho']:>8.4f} {stats['rms_a']:>9.4f} {stats['rms_b']:>9.4f}" f" {stats['rms_diff']:>13.4f} {stats['rms_floor']:>10.4f} {stats['slope']:>+8.4f}" ) @@ -165,41 +182,45 @@ def main() -> None: available = [engine for engine in _ENGINE_ORDER if engine in engines] pairs = list(itertools.combinations(available, 2)) + # δ table: fp32 floor, then per precision the matched (both fp32 head) and mismatched (A fp32 head, + # B body-dtype head) gaps over every available pair. delta_rows: list[tuple[str, str, str, dict[str, float]]] = [] - for mode, group_label in (("fp32", "fp32 floor"), ("matched", "matched (fp32 head)")): - for engine_a, engine_b in pairs: - config_a, config_b = _pair_config(engine_a, mode), _pair_config(engine_b, mode) - if config_a in engines[engine_a] and config_b in engines[engine_b]: - stats = _delta_stats(engines[engine_a][config_a], engines[engine_b][config_b], args.epsilon) - delta_rows.append((group_label, engine_a, engine_b, stats)) for engine_a, engine_b in pairs: - if "vllm" not in (engine_a, engine_b): - continue - config_a, config_b = _pair_config(engine_a, "mismatched"), _pair_config(engine_b, "mismatched") - if config_a in engines[engine_a] and config_b in engines[engine_b]: - stats = _delta_stats(engines[engine_a][config_a], engines[engine_b][config_b], args.epsilon) - delta_rows.append(("mismatched (vLLM bf16 head)", engine_a, engine_b, stats)) + if "fp32" in engines[engine_a] and "fp32" in engines[engine_b]: + stats = _delta_stats(engines[engine_a]["fp32"], engines[engine_b]["fp32"], args.epsilon) + delta_rows.append(("fp32 floor", engine_a, engine_b, stats)) + for precision in _PRECISIONS: + fp32_head, body_head = _fp32_head(precision), _body_head(precision) + for group_label, config_b in ((f"{precision} matched", fp32_head), (f"{precision} mismatched", body_head)): + for engine_a, engine_b in pairs: + if fp32_head in engines[engine_a] and config_b in engines[engine_b]: + stats = _delta_stats(engines[engine_a][fp32_head], engines[engine_b][config_b], args.epsilon) + delta_rows.append((group_label, engine_a, engine_b, stats)) decomposition_rows: list[tuple[str, str, str, dict[str, float]]] = [] - for mode in ("matched", "mismatched"): - for engine_a, engine_b in pairs: - if mode == "mismatched" and "vllm" not in (engine_a, engine_b): - continue - error_a = _error_vector(engines[engine_a], engine_a, mode) - error_b = _error_vector(engines[engine_b], engine_b, mode) - if error_a is None or error_b is None: - continue - floor = engines[engine_a]["fp32"] - engines[engine_b]["fp32"] - config_a, config_b = _pair_config(engine_a, mode), _pair_config(engine_b, mode) - stats = { - "rho": _corr(error_a, error_b), - "rms_a": _rms(error_a), - "rms_b": _rms(error_b), - "rms_diff": _rms(error_a - error_b), - "rms_floor": _rms(floor), - "slope": _slope(engines[engine_b][config_b], engines[engine_a][config_a]), - } - decomposition_rows.append((mode, engine_a, engine_b, stats)) + for precision in _PRECISIONS: + fp32_head, body_head = _fp32_head(precision), _body_head(precision) + for mode, config_b in (("matched", fp32_head), ("mismatched", body_head)): + for engine_a, engine_b in pairs: + if not ( + "fp32" in engines[engine_a] + and "fp32" in engines[engine_b] + and fp32_head in engines[engine_a] + and config_b in engines[engine_b] + ): + continue + error_a = engines[engine_a][fp32_head] - engines[engine_a]["fp32"] + error_b = engines[engine_b][config_b] - engines[engine_b]["fp32"] + floor = engines[engine_a]["fp32"] - engines[engine_b]["fp32"] + stats = { + "rho": _corr(error_a, error_b), + "rms_a": _rms(error_a), + "rms_b": _rms(error_b), + "rms_diff": _rms(error_a - error_b), + "rms_floor": _rms(floor), + "slope": _slope(engines[engine_b][config_b], engines[engine_a][fp32_head]), + } + decomposition_rows.append((f"{precision} {mode}", engine_a, engine_b, stats)) _print_delta_table(delta_rows, sequence_length, args.epsilon, args.label) _print_decomposition_table(decomposition_rows, args.label) From 2828c3524fa94c45f61c76d23a779169066fc531 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 2 Jun 2026 12:49:31 -0400 Subject: [PATCH 38/41] Cross-engine: full 2x2 head matrix per pair, corrected mismatch direction Enumerate every combination of {fp32 head, body-dtype head} on each side per engine pair, rather than one matched + one mismatched row. This adds the production-relevant direction that was missing: a body-dtype head on the training side against vLLM's fp32 head (vLLM always emits fp32 logits in production), the prior single mismatched row had it reversed. Per-side head columns make each row's config explicit; the decomposition mirrors the same precision/head/pair combinations. Co-Authored-By: Claude Opus 4.8 (1M context) --- tools/evaluate_precision_cross_engine.py | 105 +++++++++++++---------- 1 file changed, 58 insertions(+), 47 deletions(-) diff --git a/tools/evaluate_precision_cross_engine.py b/tools/evaluate_precision_cross_engine.py index 8d72ed61d..b385b3d8c 100644 --- a/tools/evaluate_precision_cross_engine.py +++ b/tools/evaluate_precision_cross_engine.py @@ -15,9 +15,10 @@ each engine's variant names to three canonical configs and reports two things per regime: 1. the δ distribution (mean = systematic bias, RMS, max, per-sequence sum, PPO clip fraction) over - every engine pair, for the fp32 floor and, per low precision (bf16, fp16), the matched config - (fp32 head on both sides) and the mismatched config (A keeps the fp32 head, B runs the head in - the body dtype — e.g. vLLM's as-shipped bf16 head vs a trainer's fp32 head); + every engine pair, for the fp32 floor and, per low precision (bf16, fp16), the full 2×2 of head + choice (fp32 upcast vs body-dtype head) on each side — so both matched cases (both fp32, both + body) and both mismatch directions appear. In production vLLM emits fp32 logits, so the relevant + mismatch is a body-dtype head on the *training* side against vLLM's fp32 head; 2. the error-correlation decomposition. With e = log π_bf16 − log π_fp32 per engine, δ_AB = (fp32 floor) + (e_A − e_B), and RMS(e_A − e_B) is governed by ρ = corr(e_A, e_B): ρ→1 means the engines round the same way and the errors cancel (gap collapses to the floor); @@ -68,13 +69,18 @@ _ENGINE_LABELS = {"fast_llm": "Fast-LLM", "deepspeed": "DeepSpeed", "vllm": "vLLM"} -def _fp32_head(precision: str) -> str: - return f"{precision}_fp32head" +# Head precision per side: "fp32" upcasts the head / logits to fp32; "body" runs the head in the body +# dtype. In production vLLM always emits fp32 logits, so the relevant mismatch is a body-dtype head on +# the *training* side against vLLM's fp32 head. +_HEADS = ("fp32", "body") -def _body_head(precision: str) -> str: - # Head / logits in the body dtype (no fp32 upcast) — the low-precision-head config. - return f"{precision}_{precision}head" +def _head_config(precision: str, head: str) -> str: + return f"{precision}_fp32head" if head == "fp32" else f"{precision}_{precision}head" + + +def _head_label(precision: str, head: str) -> str: + return "fp32" if head == "fp32" else precision def _rms(x: torch.Tensor) -> float: @@ -125,15 +131,19 @@ def _print_delta_table( ) -> None: print(f"\n=== Cross-engine log π gap{f' [{label}]' if label else ''} (δ = A − B, nats) ===") print(f"(per-sequence sum over {sequence_length} tokens; clip = fraction with |δ| > {epsilon})") - print("(matched = both engines fp32 head; mismatched = A fp32 head, B body-dtype head)") - header = f"{'group':<22} {'A − B':<22} {'mean δ':>10} {'RMS δ':>9} {'max|δ|':>9} {'Σδ (seq)':>10} {'clip%':>7}" + print("(head = fp32 upcast vs body-dtype; production has vLLM head fp32, so the relevant") + print(" mismatch is a body-dtype head on the trainer side against vLLM's fp32 head)") + header = ( + f"{'group':<11} {'A − B':<22} {'A head':>7} {'B head':>7} {'mean δ':>10} {'RMS δ':>9}" + f" {'max|δ|':>9} {'Σδ (seq)':>10} {'clip%':>7}" + ) print(header) print("-" * len(header)) - for group, engine_a, engine_b, stats in rows: + for group, engine_a, engine_b, head_a, head_b, stats in rows: pair = f"{_ENGINE_LABELS[engine_a]} − {_ENGINE_LABELS[engine_b]}" print( - f"{group:<22} {pair:<22} {stats['mean']:>+10.4f} {stats['rms']:>9.4f} {stats['max']:>9.4f}" - f" {stats['sum']:>+10.2f} {stats['clip'] * 100:>6.2f}%" + f"{group:<11} {pair:<22} {head_a:>7} {head_b:>7} {stats['mean']:>+10.4f} {stats['rms']:>9.4f}" + f" {stats['max']:>9.4f} {stats['sum']:>+10.2f} {stats['clip'] * 100:>6.2f}%" ) @@ -144,16 +154,16 @@ def _print_decomposition_table(rows: list[tuple[str, str, str, dict[str, float]] ) print("(δ = floor + (e_A − e_B); ρ = corr(e_A, e_B): ρ→1 errors cancel, ρ→0 add in quadrature)") header = ( - f"{'config':<16} {'A − B':<22} {'ρ(err)':>8} {'RMS e_A':>9} {'RMS e_B':>9}" + f"{'prec':<5} {'A − B':<22} {'A head':>7} {'B head':>7} {'ρ(err)':>8} {'RMS e_A':>9} {'RMS e_B':>9}" f" {'RMS(e_A−e_B)':>13} {'RMS floor':>10} {'slope':>8}" ) print(header) print("-" * len(header)) - for config, engine_a, engine_b, stats in rows: + for precision, engine_a, engine_b, head_a, head_b, stats in rows: pair = f"{_ENGINE_LABELS[engine_a]} − {_ENGINE_LABELS[engine_b]}" print( - f"{config:<16} {pair:<22} {stats['rho']:>8.4f} {stats['rms_a']:>9.4f} {stats['rms_b']:>9.4f}" - f" {stats['rms_diff']:>13.4f} {stats['rms_floor']:>10.4f} {stats['slope']:>+8.4f}" + f"{precision:<5} {pair:<22} {head_a:>7} {head_b:>7} {stats['rho']:>8.4f} {stats['rms_a']:>9.4f}" + f" {stats['rms_b']:>9.4f} {stats['rms_diff']:>13.4f} {stats['rms_floor']:>10.4f} {stats['slope']:>+8.4f}" ) @@ -182,45 +192,46 @@ def main() -> None: available = [engine for engine in _ENGINE_ORDER if engine in engines] pairs = list(itertools.combinations(available, 2)) - # δ table: fp32 floor, then per precision the matched (both fp32 head) and mismatched (A fp32 head, - # B body-dtype head) gaps over every available pair. - delta_rows: list[tuple[str, str, str, dict[str, float]]] = [] + # δ table: fp32 floor, then per precision the full 2×2 of head choice (fp32 vs body-dtype) on each + # side, over every available pair. The decomposition mirrors each precision/head/pair combination. + delta_rows: list[tuple[str, str, str, str, str, dict[str, float]]] = [] for engine_a, engine_b in pairs: if "fp32" in engines[engine_a] and "fp32" in engines[engine_b]: stats = _delta_stats(engines[engine_a]["fp32"], engines[engine_b]["fp32"], args.epsilon) - delta_rows.append(("fp32 floor", engine_a, engine_b, stats)) - for precision in _PRECISIONS: - fp32_head, body_head = _fp32_head(precision), _body_head(precision) - for group_label, config_b in ((f"{precision} matched", fp32_head), (f"{precision} mismatched", body_head)): - for engine_a, engine_b in pairs: - if fp32_head in engines[engine_a] and config_b in engines[engine_b]: - stats = _delta_stats(engines[engine_a][fp32_head], engines[engine_b][config_b], args.epsilon) - delta_rows.append((group_label, engine_a, engine_b, stats)) + delta_rows.append(("fp32 floor", engine_a, engine_b, "fp32", "fp32", stats)) - decomposition_rows: list[tuple[str, str, str, dict[str, float]]] = [] + decomposition_rows: list[tuple[str, str, str, str, str, dict[str, float]]] = [] for precision in _PRECISIONS: - fp32_head, body_head = _fp32_head(precision), _body_head(precision) - for mode, config_b in (("matched", fp32_head), ("mismatched", body_head)): + for head_a, head_b in itertools.product(_HEADS, repeat=2): + config_a, config_b = _head_config(precision, head_a), _head_config(precision, head_b) + label_a, label_b = _head_label(precision, head_a), _head_label(precision, head_b) for engine_a, engine_b in pairs: - if not ( - "fp32" in engines[engine_a] - and "fp32" in engines[engine_b] - and fp32_head in engines[engine_a] - and config_b in engines[engine_b] - ): + if config_a not in engines[engine_a] or config_b not in engines[engine_b]: + continue + stats = _delta_stats(engines[engine_a][config_a], engines[engine_b][config_b], args.epsilon) + delta_rows.append((precision, engine_a, engine_b, label_a, label_b, stats)) + if "fp32" not in engines[engine_a] or "fp32" not in engines[engine_b]: continue - error_a = engines[engine_a][fp32_head] - engines[engine_a]["fp32"] + error_a = engines[engine_a][config_a] - engines[engine_a]["fp32"] error_b = engines[engine_b][config_b] - engines[engine_b]["fp32"] floor = engines[engine_a]["fp32"] - engines[engine_b]["fp32"] - stats = { - "rho": _corr(error_a, error_b), - "rms_a": _rms(error_a), - "rms_b": _rms(error_b), - "rms_diff": _rms(error_a - error_b), - "rms_floor": _rms(floor), - "slope": _slope(engines[engine_b][config_b], engines[engine_a][fp32_head]), - } - decomposition_rows.append((f"{precision} {mode}", engine_a, engine_b, stats)) + decomposition_rows.append( + ( + precision, + engine_a, + engine_b, + label_a, + label_b, + { + "rho": _corr(error_a, error_b), + "rms_a": _rms(error_a), + "rms_b": _rms(error_b), + "rms_diff": _rms(error_a - error_b), + "rms_floor": _rms(floor), + "slope": _slope(engines[engine_b][config_b], engines[engine_a][config_a]), + }, + ) + ) _print_delta_table(delta_rows, sequence_length, args.epsilon, args.label) _print_decomposition_table(decomposition_rows, args.label) From ca0c7b2644dc8fbd9d1c6121ea2307a150ead88a Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 2 Jun 2026 18:17:27 -0400 Subject: [PATCH 39/41] Add forward-only inference mode to the precision tool MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Runs a single forward in `StageMode.inference` (no optimizer, no gradient buffers) instead of forward+backward, so large models (e.g. 7B in fp32) fit where forward+backward+Adam would OOM. The LM head skips all losses in eval mode, so after setup the head(s) are forced back into train mode directly; `run_step`'s per-step `train(False)` is a guarded no-op once `_training` is False, keeping the head trained so `chosen_logprob` still logs. Only `chosen_logprob` is configured (no grad-producing loss), so no backward ever touches the absent gradient buffers. Uses a validation-phase schedule (forward-only but still produces labels, unlike inference phase). Verified the forward-only log π is bitwise-identical to the forward+backward path. Co-Authored-By: Claude Opus 4.8 (1M context) --- tools/evaluate_precision.py | 82 ++++++++++++++++++++++++++----------- 1 file changed, 57 insertions(+), 25 deletions(-) diff --git a/tools/evaluate_precision.py b/tools/evaluate_precision.py index 3bf5e4c12..29ee1428b 100644 --- a/tools/evaluate_precision.py +++ b/tools/evaluate_precision.py @@ -90,6 +90,14 @@ class EvaluatePrecisionConfig(PretrainedGPTModelConfig, RunnableConfig): " (`tools/evaluate_precision_deepspeed.py`) can consume the identical model input.", hint=FieldHint.feature, ) + forward_only: bool = Field( + default=False, + desc="Run a single forward pass in inference mode (`StageMode.inference`, no optimizer or" + " gradient buffers) instead of forward+backward. Fits large models (e.g. 7B in fp32) that would" + " OOM with gradient and optimizer state. Only the per-token log π (chosen_logprob) and forward" + " activations are captured — no gradient/parameter-gradient tables, and no gradient scaling.", + hint=FieldHint.feature, + ) def _validate(self) -> None: super()._validate() @@ -106,7 +114,12 @@ def run(self) -> None: runs.update(self.variants) scales: dict[str, float] = {} for name, variant_overrides in runs.items(): - scales[name] = self._calibrate_and_run(name, variant_overrides, input_ids) + if self.forward_only: + # No backward -> no gradients to scale; run once directly. + self._run_one(name, variant_overrides, input_ids) + scales[name] = 1.0 + else: + scales[name] = self._calibrate_and_run(name, variant_overrides, input_ids) self._save_chosen_logprob(name) ref_artifacts = self._artifact_path(_REFERENCE_NAME) @@ -266,8 +279,6 @@ def _run_one( # Tool-required overrides win over variants — a variant must not silently disable tensor logging. tool_overrides: dict[tuple[str, ...], typing.Any] = { ("model", "multi_stage", "debug_layer_outputs"): log_level, - ("model", "multi_stage", "debug_layer_gradients"): log_level, - ("model", "multi_stage", "debug_all_param_gradients"): log_level, # Capture the LM-head logits via the `output_hidden_states` mechanism: the head's # `_debug(logits, ...)` call matches this pattern and emits to `tensor_logs`. ("model", "multi_stage", "debug_hidden_states_log"): [r"head\.logits"], @@ -275,17 +286,22 @@ def _run_one( # Contributes no gradient (weight=0); the comparison code picks it up by name. ("model", "base_model", "head", "losses", _CHOSEN_LOGPROB_NAME): {"type": "chosen_logprob"}, } - # When the user hasn't configured any loss, the head defaults to cross-entropy. Adding a - # loss explicitly suppresses that default, so re-add it so gradients still flow. - if not (self.model.base_model.head.losses or {}): - tool_overrides[("model", "base_model", "head", "losses", "cross_entropy")] = {"type": "label"} + if not self.forward_only: + tool_overrides[("model", "multi_stage", "debug_layer_gradients")] = log_level + tool_overrides[("model", "multi_stage", "debug_all_param_gradients")] = log_level + # When the user hasn't configured any loss, the head defaults to cross-entropy. Adding a + # loss explicitly suppresses that default, so re-add it so gradients still flow. + if not (self.model.base_model.head.losses or {}): + tool_overrides[("model", "base_model", "head", "losses", "cross_entropy")] = {"type": "label"} + # In forward-only mode only chosen_logprob runs (no grad-producing loss), so no backward + # happens and `StageMode.inference` (which allocates no gradient buffers) is sufficient. logger.info(f"=== Running {name!r} ===") if variant_overrides: logger.info(f"Variant overrides: {variant_overrides}") trainer_class = TrainerConfig.get_subclass(_MODEL_TYPE) trainer_config = trainer_class.from_dict(base_dict, fp32_dtypes, variant_updates, tool_overrides) trainer_config.configure_logging() - _run_fixed_input(trainer_config, input_ids, self.sequence_length) + _run_fixed_input(trainer_config, input_ids, self.sequence_length, forward_only=self.forward_only) def _compare( self, @@ -329,12 +345,20 @@ def _compare( return rows -def _run_fixed_input(config, input_ids, sequence_length: int) -> None: - """Lean forward+backward on a fixed, already-preprocessed input — like `InferenceRunner` but with a - training-phase schedule + an (lr-0) optimizer so `run_step` runs the backward and the existing - chosen-logprob loss / `debug_all_param_gradients` logging captures everything. Replaces the trainer - + data pipeline so the model sees exactly `input_ids` (the pipeline would re-randomize it) and so the - tool stops paying for training/data-loading infrastructure it doesn't need.""" +def _run_fixed_input(config, input_ids, sequence_length: int, *, forward_only: bool = False) -> None: + """Lean run on a fixed, already-preprocessed input — like `InferenceRunner` but feeding a fixed input + so the model sees exactly `input_ids` (the data pipeline would re-randomize it) and the tool stops + paying for training/data-loading infrastructure it doesn't need. + + Default mode is a training-phase schedule + an (lr-0) optimizer so `run_step` runs the backward and + the chosen-logprob loss / `debug_all_param_gradients` logging captures everything. + + `forward_only=True` runs a single forward in inference mode: `StageMode.inference` (no gradient + buffers), no optimizer, and a validation-phase schedule (forward-only, but still produces labels — + `PhaseType.inference` would zero `num_labels`). The head skips all losses in eval mode, so after setup + the head(s) are forced back into train mode directly; `run_step`'s per-step `multi_stage.train(False)` + is a guarded no-op once `_training` is False, so the head stays trained and logs chosen_logprob. + Valid only because no grad-producing loss is configured, so no backward touches the missing buffers.""" import gc import torch @@ -347,31 +371,39 @@ def _run_fixed_input(config, input_ids, sequence_length: int) -> None: from fast_llm.engine.schedule.runner import ScheduleRunner from fast_llm.engine.schedule.schedule import Schedule + phase = PhaseType.validation if forward_only else PhaseType.training distributed = Distributed(config.model.distributed) run = config.get_run(distributed) + optimizer = None with run: multi_stage = config.model.get_model_class()( - config.model, optimizer_state_names=config.optimizer.state_names() + config.model, optimizer_state_names=() if forward_only else config.optimizer.state_names() ) with torch.no_grad(): - multi_stage.setup(distributed, mode=StageMode.training) + multi_stage.setup(distributed, mode=StageMode.inference if forward_only else StageMode.training) if config.pretrained.path is not None and config.pretrained.model_weights: multi_stage.load_checkpoint(config.pretrained) else: multi_stage.initialize_weights() - param_groups, grads_for_norm = multi_stage.get_param_groups(ParamGroup) - optimizer = config.optimizer.optimizer_cls( - config.optimizer, param_groups=param_groups, grads_for_norm=grads_for_norm, distributed=distributed - ) - optimizer.reset_state() + if not forward_only: + param_groups, grads_for_norm = multi_stage.get_param_groups(ParamGroup) + optimizer = config.optimizer.optimizer_cls( + config.optimizer, param_groups=param_groups, grads_for_norm=grads_for_norm, distributed=distributed + ) + optimizer.reset_state() runner = ScheduleRunner( config=config.schedule, multi_stage=multi_stage, distributed_config=config.model.distributed ) with torch.no_grad(): runner.setup(distributed, optimizer) - preprocessing_config = multi_stage.get_preprocessing_config( - PhaseType.training, config.schedule.micro_batch_splits - ) + if forward_only: + from fast_llm.layers.language_model.head import LanguageModelHead + + multi_stage.train(False) + for module in multi_stage.base_model.modules(): + if isinstance(module, LanguageModelHead): + module.train(True) + preprocessing_config = multi_stage.get_preprocessing_config(phase, config.schedule.micro_batch_splits) # `get_model_inputs` splits off `num_labels` tokens for the shifted next-token labels, so the # actual model input is `len(tokens) - num_labels`. The schedule meta must match that length. schedule = Schedule( @@ -379,7 +411,7 @@ def _run_fixed_input(config, input_ids, sequence_length: int) -> None: multi_stage=multi_stage, batch_meta=preprocessing_config.get_input_meta(sequence_length - preprocessing_config.num_labels), distributed_config=config.model.distributed, - phase=PhaseType.training, + phase=phase, ) tokens = input_ids.flatten().to(device=distributed.device, dtype=torch.int64) batch = LanguageModelBatch(tokens=tokens, lengths=[tokens.numel()]) From f2b9d212a10f140a8f01867d7257482e396c8272 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 2 Jun 2026 18:36:48 -0400 Subject: [PATCH 40/41] Add forward-only mode to the DeepSpeed precision tool MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `--forward-only` initializes the DeepSpeed engine without an optimizer (no fp32 master copy or Adam state) and runs a single eval()+no_grad forward, for large models (e.g. 7B in fp32) where forward+backward+Adam would OOM. The engine is kept rather than bypassed for a plain HF forward: DeepSpeed's bf16/fp16 forward is not bit-identical to a plain HF forward in the same dtype — measured ~0.032 nats mean / 0.22 max on Qwen2.5-0.5B bf16, comparable to the cross-engine signal itself — so bypassing it would shift the log π. The no-optimizer engine forward is bitwise-identical to the full forward+backward path across all variants (fp32/bf16/fp16, head on/off). Co-Authored-By: Claude Opus 4.8 (1M context) --- tools/evaluate_precision_deepspeed.py | 38 +++++++++++++++++++++------ 1 file changed, 30 insertions(+), 8 deletions(-) diff --git a/tools/evaluate_precision_deepspeed.py b/tools/evaluate_precision_deepspeed.py index f02c8e17a..015f114ce 100644 --- a/tools/evaluate_precision_deepspeed.py +++ b/tools/evaluate_precision_deepspeed.py @@ -111,11 +111,11 @@ def build_input_ids(tokenizer, vocab_size: int, sequence_length: int, device: to return ids.unsqueeze(0).to(device) -def _ds_config(dtype: torch.dtype) -> dict[str, typing.Any]: - config: dict[str, typing.Any] = { - "train_micro_batch_size_per_gpu": 1, - "optimizer": {"type": "Adam", "params": {"lr": 1e-6}}, - } +def _ds_config(dtype: torch.dtype, forward_only: bool = False) -> dict[str, typing.Any]: + config: dict[str, typing.Any] = {"train_micro_batch_size_per_gpu": 1} + if not forward_only: + # No optimizer for forward-only: avoids the fp32 master copy + Adam state that would OOM a 7B run. + config["optimizer"] = {"type": "Adam", "params": {"lr": 1e-6}} if dtype == torch.bfloat16: config["bf16"] = {"enabled": True} elif dtype == torch.float16: @@ -140,9 +140,15 @@ def capture_variant( input_ids: torch.Tensor, attn_implementation: str, random_init: bool = False, + forward_only: bool = False, ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: - """Forward + backward one variant through a DeepSpeed engine. Returns (chosen_logprob, - {param_name: gradient}), both on CPU in fp32.""" + """Capture one variant. Returns (chosen_logprob, {param_name: gradient}), both on CPU in fp32. + + Default: forward + backward through a DeepSpeed engine. `forward_only=True` initializes the same + DeepSpeed engine but without an optimizer (no fp32 master copy / Adam state, which would OOM a 7B + run) and runs a single `eval()` + `no_grad` forward (returns empty gradients). The engine is kept — + DeepSpeed's bf16/fp16 forward is not bit-identical to a plain HF forward in the same dtype, so + bypassing it would shift the measured log π.""" import deepspeed import transformers @@ -156,6 +162,15 @@ def capture_variant( ) if fp32_head: apply_fp32_lm_head(model) + if forward_only: + engine, *_ = deepspeed.initialize(model=model, config=_ds_config(dtype, forward_only=True)) + engine.eval() + with torch.no_grad(): + logprob = chosen_logprob(engine(input_ids).logits, input_ids).detach().float().cpu() + del engine, model + torch.cuda.empty_cache() + return logprob, {} + engine, *_ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=_ds_config(dtype)) outputs = engine(input_ids) logprob = chosen_logprob(outputs.logits, input_ids) @@ -245,6 +260,13 @@ def main() -> None: help="Build the model from config with random weights instead of loading the pretrained" " checkpoint (contrast with the pretrained run; weights won't match Fast-LLM's random init).", ) + parser.add_argument( + "--forward-only", + action="store_true", + help="Initialize the DeepSpeed engine without an optimizer and run a single eval()+no_grad" + " forward (no optimizer state, no gradients). Fits large models (e.g. 7B fp32) where" + " forward+backward+Adam would OOM. The gradient table is then empty.", + ) parser.add_argument( "--output-dir", default=None, @@ -286,7 +308,7 @@ def main() -> None: for name, dtype, fp32_head in _VARIANTS: logger.info(f"=== variant {name} (dtype={dtype}, fp32_head={fp32_head}) ===") logprob, grads = capture_variant( - args.model, dtype, fp32_head, input_ids, args.attn_implementation, args.random_init + args.model, dtype, fp32_head, input_ids, args.attn_implementation, args.random_init, args.forward_only ) if args.output_dir is not None: output_dir = pathlib.Path(args.output_dir) From 32c9545c7eca5e17fc53aa6a4ec87054296246c3 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 2 Jun 2026 19:33:20 -0400 Subject: [PATCH 41/41] Add Qwen2.5-7B precision-evaluation config (forward-only) 7B is untied, so the fp32 LM head genuinely changes the logits (unlike the tied 0.5B). Forward-only so the fp32 reference fits in memory. Co-Authored-By: Claude Opus 4.8 (1M context) --- examples/evaluate_precision/qwen_7b.yaml | 28 ++++++++++++++++++++++++ 1 file changed, 28 insertions(+) create mode 100644 examples/evaluate_precision/qwen_7b.yaml diff --git a/examples/evaluate_precision/qwen_7b.yaml b/examples/evaluate_precision/qwen_7b.yaml new file mode 100644 index 000000000..508960b83 --- /dev/null +++ b/examples/evaluate_precision/qwen_7b.yaml @@ -0,0 +1,28 @@ +# Precision-evaluation config on Qwen2.5-7B. Unlike the 0.5B model, the 7B has untied embeddings, +# so the LM head is a real parameter and an fp32 head genuinely changes the logits (on tied models +# it can be a no-op). `forward_only` runs a single inference-mode forward so the fp32 reference fits +# in memory — forward+backward+Adam in fp32 would not. +# +# Run with: +# python -m tools.evaluate_precision -c examples/evaluate_precision/qwen_7b.yaml +pretrained: + path: Qwen/Qwen2.5-7B + format: qwen2 +output_dir: /tmp/fast_llm_tests/evaluate_precision/qwen_7b +sequence_length: 2048 +forward_only: true +variants: + # compute bf16, lm head in compute dtype. + bf16: + model.distributed.compute_dtype: bfloat16 + # compute bf16, fp32 lm head (the stack default). + bf16_fp32_lm_head: + model.distributed.compute_dtype: bfloat16 + model.base_model.head.fp32_lm_head: true + # compute fp16, lm head in compute dtype. + fp16: + model.distributed.compute_dtype: float16 + # compute fp16, fp32 lm head. + fp16_fp32_lm_head: + model.distributed.compute_dtype: float16 + model.base_model.head.fp32_lm_head: true