Skip to content

Add Function.reseed_rngs to reseed a compiled function's RNGs#2271

Draft
velochy wants to merge 1 commit into
pymc-devs:mainfrom
velochy:jax-reseed-rngs
Draft

Add Function.reseed_rngs to reseed a compiled function's RNGs#2271
velochy wants to merge 1 commit into
pymc-devs:mainfrom
velochy:jax-reseed-rngs

Conversation

@velochy

@velochy velochy commented Jun 30, 2026

Copy link
Copy Markdown
Contributor

Motivation

Caching compiled PyTensor functions and reseeding them per call (see pymc-devs/pymc#8330) currently only works for backends that read their RNG shared variables at call time (C, numba): you can reseed_rngs(shared_rngs, seed) after compilation and the function picks it up.

The JAX backend copies and typifies its RNG shared variables at compile time (JAXLinker.fgraph_convert), so the compiled function uses detached copies stored as a state dict. Reseeding the original shared variables therefore has no effect, and a cached JAX function cannot be reseeded — same-seed calls diverge.

Change

Add Function.reseed_rngs(seed), which seeds the function's own RNG storage containers directly, so it works for every backend including JAX.

A reseeded_rng_value singledispatch converts a fresh Generator into the representation each backend stores — a plain Generator by default, with the JAX backend registering the typified state dict (reusing jax_typify_Generator). This keeps backend-specific knowledge in the backend module, mirroring jax_typify.

seed accepts an int, a sequence of ints, or a SeedSequence (None draws fresh entropy), matching RandomStream.seed.

Tests

  • tests/compile/test_executor.py::test_reseed_rngs — default backend: same seed → same draw, different seed → differs, and a no-op when the function has no RNG inputs.
  • tests/link/jax/test_random.py::test_reseed_rngs — the JAX case.

Verified reproducible reseeding for the C, numba and JAX backends, including reseeding before the first call; the default path does not import the JAX dispatch.

A compiled function's RNG state can only be reseeded through the original
shared variables for backends that read them at call time (C, numba). The
JAX backend copies and typifies its RNG shared variables at compile time,
so those copies cannot be reached afterwards — reseeding the originals has
no effect.

Add `Function.reseed_rngs(seed)`, which seeds the function's own RNG
storage containers directly. A `reseeded_rng_value` singledispatch
converts a fresh Generator into the representation each backend stores
(a plain Generator by default; the JAX backend registers the typified
state dict), so the method is backend-agnostic.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@velochy

velochy commented Jun 30, 2026

Copy link
Copy Markdown
Contributor Author

@ricardoV94 so this is what AI was proposing for pymc-devs/pymc#8330
It seems like a nice clean change to unify randomness (re)seeding for jax and other backends, hopefully leading to cleaner code in other places down the line too.

Or am I missing something that makes this a bad idea or already provides this functionality?

@ricardoV94

Copy link
Copy Markdown
Member

Probably something more agnostic imo, Function.update_shareds? We may have other types that are compabile across backends, see #278. There's also a discussion/issue on this question for shared specifically I think.

@velochy

velochy commented Jun 30, 2026

Copy link
Copy Markdown
Contributor Author

Probably something more agnostic imo, Function.update_shareds?

Can you elaborate a bit more what you mean by that. Something that allows not just setting the random seed but more things? But what kinds of things?

Because at least on a cursory look, it seems rng is a special case with some models taking numpy generator but jax requiring conversion to something else. So would the other things it is intended to set also require similar conversions and we just have a switch-case style block that chooses which one based on model type and what is being set?

Honestly, from the side, reseed_rng does seem like the right level of concrete for that reason. But I do understand you not wanting to pollute Function api with multiple different functions for niche use cases too, so I guess there might be a case for bundling there too.

@ricardoV94

ricardoV94 commented Jun 30, 2026

Copy link
Copy Markdown
Member

But what kinds of things?

Sparse variables for instance. They are represented in JAX with BCOO.

The one graph/multiple backends breaks once we deviate from regular dense arrays, and specially so in Shared variables that are supposed to hold a value that can be present in multiple functions. Hence why we need to break the connection for RNGs when we compile to JAX.

Also the reseed rng is not really the right abstraction here. PyTensor is built of individual separate "Generator" root variables. The caller who wants to reseed them from a common source of entropy should pass the distinct ones to each shared variable not expect PyTensor function to understand one rng -> magic split. This is why we have that sort of logic living in PyMC not PyTensor.

The biggest issue here is that these SharedVariables aren't really shared. And it's not even clear for the user what the input data should look like. The PR I linked alluded to tried to at least help with non-shared variables but it's still not something great. For the SharedVariable there was some discussion in https://github.com/pymc-devs/design-notes/blob/d09f899a3a8c6e60ce739b3d7db440d70bb077bf/PyTensor%20design%20meeting%20(April%2014%2C%202023).md?plain=1#L3-L85

@velochy

velochy commented Jun 30, 2026

Copy link
Copy Markdown
Contributor Author

Ok this feels like a very deep rabbithole all of a sudden.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants