Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions qa/L0_pytorch_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_rope.xml $
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_nvfp4.xml $TE_PATH/tests/pytorch/nvfp4 || test_fail "test_nvfp4"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_mxfp8.xml $TE_PATH/tests/pytorch/mxfp8 || test_fail "test_mxfp8"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_quantized_tensor.xml $TE_PATH/tests/pytorch/test_quantized_tensor.py || test_fail "test_quantized_tensor.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_torch_compile.xml $TE_PATH/tests/pytorch/test_torch_compile.py || test_fail "test_torch_compile.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8blockwisetensor.xml $TE_PATH/tests/pytorch/test_float8blockwisetensor.py || test_fail "test_float8blockwisetensor.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_scaling_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_scaling_exact.py || test_fail "test_float8_blockwise_scaling_exact.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_gemm_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_gemm_exact.py || test_fail "test_float8_blockwise_gemm_exact.py"
Expand Down
41 changes: 31 additions & 10 deletions tests/pytorch/test_torch_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule
from transformer_engine.pytorch.ops.basic.basic_linear import BasicLinear
from transformer_engine.pytorch.tensor.float8_tensor import Float8CurrentScalingQuantizer
from transformer_engine.pytorch.quantization import QuantizerRole
from transformer_engine.pytorch import (
is_fp8_available,
is_mxfp8_available,
Expand Down Expand Up @@ -123,21 +124,25 @@ def __fx_repr__(self):
_Q = get_opaque_type_name(ToyQuantizer)

def _make_qfactory(tag: str):
"""Return a qfactory that produces ToyQuantizer instances tagged with *tag*."""
"""Return a qfactory that produces ToyQuantizer instances tagged with *tag*.

The factory dispatches on ``QuantizerRole.tensor_type``; the roles are
supplied by :meth:`ToyLinear.get_quantizer_roles`.
"""

quantizers = {
role: ToyQuantizer(tag=f"{tag}:{role}")
for role in (
"linear_input",
"linear_weight",
"linear_output",
"linear_grad_output",
"linear_grad_input",
tensor_type: ToyQuantizer(tag=f"{tag}:{tensor_type}")
for tensor_type in (
"input",
"weight",
"output",
"grad_output",
"grad_input",
)
}

def qfactory(role: str):
return quantizers[role]
def qfactory(role: QuantizerRole):
return quantizers[role.tensor_type]

return qfactory

Expand All @@ -163,6 +168,22 @@ def __init__(
)
torch.nn.init.normal_(self.weight)

def get_quantizer_roles(self, *, fwd: bool, num_quantizers: int):
# Supplying explicit roles keeps CustomRecipeState from emitting a
# warning (which would graph-break under fullgraph=True) and lets the
# qfactory dispatch per tensor slot. Order must match the module's
# quantizer array (FP8FwdTensorIdx / FP8BwdTensorIdx).
if fwd:
return [
QuantizerRole(module_type="linear", tensor_type="input"),
QuantizerRole(module_type="linear", tensor_type="weight"),
QuantizerRole(module_type="linear", tensor_type="output"),
]
return [
QuantizerRole(module_type="linear", tensor_type="grad_output"),
QuantizerRole(module_type="linear", tensor_type="grad_input"),
]

def _get_weight_tensors(self):
return [self.weight]

Expand Down
12 changes: 11 additions & 1 deletion transformer_engine/pytorch/module/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,12 @@ def get_ub(name: str, use_fp8: bool):
return _ub_communicators[key]


@torch.compiler.assume_constant_result
def get_ub_is_fp8(name: str, use_fp8: bool) -> bool:
"""Query is_fp8_ubuf for a named UB communicator; treated as compile-time constant."""
return get_ub(name, use_fp8).is_fp8_ubuf()
Comment on lines +557 to +560

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 assume_constant_result can become stale after destroy_ub() + re-init

@torch.compiler.assume_constant_result caches the return value per (name, use_fp8) argument pair for the lifetime of a compiled region. If destroy_ub() is called and UB communicators are re-initialized with different FP8 settings (e.g. in a test harness that re-creates the communicators), the cached is_fp8_ubuf() result would be silently stale until the next recompile. In production training this should not happen — UB is typically initialized once — but test suites that tear down and rebuild UB communicators between cases could observe incorrect fp8_output/fp8_grad flags without triggering a recompile.



def destroy_ub():
"""Destroy all allocated userbuffer communicators."""
global _ub_communicators, _ub_with_cublasmp, _ub_initialized
Expand All @@ -562,6 +568,9 @@ def destroy_ub():
_ub_initialized = False
global layers_atomic_ring_exchange
layers_atomic_ring_exchange = []
# Compiled graphs may have baked is_fp8_ubuf() via assume_constant_result;
# reset so re-init with different settings doesn't read stale constants.
torch.compiler.reset()


def fill_userbuffers_buffer_for_all_gather(
Expand Down Expand Up @@ -1049,7 +1058,8 @@ def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None:
if recipe.nvfp4() and isinstance(recipe_state, NVFP4BlockScalingRecipeState):
return
if recipe.custom() and isinstance(recipe_state, CustomRecipeState):
return
if recipe_state.recipe is recipe:
return

# Max. number of fp8 tensors per GEMM = 3 (input, weight, output) for fwd and
# 2 (grad_output and grad_input) for bwd
Expand Down
15 changes: 7 additions & 8 deletions transformer_engine/pytorch/module/layernorm_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from .base import (
fill_userbuffers_buffer_for_all_gather,
get_ub,
get_ub_is_fp8,
is_ub_initialized,
using_cublasmp_backend,
quantize_weight,
Expand Down Expand Up @@ -1051,8 +1052,10 @@ def wgrad_gemm(
if ctx.ln_out_needs_gather:
# Gathered input is internal
clear_tensor_data(ln_out_total)
if ctx.parallel_mode == "row" and ctx.sequence_parallel:
# Gathered grad output tensor is internal
if ctx.sequence_parallel and (
ctx.parallel_mode == "row" or (ctx.parallel_mode == "column" and ctx.fp8)
):
# Gathered (row-SP) or quantized (column-SP FP8) grad_output is internal
clear_tensor_data(grad_output)

# Update grad input if overlapping reduce-scatter with wgrad GEMM
Expand Down Expand Up @@ -1668,14 +1671,10 @@ def forward(
is_first_microbatch = False

if self.ub_overlap_rs_fprop:
if get_ub(
self.ub_name + "_fprop", FP8GlobalStateManager.is_fp8_enabled()
).is_fp8_ubuf():
if get_ub_is_fp8(self.ub_name + "_fprop", FP8GlobalStateManager.is_fp8_enabled()):
fp8_output = True
if self.ub_overlap_rs_dgrad:
if get_ub(
self.ub_name + "_dgrad", FP8GlobalStateManager.is_fp8_enabled()
).is_fp8_ubuf():
if get_ub_is_fp8(self.ub_name + "_dgrad", FP8GlobalStateManager.is_fp8_enabled()):
fp8_grad = True

inp = self.prepare_forward(
Expand Down
3 changes: 2 additions & 1 deletion transformer_engine/pytorch/module/layernorm_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
fill_userbuffers_buffer_for_all_gather,
_ub_communicators,
get_ub,
get_ub_is_fp8,
is_ub_initialized,
using_cublasmp_backend,
quantize_weight,
Expand Down Expand Up @@ -2292,7 +2293,7 @@ def forward(

fp8_output = False
if self.ub_overlap_rs:
if get_ub("fc2_fprop", FP8GlobalStateManager.is_fp8_enabled()).is_fp8_ubuf():
if get_ub_is_fp8("fc2_fprop", FP8GlobalStateManager.is_fp8_enabled()):
fp8_output = True

inp = self.prepare_forward(inp, num_gemms=2)
Expand Down
47 changes: 26 additions & 21 deletions transformer_engine/pytorch/module/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
fill_userbuffers_buffer_for_all_gather,
get_dummy_wgrad,
get_ub,
get_ub_is_fp8,
is_ub_initialized,
using_cublasmp_backend,
quantize_weight,
Expand Down Expand Up @@ -117,6 +118,8 @@ class LinearFwdArgs:
fp8_output: bool
save_original_input: bool
backward_override: Optional[str]
dgrad_use_split_accumulator: bool
wgrad_use_split_accumulator: bool
custom: bool
debug: bool

Expand Down Expand Up @@ -183,7 +186,8 @@ class LinearBwdArgs:
# --- Numerical / dtype config ---
activation_dtype: Optional[torch.dtype] = None
fp8: bool = False
fp8_recipe: Optional[Recipe] = None
dgrad_use_split_accumulator: bool = _2X_ACC_DGRAD
wgrad_use_split_accumulator: bool = _2X_ACC_WGRAD
backward_override: Optional[str] = None
is_weight_param_quantized: bool = False
custom: bool = False
Expand Down Expand Up @@ -656,7 +660,8 @@ def _linear_setup_ctx(
# Numerical / dtype config
bwd_args.activation_dtype = fwd_args.activation_dtype
bwd_args.fp8 = fp8
bwd_args.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None
bwd_args.dgrad_use_split_accumulator = fwd_args.dgrad_use_split_accumulator
bwd_args.wgrad_use_split_accumulator = fwd_args.wgrad_use_split_accumulator
bwd_args.backward_override = backward_override
bwd_args.is_weight_param_quantized = isinstance(weight, QuantizedTensorStorage)
bwd_args.custom = fwd_args.custom
Expand Down Expand Up @@ -958,11 +963,7 @@ def _linear_backward(args: LinearBwdArgs) -> Tuple[Union[torch.Tensor, None], ..
weight_fp8.update_usage(columnwise_usage=True)

# Choose whether to use GEMM kernel with split accumulator
use_split_accumulator = _2X_ACC_DGRAD
if bwd_args.fp8:
recipe = bwd_args.fp8_recipe
if hasattr(recipe, "fp8_gemm_dgrad"):
use_split_accumulator = recipe.fp8_gemm_dgrad.use_split_accumulator
use_split_accumulator = bwd_args.dgrad_use_split_accumulator

# Update grad input quantizer
if grad_input_quantizer is not None:
Expand Down Expand Up @@ -1133,11 +1134,7 @@ def _linear_backward(args: LinearBwdArgs) -> Tuple[Union[torch.Tensor, None], ..
grad_output = grad_output_quantizer(grad_output)

# Figure out whether to use split accumulator
use_split_accumulator = _2X_ACC_WGRAD
if bwd_args.fp8:
recipe = bwd_args.fp8_recipe
if hasattr(recipe, "fp8_gemm_wgrad"):
use_split_accumulator = recipe.fp8_gemm_wgrad.use_split_accumulator
use_split_accumulator = bwd_args.wgrad_use_split_accumulator

# Figure out whether to output wgrad GEMM directly into main grad
if bwd_args.is_first_microbatch is not None:
Expand Down Expand Up @@ -1228,8 +1225,11 @@ def wgrad_gemm(
elif bwd_args.backward_input_needs_gather:
# Gathered input tensor is internal
clear_tensor_data(inputmat_total)
if bwd_args.parallel_mode == "row" and bwd_args.sequence_parallel:
# Gathered grad output tensor is internal
if bwd_args.sequence_parallel and (
bwd_args.parallel_mode == "row"
or (bwd_args.parallel_mode == "column" and bwd_args.fp8)
):
# Gathered (row-SP) or quantized (column-SP FP8) grad_output is internal
clear_tensor_data(grad_output)

# Update grad input if overlapping reduce-scatter with wgrad GEMM
Expand Down Expand Up @@ -1816,14 +1816,10 @@ def forward(
is_first_microbatch = False

if self.ub_overlap_rs_fprop:
if get_ub(
self.ub_name + "_fprop", FP8GlobalStateManager.is_fp8_enabled()
).is_fp8_ubuf():
if get_ub_is_fp8(self.ub_name + "_fprop", FP8GlobalStateManager.is_fp8_enabled()):
fp8_output = True
if self.ub_overlap_rs_dgrad:
if get_ub(
self.ub_name + "_dgrad", FP8GlobalStateManager.is_fp8_enabled()
).is_fp8_ubuf():
if get_ub_is_fp8(self.ub_name + "_dgrad", FP8GlobalStateManager.is_fp8_enabled()):
fp8_grad = True

inp = self.prepare_forward(inp, allow_non_contiguous=isinstance(inp, QuantizedTensor))
Expand Down Expand Up @@ -1861,8 +1857,15 @@ def forward(
self._fp8_workspaces.get(cache_name) if cache_name is not None else None
)

dgrad_use_split_accumulator = _2X_ACC_DGRAD
wgrad_use_split_accumulator = _2X_ACC_WGRAD
if self.fp8:
backward_override = FP8GlobalStateManager.get_fp8_recipe().backward_override
_recipe = FP8GlobalStateManager.get_fp8_recipe()
backward_override = _recipe.backward_override
if hasattr(_recipe, "fp8_gemm_dgrad"):
dgrad_use_split_accumulator = _recipe.fp8_gemm_dgrad.use_split_accumulator
if hasattr(_recipe, "fp8_gemm_wgrad"):
wgrad_use_split_accumulator = _recipe.fp8_gemm_wgrad.use_split_accumulator
else:
backward_override = None
custom = is_custom(input_quantizer) or is_custom(weight_quantizer)
Expand Down Expand Up @@ -1917,6 +1920,8 @@ def forward(
fp8_output=fp8_output,
save_original_input=self.save_original_input,
backward_override=backward_override,
dgrad_use_split_accumulator=dgrad_use_split_accumulator,
wgrad_use_split_accumulator=wgrad_use_split_accumulator,
custom=custom,
debug=debug,
# weight-workspace caching
Expand Down
Loading