From c043ae6546062e8c89fe93357032a29f67c75a9a Mon Sep 17 00:00:00 2001 From: lxd-cumt <1141051934@qq.com> Date: Wed, 10 Jun 2026 17:20:09 +0800 Subject: [PATCH 1/4] add entrypoint for flagos multi-backend plugin system Signed-off-by: Xianduo Li --- transformer_engine/common/__init__.py | 10 ++++++++++ .../dot_product_attention/dot_product_attention.py | 8 ++++++++ 2 files changed, 18 insertions(+) diff --git a/transformer_engine/common/__init__.py b/transformer_engine/common/__init__.py index 40933f17a9..90831f4991 100644 --- a/transformer_engine/common/__init__.py +++ b/transformer_engine/common/__init__.py @@ -191,6 +191,16 @@ def load_framework_extension(framework: str) -> None: sys.modules[module_name] = solib spec.loader.exec_module(solib) + # Plugin system: if NVTE_ENABLE_PLUGIN=1, let plugin stub take over + # transformer_engine_torch and register original pybind as _nv for CUDA backend. + if os.environ.get("NVTE_ENABLE_PLUGIN", "0") == "1": + sys.modules[module_name + "_nv"] = solib + try: + from plugin import load_plugins + load_plugins() + except ImportError as e: + print(f"[TE] NVTE_ENABLE_PLUGIN=1 but plugin import failed: {e}") + def sanity_checks_for_pypi_installation() -> None: """Ensure that package is installed correctly if using PyPI.""" diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index 2dc42be18a..a9bf80ed94 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -148,6 +148,14 @@ _dpa_fp8ds_reduce_amax = os.getenv("NVTE_DPA_FP8DS_REDUCE_AMAX", "1") == "1" +# Plugin system: override FlashAttention and get_attention_backend if enabled +if os.environ.get("NVTE_ENABLE_PLUGIN", "0") == "1": + _FlashAttentionNative = FlashAttention + FlashAttention = getattr(tex, "flash_attention", _FlashAttentionNative) + dpa_utils._original_get_attention_backend = dpa_utils.get_attention_backend + dpa_utils.get_attention_backend = tex.get_attention_backend + + __all__ = ["DotProductAttention"] From c590d4d1dbceaa3c56a94baa9c5932e41d69461a Mon Sep 17 00:00:00 2001 From: lxd-cumt <1141051934@qq.com> Date: Mon, 15 Jun 2026 17:20:21 +0800 Subject: [PATCH 2/4] support flagos plugin system Signed-off-by: Xianduo Li --- transformer_engine/common/__init__.py | 12 +++++++++--- .../dot_product_attention/dot_product_attention.py | 6 ++++-- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/transformer_engine/common/__init__.py b/transformer_engine/common/__init__.py index 90831f4991..0bdc5b5280 100644 --- a/transformer_engine/common/__init__.py +++ b/transformer_engine/common/__init__.py @@ -16,6 +16,7 @@ import sys import sysconfig from typing import Optional, Tuple +import warnings @functools.lru_cache(maxsize=None) @@ -193,13 +194,18 @@ def load_framework_extension(framework: str) -> None: # Plugin system: if NVTE_ENABLE_PLUGIN=1, let plugin stub take over # transformer_engine_torch and register original pybind as _nv for CUDA backend. - if os.environ.get("NVTE_ENABLE_PLUGIN", "0") == "1": + # Only applies to the PyTorch extension — JAX has no plugin stub. + if os.environ.get("NVTE_ENABLE_PLUGIN", "0") == "1" and framework == "torch": sys.modules[module_name + "_nv"] = solib try: - from plugin import load_plugins + from transformer_engine_plugin_fl import load_plugins load_plugins() except ImportError as e: - print(f"[TE] NVTE_ENABLE_PLUGIN=1 but plugin import failed: {e}") + warnings.warn( + f"NVTE_ENABLE_PLUGIN=1 but plugin import failed: {e}", + ImportWarning, + stacklevel=2, + ) def sanity_checks_for_pypi_installation() -> None: diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index a9bf80ed94..cb2e45d3c7 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -152,8 +152,10 @@ if os.environ.get("NVTE_ENABLE_PLUGIN", "0") == "1": _FlashAttentionNative = FlashAttention FlashAttention = getattr(tex, "flash_attention", _FlashAttentionNative) - dpa_utils._original_get_attention_backend = dpa_utils.get_attention_backend - dpa_utils.get_attention_backend = tex.get_attention_backend + _plugin_get_attention_backend = getattr(tex, "get_attention_backend", None) + if _plugin_get_attention_backend is not None: + dpa_utils._original_get_attention_backend = dpa_utils.get_attention_backend + dpa_utils.get_attention_backend = _plugin_get_attention_backend __all__ = ["DotProductAttention"] From 9a869ecfadfb35ccb0b5f8662b04f82da32403d6 Mon Sep 17 00:00:00 2001 From: Xianduo Li Date: Mon, 15 Jun 2026 17:37:55 +0800 Subject: [PATCH 3/4] polish Signed-off-by: Xianduo Li --- transformer_engine/common/__init__.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/transformer_engine/common/__init__.py b/transformer_engine/common/__init__.py index 0bdc5b5280..9c92cf2bc1 100644 --- a/transformer_engine/common/__init__.py +++ b/transformer_engine/common/__init__.py @@ -196,14 +196,16 @@ def load_framework_extension(framework: str) -> None: # transformer_engine_torch and register original pybind as _nv for CUDA backend. # Only applies to the PyTorch extension — JAX has no plugin stub. if os.environ.get("NVTE_ENABLE_PLUGIN", "0") == "1" and framework == "torch": - sys.modules[module_name + "_nv"] = solib try: from transformer_engine_plugin_fl import load_plugins + sys.modules[module_name + "_nv"] = solib load_plugins() - except ImportError as e: + except Exception as e: + # Rollback _nv registration if plugin failed to fully initialize + sys.modules.pop(module_name + "_nv", None) warnings.warn( - f"NVTE_ENABLE_PLUGIN=1 but plugin import failed: {e}", - ImportWarning, + f"NVTE_ENABLE_PLUGIN=1 but plugin loading failed: {e}", + RuntimeWarning, stacklevel=2, ) From ad0c9bd517e0dc53d8c02c3c749342fa2f5116ba Mon Sep 17 00:00:00 2001 From: Xianduo Li Date: Mon, 15 Jun 2026 18:02:09 +0800 Subject: [PATCH 4/4] fix: complete sys.modules rollback on plugin load failure Signed-off-by: Xianduo Li --- transformer_engine/common/__init__.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/transformer_engine/common/__init__.py b/transformer_engine/common/__init__.py index 9c92cf2bc1..bbe13790a6 100644 --- a/transformer_engine/common/__init__.py +++ b/transformer_engine/common/__init__.py @@ -196,13 +196,19 @@ def load_framework_extension(framework: str) -> None: # transformer_engine_torch and register original pybind as _nv for CUDA backend. # Only applies to the PyTorch extension — JAX has no plugin stub. if os.environ.get("NVTE_ENABLE_PLUGIN", "0") == "1" and framework == "torch": + _original_module = sys.modules.get(module_name) try: from transformer_engine_plugin_fl import load_plugins + sys.modules[module_name + "_nv"] = solib load_plugins() except Exception as e: - # Rollback _nv registration if plugin failed to fully initialize + # Rollback to pre-plugin state if plugin failed to fully initialize sys.modules.pop(module_name + "_nv", None) + if _original_module is not None: + sys.modules[module_name] = _original_module + else: + sys.modules.pop(module_name, None) warnings.warn( f"NVTE_ENABLE_PLUGIN=1 but plugin loading failed: {e}", RuntimeWarning,