Improve Feature hook API and pickle round-trip #2099
Improve Feature hook API and pickle round-trip #2099ricardoV94 merged 14 commits intopymc-devs:mainfrom
Conversation
…(no consumers yet)
…onsistent_ to validate/consistent
…pickle_rm_attr/unpickle and history-strip __getstate__
…rs set to _destroying_apps; fix destroy_handler typo in on_detach
…__getstate__/__setstate__ pickle plumbing
…k registry; switch execute_callbacks/collect_callbacks to use it
|
question: why are we now rewriting pickled fgraphs (and never did before)? Is it scan in numba only compiled/evaled after multiprocessing split? frozen fgraphs? Not relevant for the fix itself just for my mental model |
|
It's not exclusive to scan, I can trigger the error with a regular GLM by using My guess is there's a subtle difference in the multiprocessing code between |
pm sample evals the function at leastonce (init nuts jitter to find a valid point). If you don't eval you don't trigger numba compile. I assume this requires an inner graph that gets rewritten though (ofg or scan). It may make sense to force compilation before the dispatch (or you end up doing the same work per process for no gain). Fix is still valid but can you confirm? |
|
I can't consistently trigger the bug from a trivial MWE, so i'm not sure how to test that |
bdd93f5 to
b9d32e4
Compare
We can read the correctness out of this one. It may require cloudpickle specifically. You also need a rewrite that would back-out (like try inplace eagerly and then find out it's invalid) |
|
MWE: import pickle
import pytensor
import pytensor.tensor as pt
x = pt.vector("x", shape=(5,))
ofg = pytensor.OpFromGraph([x], [x[[0, 1, 2]].set(1)])
f = pytensor.function([x], ofg(x), mode="NUMBA")
pickle.loads(pickle.dumps(f, protocol=-1)) |
|
This is happening because on pickling the numba dispatch is regtriggered, and the OFG is optimized again... one more point for rewrite inner graphs as part of the global rewrite and do only transpilation at dispatch time |
|
I'm pushing a simplification, no hack to make node/graph rewrite pickable (they shouldn't have to be), and the current approach would break for rewrites defined dynmacially/inside a function). Reason is only used for string repr why a change happened, so just keep the string in pickling. |
|
Follow up: reason should always be a string, not an actual rewrite, it's only ever used for compiler_verbose printing anyway. |
|
Nice, you were right that |
I have been periodically hitting rewrite failures in PyMC, it turns out due to pickle roundtripping of
fgraphthat have been partially rewritten. Here is the specific error:How we got here
History originally monkey-patched fgraph.checkpoint with a lambda, but lambdas aren't pickleable. Commit 2dc5129 added pickle_rm_attr to strip those lambdas off fgraph.dict on pickle. The history dict itself was also full of lambdas, so a
__getstate__onReplaceValidatedropped it too. Commit ddb24d7 then replaced the closures with pickleable classes (GetCheckpoint,LambdaExtract), making history pickleable in principle — but never undid the strip, and never added restoration code on the unpickle side.So unpickle reinstalls fgraph.checkpoint but doesn't recreate self.history. The next
replace_all_validatecallsfgraph.checkpoint()and crashes. Latent for 11 years because nothing pickled a fgraph and then re-rewrote it. I think I hit it because we're doingrewrite_pregradon pymc models, then pickle the graph and finish compiling on workers, but only in certain multiprocessing harnesses (pm.sample_smc,pmx.fit_pathfinder).Why a refactor instead of a patch
I started with a one-line restoration of history, but the underlying Feature API has two hand-maintained matched pairs that know nothing about each other:
pickle_rm_attr <-> unpickle, and__getstate__ <-> __setstate__. They have to stay in sync or pickle bugs of this exact shape sneak in. The same bug existed unfound for execute_callbacks_times, which my initial patch immediately surfaced.What this PR does
Feature methods now belong to one of two groups, declared explicitly:
dispatches registered callbacks.
fgraph.<name>(...)and resolves via__getattr__tofeature.<name>(fgraph, ...).Two advantages:
FunctionGraphchecks a whitelist of names instead of loopingthrough _features with try/except on every callback.
__getstate__and__setstate__, and everything works.The only remaining hack:
execute_callbacks_timesis now a lazy property, so it never lands in__dict__and doesn't need special handling on pickle. I don't love it, but it works.