From 0721ed2544596d905fd90786ca23a88a5e15a703 Mon Sep 17 00:00:00 2001 From: Xiaokang Shang Date: Wed, 15 Apr 2026 11:39:19 +0800 Subject: [PATCH 1/3] [Debug] Add AutoswitchGEmm for Debug Precision Tool --- docs/debug/1_getting_started.rst | 48 +- docs/debug/2_config_file_structure.rst | 22 + docs/debug/3_api_features.rst | 1 + docs/debug/autoswitch_gemm_example.yaml | 72 +++ .../debug/features/autoswitch_gemm.py | 585 ++++++++++++++++++ .../debug/pytorch/debug_quantization.py | 87 ++- transformer_engine/pytorch/module/base.py | 2 +- 7 files changed, 800 insertions(+), 17 deletions(-) create mode 100644 docs/debug/autoswitch_gemm_example.yaml create mode 100644 transformer_engine/debug/features/autoswitch_gemm.py diff --git a/docs/debug/1_getting_started.rst b/docs/debug/1_getting_started.rst index cce2616998..ac36acf990 100644 --- a/docs/debug/1_getting_started.rst +++ b/docs/debug/1_getting_started.rst @@ -149,10 +149,11 @@ Inspecting the logs ------------------- -Let's look at the files with the logs. Two files will be created: +Let's look at the files with the logs. At least two files will be created: 1. debug logs. 2. statistics logs. +3. optional feature-specific logs (for example AutoswitchGemm metrics). Let's look inside them! @@ -214,6 +215,51 @@ The second log file (``nvdlfw_inspect_statistics_logs/nvdlfw_inspect_globalrank- INFO - transformer_layer.self_attention.layernorm_qkv_activation_std iteration=000004 value=0.9996 INFO - transformer_layer.self_attention.layernorm_qkv_activation_l1_norm iteration=000004 value=130776.7969 +AutoswitchGemm quick guide +-------------------------- + +``AutoswitchGemm`` monitors quantization quality and can dynamically switch selected GEMMs +to high precision when thresholds are exceeded. + +Minimal config example: + +.. code-block:: yaml + + autoswitch_fc_layers: + enabled: True + layers: + layer_types: [fc1, fc2] + transformer_engine: + AutoswitchGemm: + enabled: True + gemms: [fprop, dgrad, wgrad] + underflow_threshold_pct: 1.0 + mse_threshold: 1.0e-4 + # Needed only if the layer uses fp8 model parameters and + # you want fprop/dgrad to be able to switch to high precision. + allow_fp8_model_params_dequantized_weight: False + freq: 1 + +Behavior summary: + +1. For each ``(layer, gemm)``, AutoswitchGemm tracks the latest tensor metrics and applies + OR logic across monitored tensors: if any tensor breaches thresholds, that GEMM switches. +2. Metrics computed in iteration ``n`` are consumed in iteration ``n`` only. +3. If thresholds are not breached in the current iteration, the GEMM stays quantized. + +When AutoswitchGemm is enabled, an additional directory is created under ``log_dir``: + +``nvdlfw_inspect_autoswitchgemm_logs/nvdlfw_inspect_globalrank-.log`` + +It contains per-rank, per-iteration metrics such as: + +- ``___underflow_pct`` +- ``___mse`` +- ``__quantized_enabled`` +- ``__disable_until_iter`` +- ``__switch_blocked_fp8_model_params`` +- ``__fp8_model_params_dequantized_fallback`` + Logging using TensorBoard ------------------------- diff --git a/docs/debug/2_config_file_structure.rst b/docs/debug/2_config_file_structure.rst index 3ade970b57..28da6beab3 100644 --- a/docs/debug/2_config_file_structure.rst +++ b/docs/debug/2_config_file_structure.rst @@ -220,6 +220,28 @@ We can use both structs for tensors and GEMMs. The tensors_struct should be nest tensor_feature_param2: value gemm_feature_param1: value +AutoswitchGemm notes +-------------------- + +``AutoswitchGemm`` supports both global and per-GEMM configuration. + +- Use ``gemms: [...]`` for one shared policy. +- Use ``gemms_struct`` to set per-GEMM thresholds. + +If ``tensors``/``tensors_struct`` are omitted, monitored tensors are inferred from GEMMs: + +- ``fprop`` -> ``activation``, ``weight`` +- ``dgrad`` -> ``gradient``, ``weight`` +- ``wgrad`` -> ``activation``, ``gradient`` + +Other important keys: + +- ``underflow_threshold_pct``: switch trigger based on underflow percentage. +- ``mse_threshold``: switch trigger based on quantization MSE. +- metrics are consumed in the same iteration where they are computed. +- ``allow_fp8_model_params_dequantized_weight``: allows ``fprop``/``dgrad`` switching + for layers with FP8 model parameters by using dequantized temporary weights. + Enabling or Disabling Sections and Features ------------------------------------------- diff --git a/docs/debug/3_api_features.rst b/docs/debug/3_api_features.rst index a8a644d5b5..1972a3d1d8 100644 --- a/docs/debug/3_api_features.rst +++ b/docs/debug/3_api_features.rst @@ -10,6 +10,7 @@ Debug features .. autoapiclass:: transformer_engine.debug.features.log_fp8_tensor_stats.LogFp8TensorStats .. autoapiclass:: transformer_engine.debug.features.log_nvfp4_tensor_stats.LogNvfp4TensorStats .. autoapiclass:: transformer_engine.debug.features.disable_quantization_gemm.DisableQuantizationGEMM +.. autoapiclass:: transformer_engine.debug.features.autoswitch_gemm.AutoswitchGemm .. autoapiclass:: transformer_engine.debug.features.disable_quantization_layer.DisableQuantizationLayer .. autoapiclass:: transformer_engine.debug.features.per_tensor_scaling.PerTensorScaling .. autoapiclass:: transformer_engine.debug.features.fake_quant.FakeQuant diff --git a/docs/debug/autoswitch_gemm_example.yaml b/docs/debug/autoswitch_gemm_example.yaml new file mode 100644 index 0000000000..c24462a67e --- /dev/null +++ b/docs/debug/autoswitch_gemm_example.yaml @@ -0,0 +1,72 @@ +# Example config for transformer_engine.debug.features.autoswitch_gemm.AutoswitchGemm +# +# Usage: +# import nvdlfw_inspect.api as debug_api +# debug_api.initialize( +# config_file="docs/debug/autoswitch_gemm_example.yaml", +# feature_dirs=["transformer_engine/debug/features"], +# log_dir="./log", +# ) +# ... +# debug_api.step() # call once per training step + +autoswitch_attention_blocks: + enabled: True + layers: + # Match attention linear layers, e.g. *.qkv / *.proj + layer_name_regex_pattern: ".*(qkv|proj).*" + transformer_engine: + AutoswitchGemm: + enabled: True + + # Optional. If omitted, tensors are inferred from selected gemms: + # fprop -> [activation, weight], dgrad -> [gradient, weight], + # wgrad -> [activation, gradient]. + tensors: [activation, weight, gradient] + + # Per-GEMM switching policy. + gemms_struct: + - gemm: fprop + underflow_threshold_pct: 1.0 + mse_threshold: 1.0e-4 + - gemm: dgrad + underflow_threshold_pct: 1.5 + mse_threshold: 1.5e-4 + - gemm: wgrad + underflow_threshold_pct: 2.0 + mse_threshold: 2.0e-4 + + # For layers with fp8 model parameters: + # - False: keep fprop/dgrad quantized + # - True: allow high-precision switch via temporary dequantized weights + allow_fp8_model_params_dequantized_weight: False + + # Collect metrics every step after warmup. + freq: 1 + start_step: 10 + end_step: 5000 + + +autoswitch_mlp_blocks: + enabled: True + layers: + layer_types: [fc1, fc2] + transformer_engine: + AutoswitchGemm: + enabled: True + + # Simpler global policy (shared by selected GEMMs). + gemms: [fprop, wgrad] + tensors: [activation, weight, gradient] + + underflow_threshold_pct: 3.0 + mse_threshold: 3.0e-4 + + # Example sparse monitoring windows. + freq: 2 + start_end_list: + - [0, 300] + - [800, 3000] + +# Autoswitch per-rank metrics are written to: +# /nvdlfw_inspect_autoswitchgemm_logs/nvdlfw_inspect_globalrank-.log diff --git a/transformer_engine/debug/features/autoswitch_gemm.py b/transformer_engine/debug/features/autoswitch_gemm.py new file mode 100644 index 0000000000..807947f627 --- /dev/null +++ b/transformer_engine/debug/features/autoswitch_gemm.py @@ -0,0 +1,585 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""AutoswitchGemm Feature support for nvidia-dlframework-inspect.""" + +import copy +import logging +import os +from typing import Dict, Optional, Set, Tuple + +import torch +import torch.distributed as dist + +import nvdlfw_inspect.api as debug_api +from nvdlfw_inspect.logging import get_logger +from nvdlfw_inspect.registry import Registry, api_method + +from transformer_engine.debug.features.api import TEConfigAPIMapper +from transformer_engine.debug.features.utils import next_enabled_iter + + +class _AutoswitchGemmMetricLogger: + """Writes per-rank autoswitch metrics to a dedicated log file.""" + + _instance = None + _initialized = False + + def __new__(cls): + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + def __init__(self): + if _AutoswitchGemmMetricLogger._initialized: + return + self.root_dir = None + self.log_file = None + self.logger = None + _AutoswitchGemmMetricLogger._initialized = True + + @staticmethod + def _get_rank() -> int: + if dist.is_available() and dist.is_initialized(): + return dist.get_rank() + return 0 + + def _expected_paths(self, root_log_dir: str) -> Tuple[str, str]: + rank = self._get_rank() + root_dir = os.path.join(root_log_dir, "nvdlfw_inspect_autoswitchgemm_logs") + log_file = os.path.join(root_dir, f"nvdlfw_inspect_globalrank-{rank}.log") + return root_dir, log_file + + def initialize(self, root_log_dir: str) -> None: + """Initialize rank-local logger under autoswitch log directory.""" + root_dir, log_file = self._expected_paths(root_log_dir) + os.makedirs(root_dir, exist_ok=True) + + rank = self._get_rank() + logger_name = f"nvdlfw_inspect.autoswitchgemm.rank{rank}" + logger = logging.getLogger(logger_name) + logger.setLevel(logging.INFO) + logger.propagate = False + + for handler in list(logger.handlers): + logger.removeHandler(handler) + handler.close() + + file_handler = logging.FileHandler(log_file, mode="a") + file_handler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")) + logger.addHandler(file_handler) + + self.root_dir = root_dir + self.log_file = log_file + self.logger = logger + + def ensure_initialized(self, root_log_dir: Optional[str]) -> bool: + """Ensure logger tracks current debug session's root log dir.""" + if not root_log_dir: + return False + expected_root_dir, expected_log_file = self._expected_paths(root_log_dir) + if ( + self.logger is None + or self.root_dir != expected_root_dir + or self.log_file != expected_log_file + or not os.path.isdir(expected_root_dir) + ): + self.initialize(root_log_dir) + return self.logger is not None + + def log_scalar( + self, + layer_name: str, + gemm: str, + metric_name: str, + iteration: int, + value: float, + ) -> None: + """Log metric in LogTensorStats-like `iteration/value` format.""" + if self.logger is None: + return + metric_key = f"{layer_name}_{gemm}_{metric_name}" + self.logger.info( + f"{metric_key} \t\t\t\t iteration={iteration:06d} \t\t\t\t value={value:.8f}" + ) + + +def _get_autoswitch_metric_logger() -> _AutoswitchGemmMetricLogger: + """Get singleton autoswitch metric logger.""" + return _AutoswitchGemmMetricLogger() + + +class _GemmSwitchState: + """Autoswitch state tracked independently for each (layer, gemm).""" + + def __init__(self): + self.disable_until_iter = -1 + self.last_applied_metric_snapshot = None + self.last_reason = "" + + +@Registry.register_feature(namespace="transformer_engine") +class AutoswitchGemm(TEConfigAPIMapper): + """ + Dynamically switches selected GEMMs between quantized and high-precision execution. + + The feature continuously monitors quantization quality for selected tensors and, + when quality degrades beyond configured thresholds, temporarily disables quantized + GEMM for the affected operation. + + The decision is made per `(layer_name, gemm)`: + + - `fp8_gemm_enabled(..., gemm="fprop")` controls FPROP GEMM + - `fp8_gemm_enabled(..., gemm="dgrad")` controls DGRAD GEMM + - `fp8_gemm_enabled(..., gemm="wgrad")` controls WGRAD GEMM + + The API name `fp8_gemm_enabled` is kept for backward compatibility with the + debug API; the switch applies to all quantized formats supported by TE. + When multiple tensors are monitored for a GEMM, their metrics are aggregated + with OR semantics: if any monitored tensor breaches thresholds, the GEMM + switches to high precision. + + Parameters + ---------- + + gemms / gemms_struct: List[str] + GEMMs to control: + + - fprop + - dgrad + - wgrad + + tensors / tensors_struct: Optional[List[str]] + Tensors to monitor: + + - activation + - weight + - gradient + + If omitted, tensors are inferred from selected GEMMs: + + - fprop -> activation, weight + - dgrad -> gradient, weight + - wgrad -> activation, gradient + + underflow_threshold_pct: float, default = 5.0 + Trigger switch to high precision if underflow percentage exceeds this value. + + mse_threshold: float, default = 1e-4 + Trigger switch to high precision if quantization MSE exceeds this value. + + The switch decision is same-iteration only: + metrics computed at iteration `n` are consumed in iteration `n`. + There is no cross-iteration hold window. + + allow_fp8_model_params_dequantized_weight: bool, default = False + If True, allows `fprop`/`dgrad` to switch to high precision even when + fp8 model parameters are enabled by using a temporary dequantized weight + tensor for GEMM execution. + If False, `fprop`/`dgrad` stay quantized for such layers. + + freq/start_step/end_step/start_end_list: Optional + Sampling controls for tensor inspection calls. + + Example + ------- + .. code-block:: yaml + + example_autoswitch_gemm: + enabled: True + layers: + layer_types: [qkv] + transformer_engine: + AutoswitchGemm: + enabled: True + gemms: [fprop, dgrad, wgrad] + underflow_threshold_pct: 3.0 + mse_threshold: 1e-4 + # decision is computed and consumed in the same iteration + """ + + _GEMM_TO_TENSORS = { + "fprop": {"activation", "weight"}, + "dgrad": {"gradient", "weight"}, + "wgrad": {"activation", "gradient"}, + } + + # Mirrors DebugQuantizer's internal mapping. + _TENSOR_TO_GEMMS = { + "weight": ("fprop", "dgrad"), + "activation": ("fprop", "wgrad"), + "gradient": ("dgrad", "wgrad"), + "output": ("fprop", None), + "wgrad": ("wgrad", None), + "dgrad": ("dgrad", None), + } + + _DEFAULT_UNDERFLOW_THRESHOLD_PCT = 5.0 + _DEFAULT_MSE_THRESHOLD = 1e-4 + _DEFAULT_ALLOW_FP8_MODEL_PARAMS_DEQUANTIZED_WEIGHT = False + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._gemm_state: Dict[Tuple[str, str], _GemmSwitchState] = {} + self._latest_metrics: Dict[Tuple[str, str], Dict[str, float | int | str]] = {} + self._layer_has_fp8_model_params: Dict[str, bool] = {} + + def parse_config_and_api(self, config, **kwargs): + """ + Parse config for GEMM-routing and tensor-inspection APIs. + + Unlike the default TEConfigAPIMapper behavior, this implementation supports + tensor inspection even when `tensors` is omitted by inferring monitored + tensors from selected GEMMs. + """ + processed_config = None + config_copy = copy.deepcopy(config) + + gemm = kwargs.get("gemm", None) + tensor_name = kwargs.get("tensor_name", None) + + if gemm is not None and tensor_name is None: + processed_config = self._process_transformer_engine_config(config_copy, **kwargs) + elif tensor_name is not None: + if "tensors" in config_copy or "tensors_struct" in config_copy: + processed_config = self._process_tensor_config(config_copy, tensor_name) + else: + monitored_tensors = self._infer_monitored_tensors(config_copy) + if tensor_name not in monitored_tensors: + return False, None + processed_config = config_copy + processed_config["tensor"] = tensor_name + + if not processed_config: + return False, None + + if "enabled" in processed_config: + processed_config.pop("enabled") + + return True, processed_config + + def _infer_monitored_tensors(self, config: Dict) -> Set[str]: + """Infer tensors to inspect from configured GEMMs.""" + configured_gemms = self._extract_configured_gemms(config) + if not configured_gemms: + configured_gemms = set(self._GEMM_TO_TENSORS.keys()) + + tensors = set() + for gemm in configured_gemms: + self._validate_gemm(gemm) + tensors.update(self._GEMM_TO_TENSORS[gemm]) + return tensors + + @staticmethod + def _extract_configured_gemms(config: Dict) -> Set[str]: + """Extract GEMM names from config keys `gemm`, `gemms`, and `gemms_struct`.""" + gemms = set() + if "gemm" in config: + gemms.add(config["gemm"]) + if "gemms" in config: + gemms.update(config["gemms"]) + if "gemms_struct" in config: + for cfg in config["gemms_struct"]: + if "gemm" in cfg: + gemms.add(cfg["gemm"]) + return gemms + + @staticmethod + def _config_float(config: Dict, key: str, default: Optional[float]) -> Optional[float]: + """Read optional float value from config.""" + value = config.get(key, default) + if value is None: + return None + return float(value) + + @staticmethod + def _config_bool(config: Dict, key: str, default: bool) -> bool: + """Read bool value from config.""" + value = config.get(key, default) + if isinstance(value, str): + return value.strip().lower() in ("1", "true", "yes", "on") + return bool(value) + + @staticmethod + def _get_root_log_dir() -> Optional[str]: + """Best-effort retrieval of nvdlfw_inspect root log directory.""" + try: + root_log_dir = getattr(get_logger(), "root_log_dir", None) + except Exception: # pylint: disable=broad-except + return None + return root_log_dir + + def _get_metrics_logger(self) -> Optional[_AutoswitchGemmMetricLogger]: + """Return initialized autoswitch metric logger if log dir is available.""" + metric_logger = _get_autoswitch_metric_logger() + if metric_logger.ensure_initialized(self._get_root_log_dir()): + return metric_logger + return None + + def _get_or_create_state(self, layer_name: str, gemm: str) -> _GemmSwitchState: + key = (layer_name, gemm) + if key not in self._gemm_state: + self._gemm_state[key] = _GemmSwitchState() + return self._gemm_state[key] + + def _update_metric( + self, + layer_name: str, + gemm: str, + iteration: int, + tensor_name: str, + underflow_pct: float, + mse: float, + ) -> None: + """Store the latest quality metric for a `(layer, gemm)` pair.""" + metric_logger = self._get_metrics_logger() + if metric_logger is not None: + metric_logger.log_scalar( + layer_name, gemm, f"{tensor_name}_underflow_pct", iteration, underflow_pct + ) + metric_logger.log_scalar(layer_name, gemm, f"{tensor_name}_mse", iteration, mse) + + key = (layer_name, gemm) + entry = self._latest_metrics.get(key) + + if entry is None or int(entry["iteration"]) < iteration: + self._latest_metrics[key] = { + "iteration": iteration, + "underflow_pct": underflow_pct, + "mse": mse, + "tensor_name": tensor_name, + } + return + + if int(entry["iteration"]) == iteration: + if underflow_pct >= float(entry["underflow_pct"]): + entry["underflow_pct"] = underflow_pct + entry["tensor_name"] = tensor_name + entry["mse"] = max(float(entry["mse"]), mse) + + @staticmethod + def _dequantize_like( + quantized_tensor, + dtype: torch.dtype, + shape: torch.Size, + ) -> Optional[torch.Tensor]: + """Best-effort dequantization helper used for quality metrics.""" + if quantized_tensor is None or not hasattr(quantized_tensor, "dequantize"): + return None + + try: + dequantized = quantized_tensor.dequantize(dtype=dtype) + except TypeError: + dequantized = quantized_tensor.dequantize() + if dequantized.dtype != dtype: + dequantized = dequantized.to(dtype) + + if dequantized.shape != shape: + expected_numel = 1 + for dim in shape: + expected_numel *= int(dim) + if dequantized.numel() != expected_numel: + return None + dequantized = dequantized.reshape(shape) + return dequantized + + @staticmethod + def _compute_metrics( + tensor: Optional[torch.Tensor], + quantized_tensor, + ) -> Optional[Tuple[float, float]]: + """Compute underflow percentage and MSE for one tensor.""" + if tensor is None or tensor.numel() == 0: + return None + + if not tensor.is_floating_point(): + return None + + dequantized = AutoswitchGemm._dequantize_like(quantized_tensor, tensor.dtype, tensor.shape) + if dequantized is None: + return None + + tensor_fp32 = tensor.float() + dequantized_fp32 = dequantized.float() + + underflow_count = torch.count_nonzero((tensor_fp32 != 0) & (dequantized_fp32 == 0)) + underflow_pct = (underflow_count.float() * 100.0 / tensor_fp32.numel()).item() + + mse = torch.mean((tensor_fp32 - dequantized_fp32) ** 2).item() + return underflow_pct, mse + + def _consume_new_metric_and_maybe_arm_switch( + self, + layer_name: str, + gemm: str, + iteration: int, + config: Dict, + state: _GemmSwitchState, + ) -> None: + """Consume current-iteration metrics and arm switch for this iteration only.""" + metric = self._latest_metrics.get((layer_name, gemm)) + if metric is None: + return + + metric_iter = int(metric["iteration"]) + if metric_iter != iteration: + # Autoswitch consumes metrics only in the same iteration they were produced. + return + + metric_snapshot = ( + metric_iter, + float(metric["underflow_pct"]), + float(metric["mse"]), + str(metric["tensor_name"]), + ) + if metric_snapshot == state.last_applied_metric_snapshot: + return + state.last_applied_metric_snapshot = metric_snapshot + + underflow_threshold = self._config_float( + config, "underflow_threshold_pct", self._DEFAULT_UNDERFLOW_THRESHOLD_PCT + ) + mse_threshold = self._config_float(config, "mse_threshold", self._DEFAULT_MSE_THRESHOLD) + + reasons = [] + metric_underflow = float(metric["underflow_pct"]) + metric_mse = float(metric["mse"]) + + if underflow_threshold is not None and metric_underflow > underflow_threshold: + reasons.append( + f"underflow={metric_underflow:.4f}% > threshold={underflow_threshold:.4f}%" + ) + if mse_threshold is not None and metric_mse > mse_threshold: + reasons.append(f"mse={metric_mse:.6e} > threshold={mse_threshold:.6e}") + + if not reasons: + return + + state.disable_until_iter = iteration + state.last_reason = "; ".join(reasons) + + debug_api.log_message( + f"Feature={self.__class__.__name__}: switch {gemm} to high precision in" + f" iter={iteration}. Triggered by {metric['tensor_name']} at iter={metric_iter}:" + f" {state.last_reason}", + layer_name, + extra_cachable_args=(gemm, "switch"), + ) + + @api_method + def fp8_gemm_enabled(self, config, layer_name: str, gemm: str, iteration: int): + """Decide whether selected GEMM should run quantized (True) or high precision (False).""" + state = self._get_or_create_state(layer_name, gemm) + metric_logger = self._get_metrics_logger() + + fp8_model_params_layer = self._layer_has_fp8_model_params.get(layer_name, False) + allow_fp8_model_params_fallback = self._config_bool( + config, + "allow_fp8_model_params_dequantized_weight", + self._DEFAULT_ALLOW_FP8_MODEL_PARAMS_DEQUANTIZED_WEIGHT, + ) + + # With fp8 model parameters enabled, fprop/dgrad can switch to high precision + # only when dequantized fallback is explicitly enabled in config. + if gemm in {"fprop", "dgrad"} and fp8_model_params_layer and not allow_fp8_model_params_fallback: + state.disable_until_iter = -1 + if metric_logger is not None: + metric_logger.log_scalar(layer_name, gemm, "quantized_enabled", iteration, 1.0) + metric_logger.log_scalar( + layer_name, gemm, "switch_blocked_fp8_model_params", iteration, 1.0 + ) + debug_api.log_message( + f"Feature={self.__class__.__name__}: skip switch for {gemm} at" + f" iter={iteration} because fp8 model parameters are enabled.", + layer_name, + extra_cachable_args=(gemm, "skip_fp8_model_params"), + ) + return True, iteration + 1 + + if gemm in {"fprop", "dgrad"} and fp8_model_params_layer and allow_fp8_model_params_fallback: + if metric_logger is not None: + metric_logger.log_scalar( + layer_name, gemm, "fp8_model_params_dequantized_fallback", iteration, 1.0 + ) + debug_api.log_message( + f"Feature={self.__class__.__name__}: {gemm} allows fp8-model-params" + " dequantized-weight fallback.", + layer_name, + extra_cachable_args=(gemm, "fp8_model_params_dequantized_fallback"), + ) + + self._consume_new_metric_and_maybe_arm_switch(layer_name, gemm, iteration, config, state) + + if iteration <= state.disable_until_iter: + if metric_logger is not None: + metric_logger.log_scalar(layer_name, gemm, "quantized_enabled", iteration, 0.0) + metric_logger.log_scalar( + layer_name, gemm, "disable_until_iter", iteration, float(state.disable_until_iter) + ) + debug_api.log_message( + f"Feature={self.__class__.__name__}: {gemm} forced high precision at" + f" iter={iteration} (disable_until={state.disable_until_iter}).", + layer_name, + extra_cachable_args=(gemm, "high_precision"), + ) + return False, iteration + 1 + + if metric_logger is not None: + metric_logger.log_scalar(layer_name, gemm, "quantized_enabled", iteration, 1.0) + return True, iteration + 1 + + @api_method + def inspect_tensor_enabled( + self, + config: Dict, + layer_name: str, + tensor_name: str, + iteration: int, + ): # pylint: disable=unused-argument + """Enable metric collection according to the standard freq/start/end controls.""" + run_current, next_iter = next_enabled_iter( + config.get("start_step", None), + config.get("end_step", None), + config.get("start_end_list", None), + config.get("freq", 1), + iteration, + ) + return run_current, next_iter + + @api_method + def inspect_tensor( + self, + config: Dict, + layer_name: str, + tensor_name: str, + iteration: int, + tp_group: torch.distributed.ProcessGroup, # pylint: disable=unused-argument + tensor: Optional[torch.Tensor], + rowwise_quantized_tensor: Optional[torch.Tensor] = None, + columnwise_quantized_tensor: Optional[torch.Tensor] = None, + quantizer=None, # pylint: disable=unused-argument + tp_size: int = 1, # pylint: disable=unused-argument + ): + """Collect quantization quality metrics for autoswitch decisions.""" + if tensor_name == "weight" and tensor is None: + # Weight tensor unavailable in high precision indicates fp8 model params. + self._layer_has_fp8_model_params[layer_name] = True + + _ = config + gemms = self._TENSOR_TO_GEMMS.get(tensor_name, (None, None)) + + rowwise_gemm, columnwise_gemm = gemms + if rowwise_gemm is not None: + metrics = self._compute_metrics(tensor, rowwise_quantized_tensor) + if metrics is not None: + self._update_metric( + layer_name, rowwise_gemm, iteration, tensor_name, metrics[0], metrics[1] + ) + + if columnwise_gemm is not None: + metrics = self._compute_metrics(tensor, columnwise_quantized_tensor) + if metrics is not None: + self._update_metric( + layer_name, columnwise_gemm, iteration, tensor_name, metrics[0], metrics[1] + ) diff --git a/transformer_engine/debug/pytorch/debug_quantization.py b/transformer_engine/debug/pytorch/debug_quantization.py index ed5fdd4660..7d52f3a875 100644 --- a/transformer_engine/debug/pytorch/debug_quantization.py +++ b/transformer_engine/debug/pytorch/debug_quantization.py @@ -320,14 +320,21 @@ def quantize( self.parent_quantizer.set_usage(rowwise=True) rowwise_gemm_tensor, columnwise_gemm_tensor = None, None - if STANDARD_QUANTIZE in [self.rowwise_tensor_plan, self.columnwise_tensor_plan]: + parent_has_quantized_usage = ( + self.parent_quantizer is not None + and (self.parent_quantizer.rowwise_usage or self.parent_quantizer.columnwise_usage) + ) + if ( + STANDARD_QUANTIZE in [self.rowwise_tensor_plan, self.columnwise_tensor_plan] + and parent_has_quantized_usage + ): quantized_tensor = self.parent_quantizer(tensor) # if both rowwise_tensor_plan and columnwise_tensor_plan need to be quantized, # one tensor with columnwise=True and rowwise=True is computed # and both rowwise_tensor_plan and columnwise_tensor_plan point to it. - if self.rowwise_tensor_plan == STANDARD_QUANTIZE: + if self.rowwise_tensor_plan == STANDARD_QUANTIZE and self.parent_quantizer.rowwise_usage: rowwise_gemm_tensor = quantized_tensor - if self.columnwise_tensor_plan == STANDARD_QUANTIZE: + if self.columnwise_tensor_plan == STANDARD_QUANTIZE and self.parent_quantizer.columnwise_usage: columnwise_gemm_tensor = quantized_tensor # 2. modify_tensor() is called, if it is used. @@ -562,25 +569,56 @@ def set_usage(self, rowwise: bool = None, columnwise: bool = None): if not self.output_tensor: self._update_parent_quantizer_usage() - def wrap_quantized_tensor(self, tensor: QuantizedTensor): + def wrap_quantized_tensor( + self, tensor: QuantizedTensor, dtype: Optional[torch.dtype] = None + ): """ Wraps the quantized tensor with the debug quantizer. It is used for weight tensors when fp8 model parameters are enabled. """ + if API_CALL_MODIFY in (self.rowwise_tensor_plan, self.columnwise_tensor_plan): + raise AssertionError( + "[NVTORCH INSPECT ERROR] Weight tensor with fp8 model parameters enabled cannot" + " be modified by modify_tensor()." + ) + + dequantized_weight = None - assert ( + def _get_dequantized_weight(): + nonlocal dequantized_weight + if dequantized_weight is None: + output_dtype = dtype if dtype is not None else tensor.dtype + try: + dequantized_weight = tensor.dequantize(dtype=output_dtype) + except TypeError: + dequantized_weight = tensor.dequantize() + if dequantized_weight.dtype != output_dtype: + dequantized_weight = dequantized_weight.to(output_dtype) + return dequantized_weight + + if ( self.rowwise_tensor_plan == STANDARD_QUANTIZE and self.columnwise_tensor_plan == STANDARD_QUANTIZE - ), ( - "[NVTORCH INSPECT ERROR] Weight tensor with fp8 model parameters enabled cannot be" - " modified by any feature." - ) + ): + rowwise_tensor = tensor + columnwise_tensor = tensor + inspect_source = None + else: + rowwise_tensor = ( + tensor if self.rowwise_tensor_plan == STANDARD_QUANTIZE else _get_dequantized_weight() + ) + columnwise_tensor = ( + tensor + if self.columnwise_tensor_plan == STANDARD_QUANTIZE + else _get_dequantized_weight() + ) + inspect_source = _get_dequantized_weight() - self._call_inspect_tensor_api(None, tensor, tensor) + self._call_inspect_tensor_api(inspect_source, rowwise_tensor, columnwise_tensor) return DebugQuantizedTensor( - rowwise_gemm_tensor=tensor, - columnwise_gemm_tensor=tensor, + rowwise_gemm_tensor=rowwise_tensor, + columnwise_gemm_tensor=columnwise_tensor, quantizer=self, layer_name=self.layer_name, tensor_name=self.tensor_name, @@ -676,7 +714,8 @@ def size(self, *args): def update_usage(self, rowwise_usage: bool = None, columnwise_usage: bool = None): """Update usage of the tensor.""" - if self.rowwise_gemm_tensor is not self.columnwise_gemm_tensor: + same_storage = self.rowwise_gemm_tensor is self.columnwise_gemm_tensor + if not same_storage: # If the same object is used both for rowwise and columnwise gemms, # there is no benefit in erasing the usage of one of them. # And there are scenarios when not deleting the usage of one of them is needed. @@ -687,9 +726,27 @@ def update_usage(self, rowwise_usage: bool = None, columnwise_usage: bool = None self.columnwise_gemm_tensor = None if isinstance(self.rowwise_gemm_tensor, QuantizedTensor): - self.rowwise_gemm_tensor.update_usage(rowwise_usage, columnwise_usage) + if same_storage: + rowwise_rowwise_usage = rowwise_usage + rowwise_columnwise_usage = columnwise_usage + else: + # Keep rowwise storage focused on rowwise path. + rowwise_rowwise_usage = rowwise_usage + rowwise_columnwise_usage = False if columnwise_usage is not None else None + self.rowwise_gemm_tensor.update_usage( + rowwise_rowwise_usage, rowwise_columnwise_usage + ) if isinstance(self.columnwise_gemm_tensor, QuantizedTensor): - self.columnwise_gemm_tensor.update_usage(rowwise_usage, columnwise_usage) + if same_storage: + columnwise_rowwise_usage = rowwise_usage + columnwise_columnwise_usage = columnwise_usage + else: + # Keep columnwise storage focused on columnwise path. + columnwise_rowwise_usage = False if rowwise_usage is not None else None + columnwise_columnwise_usage = columnwise_usage + self.columnwise_gemm_tensor.update_usage( + columnwise_rowwise_usage, columnwise_columnwise_usage + ) if rowwise_usage and self.rowwise_gemm_tensor is None: raise RuntimeError( diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 1b237ece29..3363578de6 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -1442,7 +1442,7 @@ def get_weight_workspace( ) if isinstance(quantizer, DebugQuantizer): - tensor = quantizer.wrap_quantized_tensor(tensor) + tensor = quantizer.wrap_quantized_tensor(tensor, dtype=workspace_dtype) return tensor From e0d16645a8623a1e370fb3c1ac138711321952a7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 15 Apr 2026 03:41:10 +0000 Subject: [PATCH 2/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../debug/features/autoswitch_gemm.py | 18 ++++++++++--- .../debug/pytorch/debug_quantization.py | 27 ++++++++++--------- 2 files changed, 30 insertions(+), 15 deletions(-) diff --git a/transformer_engine/debug/features/autoswitch_gemm.py b/transformer_engine/debug/features/autoswitch_gemm.py index 807947f627..f247807d1c 100644 --- a/transformer_engine/debug/features/autoswitch_gemm.py +++ b/transformer_engine/debug/features/autoswitch_gemm.py @@ -482,7 +482,11 @@ def fp8_gemm_enabled(self, config, layer_name: str, gemm: str, iteration: int): # With fp8 model parameters enabled, fprop/dgrad can switch to high precision # only when dequantized fallback is explicitly enabled in config. - if gemm in {"fprop", "dgrad"} and fp8_model_params_layer and not allow_fp8_model_params_fallback: + if ( + gemm in {"fprop", "dgrad"} + and fp8_model_params_layer + and not allow_fp8_model_params_fallback + ): state.disable_until_iter = -1 if metric_logger is not None: metric_logger.log_scalar(layer_name, gemm, "quantized_enabled", iteration, 1.0) @@ -497,7 +501,11 @@ def fp8_gemm_enabled(self, config, layer_name: str, gemm: str, iteration: int): ) return True, iteration + 1 - if gemm in {"fprop", "dgrad"} and fp8_model_params_layer and allow_fp8_model_params_fallback: + if ( + gemm in {"fprop", "dgrad"} + and fp8_model_params_layer + and allow_fp8_model_params_fallback + ): if metric_logger is not None: metric_logger.log_scalar( layer_name, gemm, "fp8_model_params_dequantized_fallback", iteration, 1.0 @@ -515,7 +523,11 @@ def fp8_gemm_enabled(self, config, layer_name: str, gemm: str, iteration: int): if metric_logger is not None: metric_logger.log_scalar(layer_name, gemm, "quantized_enabled", iteration, 0.0) metric_logger.log_scalar( - layer_name, gemm, "disable_until_iter", iteration, float(state.disable_until_iter) + layer_name, + gemm, + "disable_until_iter", + iteration, + float(state.disable_until_iter), ) debug_api.log_message( f"Feature={self.__class__.__name__}: {gemm} forced high precision at" diff --git a/transformer_engine/debug/pytorch/debug_quantization.py b/transformer_engine/debug/pytorch/debug_quantization.py index 7d52f3a875..3f499a02f9 100644 --- a/transformer_engine/debug/pytorch/debug_quantization.py +++ b/transformer_engine/debug/pytorch/debug_quantization.py @@ -320,9 +320,8 @@ def quantize( self.parent_quantizer.set_usage(rowwise=True) rowwise_gemm_tensor, columnwise_gemm_tensor = None, None - parent_has_quantized_usage = ( - self.parent_quantizer is not None - and (self.parent_quantizer.rowwise_usage or self.parent_quantizer.columnwise_usage) + parent_has_quantized_usage = self.parent_quantizer is not None and ( + self.parent_quantizer.rowwise_usage or self.parent_quantizer.columnwise_usage ) if ( STANDARD_QUANTIZE in [self.rowwise_tensor_plan, self.columnwise_tensor_plan] @@ -332,9 +331,15 @@ def quantize( # if both rowwise_tensor_plan and columnwise_tensor_plan need to be quantized, # one tensor with columnwise=True and rowwise=True is computed # and both rowwise_tensor_plan and columnwise_tensor_plan point to it. - if self.rowwise_tensor_plan == STANDARD_QUANTIZE and self.parent_quantizer.rowwise_usage: + if ( + self.rowwise_tensor_plan == STANDARD_QUANTIZE + and self.parent_quantizer.rowwise_usage + ): rowwise_gemm_tensor = quantized_tensor - if self.columnwise_tensor_plan == STANDARD_QUANTIZE and self.parent_quantizer.columnwise_usage: + if ( + self.columnwise_tensor_plan == STANDARD_QUANTIZE + and self.parent_quantizer.columnwise_usage + ): columnwise_gemm_tensor = quantized_tensor # 2. modify_tensor() is called, if it is used. @@ -569,9 +574,7 @@ def set_usage(self, rowwise: bool = None, columnwise: bool = None): if not self.output_tensor: self._update_parent_quantizer_usage() - def wrap_quantized_tensor( - self, tensor: QuantizedTensor, dtype: Optional[torch.dtype] = None - ): + def wrap_quantized_tensor(self, tensor: QuantizedTensor, dtype: Optional[torch.dtype] = None): """ Wraps the quantized tensor with the debug quantizer. It is used for weight tensors when fp8 model parameters are enabled. @@ -605,7 +608,9 @@ def _get_dequantized_weight(): inspect_source = None else: rowwise_tensor = ( - tensor if self.rowwise_tensor_plan == STANDARD_QUANTIZE else _get_dequantized_weight() + tensor + if self.rowwise_tensor_plan == STANDARD_QUANTIZE + else _get_dequantized_weight() ) columnwise_tensor = ( tensor @@ -733,9 +738,7 @@ def update_usage(self, rowwise_usage: bool = None, columnwise_usage: bool = None # Keep rowwise storage focused on rowwise path. rowwise_rowwise_usage = rowwise_usage rowwise_columnwise_usage = False if columnwise_usage is not None else None - self.rowwise_gemm_tensor.update_usage( - rowwise_rowwise_usage, rowwise_columnwise_usage - ) + self.rowwise_gemm_tensor.update_usage(rowwise_rowwise_usage, rowwise_columnwise_usage) if isinstance(self.columnwise_gemm_tensor, QuantizedTensor): if same_storage: columnwise_rowwise_usage = rowwise_usage From 2907537d4bd069369434b8d9b270691fdb7651f5 Mon Sep 17 00:00:00 2001 From: Xiaokang Shang Date: Wed, 15 Apr 2026 18:27:16 +0800 Subject: [PATCH 3/3] apply resolve_gemm_inputs_after_sampling before gemm --- .../debug/features/autoswitch_gemm.py | 36 ++-- .../debug/pytorch/gemm_runtime_hooks.py | 183 ++++++++++++++++++ transformer_engine/pytorch/module/linear.py | 35 ++++ 3 files changed, 238 insertions(+), 16 deletions(-) create mode 100644 transformer_engine/debug/pytorch/gemm_runtime_hooks.py diff --git a/transformer_engine/debug/features/autoswitch_gemm.py b/transformer_engine/debug/features/autoswitch_gemm.py index f247807d1c..b4a05662dd 100644 --- a/transformer_engine/debug/features/autoswitch_gemm.py +++ b/transformer_engine/debug/features/autoswitch_gemm.py @@ -169,9 +169,10 @@ class AutoswitchGemm(TEConfigAPIMapper): mse_threshold: float, default = 1e-4 Trigger switch to high precision if quantization MSE exceeds this value. - The switch decision is same-iteration only: - metrics computed at iteration `n` are consumed in iteration `n`. - There is no cross-iteration hold window. + The switch decision is same-iteration: + metrics computed at iteration `n` are consumed in iteration `n` + after all GEMM input tensors are prepared. + The switch is applied for one iteration. allow_fp8_model_params_dequantized_weight: bool, default = False If True, allows `fprop`/`dgrad` to switch to high precision even when @@ -417,14 +418,14 @@ def _consume_new_metric_and_maybe_arm_switch( config: Dict, state: _GemmSwitchState, ) -> None: - """Consume current-iteration metrics and arm switch for this iteration only.""" + """Consume current-iteration metrics and arm switch for one iteration.""" metric = self._latest_metrics.get((layer_name, gemm)) if metric is None: return metric_iter = int(metric["iteration"]) if metric_iter != iteration: - # Autoswitch consumes metrics only in the same iteration they were produced. + # Autoswitch consumes metrics only in the iteration they were produced. return metric_snapshot = ( @@ -461,14 +462,21 @@ def _consume_new_metric_and_maybe_arm_switch( debug_api.log_message( f"Feature={self.__class__.__name__}: switch {gemm} to high precision in" - f" iter={iteration}. Triggered by {metric['tensor_name']} at iter={metric_iter}:" + f" iter={iteration}. Triggered by {metric['tensor_name']} sampled at iter={metric_iter}:" f" {state.last_reason}", layer_name, extra_cachable_args=(gemm, "switch"), ) @api_method - def fp8_gemm_enabled(self, config, layer_name: str, gemm: str, iteration: int): + def fp8_gemm_enabled( + self, + config, + layer_name: str, + gemm: str, + iteration: int, + final_decision: bool = False, + ): """Decide whether selected GEMM should run quantized (True) or high precision (False).""" state = self._get_or_create_state(layer_name, gemm) metric_logger = self._get_metrics_logger() @@ -488,7 +496,7 @@ def fp8_gemm_enabled(self, config, layer_name: str, gemm: str, iteration: int): and not allow_fp8_model_params_fallback ): state.disable_until_iter = -1 - if metric_logger is not None: + if final_decision and metric_logger is not None: metric_logger.log_scalar(layer_name, gemm, "quantized_enabled", iteration, 1.0) metric_logger.log_scalar( layer_name, gemm, "switch_blocked_fp8_model_params", iteration, 1.0 @@ -501,12 +509,8 @@ def fp8_gemm_enabled(self, config, layer_name: str, gemm: str, iteration: int): ) return True, iteration + 1 - if ( - gemm in {"fprop", "dgrad"} - and fp8_model_params_layer - and allow_fp8_model_params_fallback - ): - if metric_logger is not None: + if gemm in {"fprop", "dgrad"} and fp8_model_params_layer and allow_fp8_model_params_fallback: + if final_decision and metric_logger is not None: metric_logger.log_scalar( layer_name, gemm, "fp8_model_params_dequantized_fallback", iteration, 1.0 ) @@ -520,7 +524,7 @@ def fp8_gemm_enabled(self, config, layer_name: str, gemm: str, iteration: int): self._consume_new_metric_and_maybe_arm_switch(layer_name, gemm, iteration, config, state) if iteration <= state.disable_until_iter: - if metric_logger is not None: + if final_decision and metric_logger is not None: metric_logger.log_scalar(layer_name, gemm, "quantized_enabled", iteration, 0.0) metric_logger.log_scalar( layer_name, @@ -537,7 +541,7 @@ def fp8_gemm_enabled(self, config, layer_name: str, gemm: str, iteration: int): ) return False, iteration + 1 - if metric_logger is not None: + if final_decision and metric_logger is not None: metric_logger.log_scalar(layer_name, gemm, "quantized_enabled", iteration, 1.0) return True, iteration + 1 diff --git a/transformer_engine/debug/pytorch/gemm_runtime_hooks.py b/transformer_engine/debug/pytorch/gemm_runtime_hooks.py new file mode 100644 index 0000000000..3d4ecf9866 --- /dev/null +++ b/transformer_engine/debug/pytorch/gemm_runtime_hooks.py @@ -0,0 +1,183 @@ +"""Runtime GEMM hooks used by AutoswitchGemm.""" + +from __future__ import annotations + +from typing import Any, Optional + +import torch + +from transformer_engine.debug.pytorch.debug_state import TEDebugState +from transformer_engine.pytorch.quantized_tensor import QuantizedTensorStorage, Quantizer +from transformer_engine.pytorch.utils import cast_if_needed + +_AUTOSWITCH_FEATURE_NAME = "AutoswitchGemm" +_AUTOSWITCH_ENABLED_CACHE = {} + + +def _is_fp8_debug_quantizer(quantizer: Optional[Quantizer]) -> bool: + """Return True for DebugQuantizer objects wrapping an FP8/NVFP4 quantizer.""" + return ( + quantizer is not None + and quantizer.__class__.__name__ == "DebugQuantizer" + and getattr(quantizer, "parent_quantizer", None) is not None + ) + + +def _feature_block_enabled(feature_config: Any) -> bool: + """Return whether an Autoswitch feature block is enabled.""" + if isinstance(feature_config, dict): + return bool(feature_config.get("enabled", True)) + if isinstance(feature_config, bool): + return feature_config + return feature_config is not None + + +def _contains_enabled_autoswitch(config: Any, visited: Optional[set] = None) -> bool: + """Recursively check whether config contains enabled AutoswitchGemm feature.""" + if visited is None: + visited = set() + obj_id = id(config) + if obj_id in visited: + return False + visited.add(obj_id) + + if isinstance(config, dict): + for key, value in config.items(): + if key == _AUTOSWITCH_FEATURE_NAME and _feature_block_enabled(value): + return True + for value in config.values(): + if _contains_enabled_autoswitch(value, visited): + return True + return False + + if isinstance(config, (list, tuple, set)): + for item in config: + if _contains_enabled_autoswitch(item, visited): + return True + return False + + return False + + +def _autoswitch_feature_enabled() -> bool: + """Best-effort detection for whether AutoswitchGemm is enabled in debug config.""" + try: + import nvdlfw_inspect.api as debug_api + except ImportError: + return False + + manager = getattr(debug_api, "DEBUG_MANAGER", None) + if manager is None: + return False + + manager_id = id(manager) + cached = _AUTOSWITCH_ENABLED_CACHE.get(manager_id) + if cached is not None: + return cached + + candidate_configs = [] + for attr in ( + "config", + "_config", + "debug_config", + "_debug_config", + "user_config", + "_user_config", + "raw_config", + "_raw_config", + ): + value = getattr(manager, attr, None) + if value is not None: + candidate_configs.append(value) + + for attr_name, value in getattr(manager, "__dict__", {}).items(): + if "config" in attr_name.lower() and value is not None: + candidate_configs.append(value) + + if not candidate_configs: + # Keep previous behavior if manager internals cannot be introspected. + _AUTOSWITCH_ENABLED_CACHE[manager_id] = True + return True + + enabled = any(_contains_enabled_autoswitch(config) for config in candidate_configs) + _AUTOSWITCH_ENABLED_CACHE[manager_id] = enabled + return enabled + + +def should_resolve_inputs_after_sampling( + lhs_quantizer: Optional[Quantizer], + rhs_quantizer: Optional[Quantizer], +) -> bool: + """Return True when runtime GEMM decision path should be applied.""" + if not (_is_fp8_debug_quantizer(lhs_quantizer) or _is_fp8_debug_quantizer(rhs_quantizer)): + return False + return _autoswitch_feature_enabled() + + +def _to_high_precision_gemm_input(tensor, dtype: torch.dtype): + """Convert GEMM input to high precision tensor if needed.""" + if hasattr(tensor, "get_tensor") and hasattr(tensor, "rowwise_gemm_tensor"): + rowwise_tensor = _to_high_precision_gemm_input(tensor.get_tensor(False), dtype) + columnwise_src = tensor.get_tensor(True) + if columnwise_src is tensor.get_tensor(False): + columnwise_tensor = rowwise_tensor + else: + columnwise_tensor = _to_high_precision_gemm_input(columnwise_src, dtype) + tensor.rowwise_gemm_tensor = rowwise_tensor + tensor.columnwise_gemm_tensor = columnwise_tensor + return tensor + + if dtype is None: + dtype = getattr(tensor, "dtype", None) + if isinstance(tensor, QuantizedTensorStorage): + if dtype is None: + return tensor.dequantize() + try: + return tensor.dequantize(dtype=dtype) + except TypeError: + return cast_if_needed(tensor.dequantize(), dtype) + if dtype is None: + return tensor + return cast_if_needed(tensor, dtype) + + +def resolve_gemm_inputs_after_sampling( + gemm_name: str, + lhs, + rhs, + lhs_quantizer: Optional[Quantizer], + rhs_quantizer: Optional[Quantizer], + target_dtype: torch.dtype, +): + """ + Make post-sampling GEMM precision decision and enforce OR logic across inputs. + + If any sampled input for this GEMM triggers high precision, both GEMM inputs are + converted to high precision tensors before kernel launch. + """ + layer_name = ( + getattr(lhs_quantizer, "layer_name", None) or getattr(rhs_quantizer, "layer_name", None) + ) + if layer_name is None: + return lhs, rhs + + try: + import nvdlfw_inspect.api as debug_api + except ImportError: + return lhs, rhs + + iteration = TEDebugState.get_iteration() + enabled_ret = debug_api.transformer_engine.fp8_gemm_enabled( + layer_name=layer_name, + gemm=gemm_name, + iteration=iteration, + final_decision=True, + ) + quantized_enabled = enabled_ret[0] if isinstance(enabled_ret, tuple) else enabled_ret + if quantized_enabled: + return lhs, rhs + + return ( + _to_high_precision_gemm_input(lhs, target_dtype), + _to_high_precision_gemm_input(rhs, target_dtype), + ) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 8510f6cf8f..1dadb57f39 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -73,6 +73,10 @@ mark_not_offload, mark_activation_offload, ) +from ...debug.pytorch.gemm_runtime_hooks import ( + resolve_gemm_inputs_after_sampling, + should_resolve_inputs_after_sampling, +) from ...debug.pytorch.debug_state import TEDebugState __all__ = ["Linear"] @@ -335,6 +339,15 @@ def forward( # Forward GEMM # Note: y = x * w^T # ------------------------------------------------------ + if debug and should_resolve_inputs_after_sampling(weight_quantizer, input_quantizer): + weightmat, inputmat_total = resolve_gemm_inputs_after_sampling( + "fprop", + weightmat, + inputmat_total, + weight_quantizer, + input_quantizer, + activation_dtype, + ) nvtx_range_push(f"{nvtx_label}.gemm") gemm_out, *_, reduce_scatter_out = general_gemm( weightmat, @@ -760,6 +773,17 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], weight_for_dgrad = weight if isinstance(weight_for_dgrad, QuantizedTensorStorage): weight_for_dgrad = weight_for_dgrad.dequantize(dtype=ctx.activation_dtype) + if ctx.debug and should_resolve_inputs_after_sampling( + ctx.weight_quantizer, ctx.grad_output_quantizer + ): + weight_for_dgrad, grad_output = resolve_gemm_inputs_after_sampling( + "dgrad", + weight_for_dgrad, + grad_output, + ctx.weight_quantizer, + ctx.grad_output_quantizer, + ctx.activation_dtype, + ) gemm_out, *_, reduce_scatter_out = general_gemm( weight_for_dgrad, grad_output, @@ -920,6 +944,17 @@ def wgrad_gemm( some advanced communication/compute overlapping. """ + if ctx.debug and should_resolve_inputs_after_sampling( + ctx.input_quantizer, ctx.grad_output_quantizer + ): + x, dy = resolve_gemm_inputs_after_sampling( + "wgrad", + x, + dy, + ctx.input_quantizer, + ctx.grad_output_quantizer, + ctx.activation_dtype, + ) nvtx_range_push(f"{nvtx_label}.wgrad_gemm") dw, db, *_ = general_gemm(x, dy, **wgrad_gemm_kwargs) nvtx_range_pop(f"{nvtx_label}.wgrad_gemm")