Skip to content

Avoid unpickling the extra state when not needed#3123

Open
ptrendx wants to merge 3 commits into
NVIDIA:mainfrom
ptrendx:pr_avoid_unpickle
Open

Avoid unpickling the extra state when not needed#3123
ptrendx wants to merge 3 commits into
NVIDIA:mainfrom
ptrendx:pr_avoid_unpickle

Conversation

@ptrendx

@ptrendx ptrendx commented Jun 12, 2026

Copy link
Copy Markdown
Member

Description

Avoids unpickling of the extra state if the recipe is stateless. Adds a guard prompting user to explicitly allow loading of the checkpoint when the unpickling is necessary.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Avoids unpickling of the stateless recipe extra state
  • Adds a guard and environment variable for the delayed scaling recipes

ptrendx added 2 commits June 12, 2026 05:24
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
@greptile-apps

greptile-apps Bot commented Jun 12, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR introduces a security-aware checkpoint loading mechanism for FP8 extra state by adding a CheckpointExtraStatePolicy enum to each recipe class and a new _extra_state.py module that classifies pickled extra-state payloads using pickletools.genops — without executing them — before deciding whether to deserialize. Stateless recipes now write an empty tensor on save and skip deserialization on load; delayed-scaling recipes require NVTE_ALLOW_UNSAFE_PICKLE_EXTRA_STATE=1.

  • CheckpointExtraStatePolicy (STATELESS / STATEFUL / DYNAMIC) is declared as a ClassVar on every Recipe subclass; get_extra_state returns an empty tensor for STATELESS, while STATEFUL and DYNAMIC pickles go through the new classifier before pickle.loads.
  • _classify_extra_state_pickle_impl inspects pickle opcodes to detect the recipe class, match it against a module-level _RECIPE_POLICIES dict, and flag delayed-state key names (scale_fwd etc.), returning IGNORE or UNSAFE_LOAD accordingly.
  • Test files are updated to set NVTE_ALLOW_UNSAFE_PICKLE_EXTRA_STATE=1 around load_state_dict calls for checkpoints that use delayed-scaling recipes.

Confidence Score: 3/5

The core classification logic has a gap: CustomRecipe checkpoints whose state uses keys outside the four hard-coded delayed-scaling names are silently and irrecoverably dropped on load, with no env-var escape hatch, on top of the pre-existing save/load asymmetry for non-delayed CustomRecipe.

Two independent correctness problems converge on the CustomRecipe (DYNAMIC) path. First, get_extra_state writes a non-empty pickle for DYNAMIC recipes whether or not there is any delayed state, while set_extra_state silently discards that pickle when no delayed-state keys are found. Second, the IGNORE branch in _classify_extra_state_pickle_impl is reached before the env-var check, so even a user who sets NVTE_ALLOW_UNSAFE_PICKLE_EXTRA_STATE=1 cannot force loading of a DYNAMIC checkpoint that lacks the four standard delayed-state key names.

transformer_engine/pytorch/_extra_state.py (classifier IGNORE/UNSAFE_LOAD branching for DYNAMIC policy) and transformer_engine/pytorch/module/base.py + transformer_engine/pytorch/ops/op.py (get_extra_state write path for DYNAMIC recipes).

Important Files Changed

Filename Overview
transformer_engine/pytorch/_extra_state.py New module implementing the pickle classifier via pickletools opcode inspection; DYNAMIC without delayed-state keys resolves to IGNORE even when NVTE_ALLOW_UNSAFE_PICKLE_EXTRA_STATE=1 is set, silently dropping custom recipe state with no opt-in path.
transformer_engine/pytorch/module/base.py get_extra_state skips serialization only for STATELESS; DYNAMIC (CustomRecipe) still writes recipe + extra_fp8_variables, but set_extra_state silently ignores that payload, creating a save/load asymmetry.
transformer_engine/pytorch/ops/op.py Same DYNAMIC save/load asymmetry as base.py: get_extra_state serializes per-mode state for DYNAMIC recipes, but set_extra_state ignores it when no delayed-state keys are present.
transformer_engine/common/recipe/init.py Adds CheckpointExtraStatePolicy enum and assigns STATELESS/STATEFUL/DYNAMIC ClassVar to each Recipe subclass; clean and consistent.
tests/pytorch/test_recipe.py New unit tests cover classifier correctness for all policy variants, including legacy/malformed payloads and the env-var opt-in path.
tests/pytorch/test_checkpoint.py Correctly wraps load_state_dict with NVTE_ALLOW_UNSAFE_PICKLE_EXTRA_STATE=1 only when quantization=="fp8", restoring the previous env-var value in a finally block.
tests/pytorch/test_numerics.py Guards the env var around load_state_dict using recipe.delayed(), correctly limiting opt-in to delayed-scaling recipes.
tests/pytorch/test_fusible_ops.py Opts in for both fp8 and fp8_delayed_scaling quantization variants; env-var save/restore pattern is correct.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["get_extra_state()"] --> B{fp8_checkpoint?}
    B -- No --> C["return empty tensor"]
    B -- Yes --> D{checkpoint_extra_state_policy?}
    D -- STATELESS --> C
    D -- "STATEFUL / DYNAMIC" --> E["serialize: recipe + extra_fp8_variables\n+ delayed tensors if present"]
    E --> F["pickle.dumps → byte tensor"]

    G["set_extra_state(state)"] --> H{state empty?}
    H -- Yes --> I["return early"]
    H -- No --> J["should_load_extra_state_pickle()"]
    J --> K["_classify_extra_state_pickle_impl()"]
    K --> L{has_recipe_key?}
    L -- No --> M["UNSAFE_LOAD\n(legacy TE 1.x)"]
    L -- Yes --> N{STATEFUL in policies?}
    N -- Yes --> M
    N -- No --> O{has_delayed_state_keys?}
    O -- Yes --> M
    O -- No --> P{policies empty?}
    P -- Yes --> M
    P -- No --> Q["IGNORE\n(STATELESS or DYNAMIC\nwithout delayed state)"]
    M --> R{NVTE_ALLOW_UNSAFE\n_PICKLE_EXTRA_STATE=1?}
    R -- No --> S["raise RuntimeError"]
    R -- Yes --> T["pickle.loads → restore\nrecipe + tensors"]
    Q --> I
Loading
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
flowchart TD
    A["get_extra_state()"] --> B{fp8_checkpoint?}
    B -- No --> C["return empty tensor"]
    B -- Yes --> D{checkpoint_extra_state_policy?}
    D -- STATELESS --> C
    D -- "STATEFUL / DYNAMIC" --> E["serialize: recipe + extra_fp8_variables\n+ delayed tensors if present"]
    E --> F["pickle.dumps → byte tensor"]

    G["set_extra_state(state)"] --> H{state empty?}
    H -- Yes --> I["return early"]
    H -- No --> J["should_load_extra_state_pickle()"]
    J --> K["_classify_extra_state_pickle_impl()"]
    K --> L{has_recipe_key?}
    L -- No --> M["UNSAFE_LOAD\n(legacy TE 1.x)"]
    L -- Yes --> N{STATEFUL in policies?}
    N -- Yes --> M
    N -- No --> O{has_delayed_state_keys?}
    O -- Yes --> M
    O -- No --> P{policies empty?}
    P -- Yes --> M
    P -- No --> Q["IGNORE\n(STATELESS or DYNAMIC\nwithout delayed state)"]
    M --> R{NVTE_ALLOW_UNSAFE\n_PICKLE_EXTRA_STATE=1?}
    R -- No --> S["raise RuntimeError"]
    R -- Yes --> T["pickle.loads → restore\nrecipe + tensors"]
    Q --> I
Loading

Reviews (2): Last reviewed commit: "Merge branch 'main' into pr_avoid_unpick..." | Re-trigger Greptile

Comment on lines +37 to +41
_RECIPE_POLICIES = {
(_RECIPE_MODULE, cls.__name__): cls.checkpoint_extra_state_policy
for cls in _recipe_subclasses(Recipe)
if cls.checkpoint_extra_state_policy is not None
}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 _RECIPE_POLICIES misses user-defined Recipe subclasses

_RECIPE_POLICIES is computed once at import time and only contains subclasses visible inside transformer_engine.common.recipe. Any Recipe subclass defined in user code (or in a different module) will not appear in this dict, so _classify_extra_state_pickle_impl will find an empty policies set and return UNSAFE_LOAD, forcing the user to set NVTE_ALLOW_UNSAFE_PICKLE_EXTRA_STATE=1 even for a genuinely stateless custom recipe. The test test_checkpoint_extra_state_policy_classifier_map_covers_all_recipes only asserts coverage for first-party recipes and will not catch this gap for downstream users.

@timmoon10 timmoon10 left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall this seems like a reasonable fix, although I have some design suggestions and nits. FP8 delayed scaling still has pickling, but at least we can avoid it for more modern recipes.

HYBRID = _FormatHelper(max_fwd=E4M3.max_fwd, max_bwd=E5M2.max_bwd)


class CheckpointExtraStatePolicy(Enum):

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's uncomfortable how PyTorch-specific logic is making its way into te.common. Perhaps we should try to phrase this more generically so it's not about torch.nn.Module.set_extra_state, but about recipe statefulness. However, this won't work if we implement future stateful recipes and we want to add an enum value for non-pickle checkpoint formats.

Since the bug comes from a mistake in te.pytorch, it might be better for the hacky WAR to also live in te.pytorch. I see that _extra_state.py does some type inspection to create a map between recipes and policies:
https://github.com/ptrendx/TransformerEngine/blob/69a5ab8fd4bfeb188b40d8886faa155c9014aeaf/transformer_engine/pytorch/_extra_state.py#L37-L41
How about we just define the recipe-policy map directly? Each time we add a new stateful recipe, we can update the map and add any additional checkpointing logic there instead of touching te.common.


_RECIPE_MODULE = "transformer_engine.common.recipe"
_RECIPE_KEY = "recipe"
_DELAYED_STATE_KEYS = {

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: All of this is really a backward-compatibility hack for FP8 delayed scaling, and if we implement other delayed scaling recipes in the future we're not going to use pickles in checkpoints.

Suggested change
_DELAYED_STATE_KEYS = {
_FLOAT8_DELAYED_SCALING_STATE_KEYS = {

"""

STATELESS = "stateless"
STATEFUL = "stateful"

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may have stateful recipes in the future, but we've learned our lesson not to naively pickle. We should make clear that this particular enum value represents stateful recipes with unsafe pickling.

Suggested change
STATEFUL = "stateful"
STATEFUL_FP8_DELAYED_SCALING = "stateful_fp8_delayed_scaling"

Other possible names could be STATEFUL_PICKLE or STATEFUL_UNSAFE.

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Comment on lines +139 to +142
if not policies:
return _PickledExtraStateAction.UNSAFE_LOAD

return _PickledExtraStateAction.IGNORE

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 DYNAMIC without delayed state is silently ignored even with the opt-in env var

The final return _PickledExtraStateAction.IGNORE is reached when policies contains only DYNAMIC (i.e. CustomRecipe) and has_delayed_state_keys is False. should_load_extra_state_pickle short-circuits on IGNORE before ever consulting unsafe_pickle_extra_state_enabled(), so setting NVTE_ALLOW_UNSAFE_PICKLE_EXTRA_STATE=1 has no effect for this case.

A CustomRecipe user whose checkpoint contains state stored under non-standard key names (anything outside _DELAYED_STATE_KEYS) will have that state silently dropped on load with no recoverable opt-in path, even if they know the checkpoint is from a trusted source and explicitly set the env var.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants