Add Function.reseed_rngs to reseed a compiled function's RNGs#2271
Add Function.reseed_rngs to reseed a compiled function's RNGs#2271velochy wants to merge 1 commit into
Function.reseed_rngs to reseed a compiled function's RNGs#2271Conversation
A compiled function's RNG state can only be reseeded through the original shared variables for backends that read them at call time (C, numba). The JAX backend copies and typifies its RNG shared variables at compile time, so those copies cannot be reached afterwards — reseeding the originals has no effect. Add `Function.reseed_rngs(seed)`, which seeds the function's own RNG storage containers directly. A `reseeded_rng_value` singledispatch converts a fresh Generator into the representation each backend stores (a plain Generator by default; the JAX backend registers the typified state dict), so the method is backend-agnostic. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
|
@ricardoV94 so this is what AI was proposing for pymc-devs/pymc#8330 Or am I missing something that makes this a bad idea or already provides this functionality? |
|
Probably something more agnostic imo, Function.update_shareds? We may have other types that are compabile across backends, see #278. There's also a discussion/issue on this question for shared specifically I think. |
Can you elaborate a bit more what you mean by that. Something that allows not just setting the random seed but more things? But what kinds of things? Because at least on a cursory look, it seems rng is a special case with some models taking numpy generator but jax requiring conversion to something else. So would the other things it is intended to set also require similar conversions and we just have a switch-case style block that chooses which one based on model type and what is being set? Honestly, from the side, reseed_rng does seem like the right level of concrete for that reason. But I do understand you not wanting to pollute Function api with multiple different functions for niche use cases too, so I guess there might be a case for bundling there too. |
Sparse variables for instance. They are represented in JAX with BCOO. The one graph/multiple backends breaks once we deviate from regular dense arrays, and specially so in Shared variables that are supposed to hold a value that can be present in multiple functions. Hence why we need to break the connection for RNGs when we compile to JAX. Also the reseed rng is not really the right abstraction here. PyTensor is built of individual separate "Generator" root variables. The caller who wants to reseed them from a common source of entropy should pass the distinct ones to each shared variable not expect PyTensor function to understand one rng -> magic split. This is why we have that sort of logic living in PyMC not PyTensor. The biggest issue here is that these SharedVariables aren't really shared. And it's not even clear for the user what the input data should look like. The PR I linked alluded to tried to at least help with non-shared variables but it's still not something great. For the SharedVariable there was some discussion in https://github.com/pymc-devs/design-notes/blob/d09f899a3a8c6e60ce739b3d7db440d70bb077bf/PyTensor%20design%20meeting%20(April%2014%2C%202023).md?plain=1#L3-L85 |
|
Ok this feels like a very deep rabbithole all of a sudden. |
Motivation
Caching compiled PyTensor functions and reseeding them per call (see pymc-devs/pymc#8330) currently only works for backends that read their RNG shared variables at call time (C, numba): you can
reseed_rngs(shared_rngs, seed)after compilation and the function picks it up.The JAX backend copies and typifies its RNG shared variables at compile time (
JAXLinker.fgraph_convert), so the compiled function uses detached copies stored as a statedict. Reseeding the original shared variables therefore has no effect, and a cached JAX function cannot be reseeded — same-seed calls diverge.Change
Add
Function.reseed_rngs(seed), which seeds the function's own RNG storage containers directly, so it works for every backend including JAX.A
reseeded_rng_valuesingledispatchconverts a freshGeneratorinto the representation each backend stores — a plainGeneratorby default, with the JAX backend registering the typified statedict(reusingjax_typify_Generator). This keeps backend-specific knowledge in the backend module, mirroringjax_typify.seedaccepts anint, a sequence of ints, or aSeedSequence(Nonedraws fresh entropy), matchingRandomStream.seed.Tests
tests/compile/test_executor.py::test_reseed_rngs— default backend: same seed → same draw, different seed → differs, and a no-op when the function has no RNG inputs.tests/link/jax/test_random.py::test_reseed_rngs— the JAX case.Verified reproducible reseeding for the C, numba and JAX backends, including reseeding before the first call; the default path does not import the JAX dispatch.