Skip to content

Improve Feature hook API and pickle round-trip #2099

Merged
ricardoV94 merged 14 commits intopymc-devs:mainfrom
jessegrabowski:feature-pickle
May 1, 2026
Merged

Improve Feature hook API and pickle round-trip #2099
ricardoV94 merged 14 commits intopymc-devs:mainfrom
jessegrabowski:feature-pickle

Conversation

@jessegrabowski
Copy link
Copy Markdown
Member

@jessegrabowski jessegrabowski commented Apr 29, 2026

I have been periodically hitting rewrite failures in PyMC, it turns out due to pickle roundtripping of fgraph that have been partially rewritten. Here is the specific error:

ERROR (pytensor.graph.rewriting.basic): SequentialGraphRewriter apply <pytensor.tensor.rewriting.elemwise.InplaceElemwiseOptimizer object at 0x331d13380>
ERROR (pytensor.graph.rewriting.basic): Traceback:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
  File "/Users/jesse.grabowski/Python/systematic-credit/.pixi/envs/default/lib/python3.12/site-packages/pytensor/graph/rewriting/basic.py", line 289, in apply
    sub_prof = rewriter.apply(fgraph)
               ^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/jesse.grabowski/Python/systematic-credit/.pixi/envs/default/lib/python3.12/site-packages/pytensor/tensor/rewriting/elemwise.py", line 205, in apply
    fgraph.replace_all_validate(replacements, reason=reason)
  File "/Users/jesse.grabowski/Python/systematic-credit/.pixi/envs/default/lib/python3.12/site-packages/pytensor/graph/features.py", line 731, in replace_all_validate
    chk = fgraph.checkpoint()
          ^^^^^^^^^^^^^^^^^^^
  File "/Users/jesse.grabowski/Python/systematic-credit/.pixi/envs/default/lib/python3.12/site-packages/pytensor/graph/features.py", line 358, in __call__
    self.h.history[self.fgraph] = []
    ^^^^^^^^^^^^^^
AttributeError: 'ReplaceValidate' object has no attribute 'history'

How we got here

History originally monkey-patched fgraph.checkpoint with a lambda, but lambdas aren't pickleable. Commit 2dc5129 added pickle_rm_attr to strip those lambdas off fgraph.dict on pickle. The history dict itself was also full of lambdas, so a __getstate__ on ReplaceValidate dropped it too. Commit ddb24d7 then replaced the closures with pickleable classes (GetCheckpoint, LambdaExtract), making history pickleable in principle — but never undid the strip, and never added restoration code on the unpickle side.

So unpickle reinstalls fgraph.checkpoint but doesn't recreate self.history. The next replace_all_validate calls fgraph.checkpoint() and crashes. Latent for 11 years because nothing pickled a fgraph and then re-rewrote it. I think I hit it because we're doing rewrite_pregradon pymc models, then pickle the graph and finish compiling on workers, but only in certain multiprocessing harnesses (pm.sample_smc, pmx.fit_pathfinder).

Why a refactor instead of a patch

I started with a one-line restoration of history, but the underlying Feature API has two hand-maintained matched pairs that know nothing about each other: pickle_rm_attr <-> unpickle, and __getstate__ <-> __setstate__. They have to stay in sync or pickle bugs of this exact shape sneak in. The same bug existed unfound for execute_callbacks_times, which my initial patch immediately surfaced.

What this PR does

Feature methods now belong to one of two groups, declared explicitly:

  • Callbacks are decorated with @register_feature_callback. FunctionGraph.execute_callbacks only
    dispatches registered callbacks.
  • provides is a tuple of method names listed on the Feature class. Anything listed there is callable as
    fgraph.<name>(...) and resolves via __getattr__ to feature.<name>(fgraph, ...).

Two advantages:

  1. VERY small perf boost on provides dispatch — FunctionGraph checks a whitelist of names instead of looping
    through _features with try/except on every callback.
  2. The pickle hacks are gone. No more pickle_rm_attr, no more unpickle. The pickle protocol is just
    __getstate__ and __setstate__, and everything works.

The only remaining hack: execute_callbacks_times is now a lazy property, so it never lands in __dict__
and doesn't need special handling on pickle. I don't love it, but it works.

@jessegrabowski jessegrabowski added the bug Something isn't working label Apr 29, 2026
@jessegrabowski jessegrabowski changed the title feature-pickle Improve Feature hook API and pickle round-trip Apr 29, 2026
@ricardoV94
Copy link
Copy Markdown
Member

ricardoV94 commented Apr 29, 2026

question: why are we now rewriting pickled fgraphs (and never did before)?

Is it scan in numba only compiled/evaled after multiprocessing split? frozen fgraphs?

Not relevant for the fix itself just for my mental model

@jessegrabowski
Copy link
Copy Markdown
Member Author

It's not exclusive to scan, I can trigger the error with a regular GLM by using pmx.fit_pathfinder.

My guess is there's a subtle difference in the multiprocessing code between pm.sample and other samplers, but it's just a guess.

@ricardoV94
Copy link
Copy Markdown
Member

It's not exclusive to scan, I can trigger the error with a regular GLM by using pmx.fit_pathfinder.

My guess is there's a subtle difference in the multiprocessing code between pm.sample and other samplers, but it's just a guess.

pm sample evals the function at leastonce (init nuts jitter to find a valid point). If you don't eval you don't trigger numba compile. I assume this requires an inner graph that gets rewritten though (ofg or scan).

It may make sense to force compilation before the dispatch (or you end up doing the same work per process for no gain). Fix is still valid but can you confirm?

@jessegrabowski
Copy link
Copy Markdown
Member Author

jessegrabowski commented Apr 29, 2026

I can't consistently trigger the bug from a trivial MWE, so i'm not sure how to test that

@ricardoV94
Copy link
Copy Markdown
Member

ricardoV94 commented Apr 30, 2026

I can't consistently trigger the bug from a trivial MWE, so i'm not sure how to test that

We can read the correctness out of this one. It may require cloudpickle specifically. You also need a rewrite that would back-out (like try inplace eagerly and then find out it's invalid)

@ricardoV94
Copy link
Copy Markdown
Member

ricardoV94 commented Apr 30, 2026

MWE:

import pickle
import pytensor
import pytensor.tensor as pt

x = pt.vector("x", shape=(5,))
ofg = pytensor.OpFromGraph([x], [x[[0, 1, 2]].set(1)])
f = pytensor.function([x], ofg(x), mode="NUMBA")

pickle.loads(pickle.dumps(f, protocol=-1))

@ricardoV94
Copy link
Copy Markdown
Member

This is happening because on pickling the numba dispatch is regtriggered, and the OFG is optimized again... one more point for rewrite inner graphs as part of the global rewrite and do only transpilation at dispatch time

Copy link
Copy Markdown
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also faster on the rewrite benchmark!!!

Edit: NVM I must have compared with a wrong "main" branch. Same speed, just simpler

@ricardoV94
Copy link
Copy Markdown
Member

I'm pushing a simplification, no hack to make node/graph rewrite pickable (they shouldn't have to be), and the current approach would break for rewrites defined dynmacially/inside a function). Reason is only used for string repr why a change happened, so just keep the string in pickling.

@ricardoV94
Copy link
Copy Markdown
Member

Follow up: reason should always be a string, not an actual rewrite, it's only ever used for compiler_verbose printing anyway.

@jessegrabowski
Copy link
Copy Markdown
Member Author

Nice, you were right that _restore_decorated_rewriter was a code smell in the first place. Glad to see it easily eliminated.

@ricardoV94 ricardoV94 merged commit d6b4023 into pymc-devs:main May 1, 2026
66 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants