Avoid unpickling the extra state when not needed#3123
Conversation
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
Greptile SummaryThis PR introduces a security-aware checkpoint loading mechanism for FP8 extra state by adding a
Confidence Score: 3/5The core classification logic has a gap: CustomRecipe checkpoints whose state uses keys outside the four hard-coded delayed-scaling names are silently and irrecoverably dropped on load, with no env-var escape hatch, on top of the pre-existing save/load asymmetry for non-delayed CustomRecipe. Two independent correctness problems converge on the CustomRecipe (DYNAMIC) path. First, get_extra_state writes a non-empty pickle for DYNAMIC recipes whether or not there is any delayed state, while set_extra_state silently discards that pickle when no delayed-state keys are found. Second, the IGNORE branch in _classify_extra_state_pickle_impl is reached before the env-var check, so even a user who sets NVTE_ALLOW_UNSAFE_PICKLE_EXTRA_STATE=1 cannot force loading of a DYNAMIC checkpoint that lacks the four standard delayed-state key names. transformer_engine/pytorch/_extra_state.py (classifier IGNORE/UNSAFE_LOAD branching for DYNAMIC policy) and transformer_engine/pytorch/module/base.py + transformer_engine/pytorch/ops/op.py (get_extra_state write path for DYNAMIC recipes). Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A["get_extra_state()"] --> B{fp8_checkpoint?}
B -- No --> C["return empty tensor"]
B -- Yes --> D{checkpoint_extra_state_policy?}
D -- STATELESS --> C
D -- "STATEFUL / DYNAMIC" --> E["serialize: recipe + extra_fp8_variables\n+ delayed tensors if present"]
E --> F["pickle.dumps → byte tensor"]
G["set_extra_state(state)"] --> H{state empty?}
H -- Yes --> I["return early"]
H -- No --> J["should_load_extra_state_pickle()"]
J --> K["_classify_extra_state_pickle_impl()"]
K --> L{has_recipe_key?}
L -- No --> M["UNSAFE_LOAD\n(legacy TE 1.x)"]
L -- Yes --> N{STATEFUL in policies?}
N -- Yes --> M
N -- No --> O{has_delayed_state_keys?}
O -- Yes --> M
O -- No --> P{policies empty?}
P -- Yes --> M
P -- No --> Q["IGNORE\n(STATELESS or DYNAMIC\nwithout delayed state)"]
M --> R{NVTE_ALLOW_UNSAFE\n_PICKLE_EXTRA_STATE=1?}
R -- No --> S["raise RuntimeError"]
R -- Yes --> T["pickle.loads → restore\nrecipe + tensors"]
Q --> I
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
flowchart TD
A["get_extra_state()"] --> B{fp8_checkpoint?}
B -- No --> C["return empty tensor"]
B -- Yes --> D{checkpoint_extra_state_policy?}
D -- STATELESS --> C
D -- "STATEFUL / DYNAMIC" --> E["serialize: recipe + extra_fp8_variables\n+ delayed tensors if present"]
E --> F["pickle.dumps → byte tensor"]
G["set_extra_state(state)"] --> H{state empty?}
H -- Yes --> I["return early"]
H -- No --> J["should_load_extra_state_pickle()"]
J --> K["_classify_extra_state_pickle_impl()"]
K --> L{has_recipe_key?}
L -- No --> M["UNSAFE_LOAD\n(legacy TE 1.x)"]
L -- Yes --> N{STATEFUL in policies?}
N -- Yes --> M
N -- No --> O{has_delayed_state_keys?}
O -- Yes --> M
O -- No --> P{policies empty?}
P -- Yes --> M
P -- No --> Q["IGNORE\n(STATELESS or DYNAMIC\nwithout delayed state)"]
M --> R{NVTE_ALLOW_UNSAFE\n_PICKLE_EXTRA_STATE=1?}
R -- No --> S["raise RuntimeError"]
R -- Yes --> T["pickle.loads → restore\nrecipe + tensors"]
Q --> I
Reviews (2): Last reviewed commit: "Merge branch 'main' into pr_avoid_unpick..." | Re-trigger Greptile |
| _RECIPE_POLICIES = { | ||
| (_RECIPE_MODULE, cls.__name__): cls.checkpoint_extra_state_policy | ||
| for cls in _recipe_subclasses(Recipe) | ||
| if cls.checkpoint_extra_state_policy is not None | ||
| } |
There was a problem hiding this comment.
_RECIPE_POLICIES misses user-defined Recipe subclasses
_RECIPE_POLICIES is computed once at import time and only contains subclasses visible inside transformer_engine.common.recipe. Any Recipe subclass defined in user code (or in a different module) will not appear in this dict, so _classify_extra_state_pickle_impl will find an empty policies set and return UNSAFE_LOAD, forcing the user to set NVTE_ALLOW_UNSAFE_PICKLE_EXTRA_STATE=1 even for a genuinely stateless custom recipe. The test test_checkpoint_extra_state_policy_classifier_map_covers_all_recipes only asserts coverage for first-party recipes and will not catch this gap for downstream users.
timmoon10
left a comment
There was a problem hiding this comment.
Overall this seems like a reasonable fix, although I have some design suggestions and nits. FP8 delayed scaling still has pickling, but at least we can avoid it for more modern recipes.
| HYBRID = _FormatHelper(max_fwd=E4M3.max_fwd, max_bwd=E5M2.max_bwd) | ||
|
|
||
|
|
||
| class CheckpointExtraStatePolicy(Enum): |
There was a problem hiding this comment.
It's uncomfortable how PyTorch-specific logic is making its way into te.common. Perhaps we should try to phrase this more generically so it's not about torch.nn.Module.set_extra_state, but about recipe statefulness. However, this won't work if we implement future stateful recipes and we want to add an enum value for non-pickle checkpoint formats.
Since the bug comes from a mistake in te.pytorch, it might be better for the hacky WAR to also live in te.pytorch. I see that _extra_state.py does some type inspection to create a map between recipes and policies:
https://github.com/ptrendx/TransformerEngine/blob/69a5ab8fd4bfeb188b40d8886faa155c9014aeaf/transformer_engine/pytorch/_extra_state.py#L37-L41
How about we just define the recipe-policy map directly? Each time we add a new stateful recipe, we can update the map and add any additional checkpointing logic there instead of touching te.common.
|
|
||
| _RECIPE_MODULE = "transformer_engine.common.recipe" | ||
| _RECIPE_KEY = "recipe" | ||
| _DELAYED_STATE_KEYS = { |
There was a problem hiding this comment.
Nit: All of this is really a backward-compatibility hack for FP8 delayed scaling, and if we implement other delayed scaling recipes in the future we're not going to use pickles in checkpoints.
| _DELAYED_STATE_KEYS = { | |
| _FLOAT8_DELAYED_SCALING_STATE_KEYS = { |
| """ | ||
|
|
||
| STATELESS = "stateless" | ||
| STATEFUL = "stateful" |
There was a problem hiding this comment.
We may have stateful recipes in the future, but we've learned our lesson not to naively pickle. We should make clear that this particular enum value represents stateful recipes with unsafe pickling.
| STATEFUL = "stateful" | |
| STATEFUL_FP8_DELAYED_SCALING = "stateful_fp8_delayed_scaling" |
Other possible names could be STATEFUL_PICKLE or STATEFUL_UNSAFE.
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
| if not policies: | ||
| return _PickledExtraStateAction.UNSAFE_LOAD | ||
|
|
||
| return _PickledExtraStateAction.IGNORE |
There was a problem hiding this comment.
DYNAMIC without delayed state is silently ignored even with the opt-in env var
The final return _PickledExtraStateAction.IGNORE is reached when policies contains only DYNAMIC (i.e. CustomRecipe) and has_delayed_state_keys is False. should_load_extra_state_pickle short-circuits on IGNORE before ever consulting unsafe_pickle_extra_state_enabled(), so setting NVTE_ALLOW_UNSAFE_PICKLE_EXTRA_STATE=1 has no effect for this case.
A CustomRecipe user whose checkpoint contains state stored under non-standard key names (anything outside _DELAYED_STATE_KEYS) will have that state silently dropped on load with no recoverable opt-in path, even if they know the checkpoint is from a trusted source and explicitly set the env var.
Description
Avoids unpickling of the extra state if the recipe is stateless. Adds a guard prompting user to explicitly allow loading of the checkpoint when the unpickling is necessary.
Type of change
Changes
Please list the changes introduced in this PR: