From 461dc7a9063ff1c67c88116bb157cab49364d2cf Mon Sep 17 00:00:00 2001 From: Margus Niitsoo Date: Sat, 27 Jun 2026 17:17:37 +0300 Subject: [PATCH] Add Function.reseed_rngs to reseed a compiled function's RNGs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit A compiled function's RNG state can only be reseeded through the original shared variables for backends that read them at call time (C, numba). The JAX backend copies and typifies its RNG shared variables at compile time, so those copies cannot be reached afterwards — reseeding the originals has no effect. Add `Function.reseed_rngs(seed)`, which seeds the function's own RNG storage containers directly. A `reseeded_rng_value` singledispatch converts a fresh Generator into the representation each backend stores (a plain Generator by default; the JAX backend registers the typified state dict), so the method is backend-agnostic. Co-Authored-By: Claude Opus 4.8 (1M context) --- pytensor/compile/executor.py | 38 ++++++++++++++++++++++++++++ pytensor/link/jax/dispatch/random.py | 7 +++++ tests/compile/test_executor.py | 17 +++++++++++++ tests/link/jax/test_random.py | 15 +++++++++++ 4 files changed, 77 insertions(+) diff --git a/pytensor/compile/executor.py b/pytensor/compile/executor.py index 9de44ad9e3..2aa44c34af 100644 --- a/pytensor/compile/executor.py +++ b/pytensor/compile/executor.py @@ -4,6 +4,7 @@ import copyreg import time import warnings +from functools import singledispatch from typing import TYPE_CHECKING import numpy as np @@ -37,6 +38,18 @@ class AliasedMemoryError(Exception): DUPLICATE = object() +@singledispatch +def reseeded_rng_value(current, generator: np.random.Generator): + """Return ``generator`` in the storage representation of ``current``. + + Used by :meth:`Function.reseed_rngs`. Most backends store a NumPy ``Generator`` + directly, so the default returns it unchanged. Backends that store RNGs in another + representation register a conversion keyed on that representation (e.g. the JAX backend + stores a state ``dict``). + """ + return generator + + class Function: r"""A class that wraps the execution of a `VM` making it easier for use as a "function". @@ -810,6 +823,31 @@ def get_shared(self): """ return [i.variable for i in self.maker.inputs if i.implicit] + def reseed_rngs(self, seed=None) -> None: + """Reseed the random generators used by this function. + + Each random input is set to a fresh stream spawned from ``seed`` (an ``int``, + sequence of ints, or ``SeedSequence``; ``None`` draws fresh entropy). This works + for every backend, including JAX, whose compiled functions copy their RNGs at + compile time and so cannot be reseeded through the original shared variables. + """ + from pytensor.tensor.random.type import RandomType + + rng_containers = [ + container + for inp, container in zip( + self.maker.expanded_inputs, self.input_storage, strict=True + ) + if isinstance(inp.variable.type, RandomType) + ] + if not rng_containers: + return + + seed_seqs = np.random.SeedSequence(seed).spawn(len(rng_containers)) + for container, seed_seq in zip(rng_containers, seed_seqs, strict=True): + generator = np.random.Generator(np.random.PCG64(seed_seq)) + container.storage[0] = reseeded_rng_value(container.storage[0], generator) + def dprint(self, **kwargs): """Debug print itself diff --git a/pytensor/link/jax/dispatch/random.py b/pytensor/link/jax/dispatch/random.py index edd9fff8b5..ef20c4dbd2 100644 --- a/pytensor/link/jax/dispatch/random.py +++ b/pytensor/link/jax/dispatch/random.py @@ -9,6 +9,7 @@ ) import pytensor.tensor.random.basic as ptr +from pytensor.compile.executor import reseeded_rng_value from pytensor.graph import Constant from pytensor.link.jax.dispatch.basic import jax_funcify, jax_typify from pytensor.link.jax.dispatch.shape import JAXShapeTuple @@ -80,6 +81,12 @@ def jax_typify_Generator(rng, **kwargs): return state +@reseeded_rng_value.register(dict) +def reseeded_rng_value_jax(current, generator): + # JAX stores RNGs as the typified state dict produced by `jax_typify_Generator`. + return jax_typify_Generator(generator) + + @jax_funcify.register(ptr.RandomVariable) def jax_funcify_RandomVariable(op: ptr.RandomVariable, node, **kwargs): """JAX implementation of random variables.""" diff --git a/tests/compile/test_executor.py b/tests/compile/test_executor.py index 7ac2ef9c1b..6997dd766a 100644 --- a/tests/compile/test_executor.py +++ b/tests/compile/test_executor.py @@ -1111,3 +1111,20 @@ def test_pickle_class_with_functions(self): blah.f2(5, 1) assert blah.f1._finder[blah.s].value != blah2.f1._finder[blah2.s].value + + +def test_reseed_rngs(): + rng = shared(np.random.default_rng(0)) + rv = pt.random.normal(0, 1, size=3, rng=rng) + f = function([], rv, updates={rng: rv.owner.outputs[0]}) + + f.reseed_rngs(123) + draw = f() + f.reseed_rngs(123) + np.testing.assert_array_equal(draw, f()) # same seed -> same draw + f.reseed_rngs(456) + assert not np.array_equal(draw, f()) # different seed -> different draw + + # A function without random inputs is a no-op. + x = scalar("x") + function([x], x * 2).reseed_rngs(0) diff --git a/tests/link/jax/test_random.py b/tests/link/jax/test_random.py index 53f91d9d34..29e1e1bf1f 100644 --- a/tests/link/jax/test_random.py +++ b/tests/link/jax/test_random.py @@ -963,3 +963,18 @@ def test_constant_shape_after_graph_rewriting(self): new_x = clone_replace(x, {size: pt.constant([2, 5])}, rebuild_strict=False) assert new_x.type.shape == (2, 5) assert compile_random_function([], new_x)().shape == (2, 5) + + +def test_reseed_rngs(): + # JAX copies RNG shared variables at compile time, so the originals can't be reseeded + # via set_value; Function.reseed_rngs reseeds the function's own (typified) RNG storage. + rng = shared(np.random.default_rng(0)) + rv = pt.random.normal(0, 1, size=3, rng=rng) + f = function([], rv, updates={rng: rv.owner.outputs[0]}, mode=jax_mode) + + f.reseed_rngs(123) + draw = np.asarray(f()) + f.reseed_rngs(123) + np.testing.assert_array_equal(draw, np.asarray(f())) # same seed -> same draw + f.reseed_rngs(456) + assert not np.array_equal(draw, np.asarray(f())) # different seed -> different draw