diff --git a/pytensor/compile/executor.py b/pytensor/compile/executor.py index 9de44ad9e3..2aa44c34af 100644 --- a/pytensor/compile/executor.py +++ b/pytensor/compile/executor.py @@ -4,6 +4,7 @@ import copyreg import time import warnings +from functools import singledispatch from typing import TYPE_CHECKING import numpy as np @@ -37,6 +38,18 @@ class AliasedMemoryError(Exception): DUPLICATE = object() +@singledispatch +def reseeded_rng_value(current, generator: np.random.Generator): + """Return ``generator`` in the storage representation of ``current``. + + Used by :meth:`Function.reseed_rngs`. Most backends store a NumPy ``Generator`` + directly, so the default returns it unchanged. Backends that store RNGs in another + representation register a conversion keyed on that representation (e.g. the JAX backend + stores a state ``dict``). + """ + return generator + + class Function: r"""A class that wraps the execution of a `VM` making it easier for use as a "function". @@ -810,6 +823,31 @@ def get_shared(self): """ return [i.variable for i in self.maker.inputs if i.implicit] + def reseed_rngs(self, seed=None) -> None: + """Reseed the random generators used by this function. + + Each random input is set to a fresh stream spawned from ``seed`` (an ``int``, + sequence of ints, or ``SeedSequence``; ``None`` draws fresh entropy). This works + for every backend, including JAX, whose compiled functions copy their RNGs at + compile time and so cannot be reseeded through the original shared variables. + """ + from pytensor.tensor.random.type import RandomType + + rng_containers = [ + container + for inp, container in zip( + self.maker.expanded_inputs, self.input_storage, strict=True + ) + if isinstance(inp.variable.type, RandomType) + ] + if not rng_containers: + return + + seed_seqs = np.random.SeedSequence(seed).spawn(len(rng_containers)) + for container, seed_seq in zip(rng_containers, seed_seqs, strict=True): + generator = np.random.Generator(np.random.PCG64(seed_seq)) + container.storage[0] = reseeded_rng_value(container.storage[0], generator) + def dprint(self, **kwargs): """Debug print itself diff --git a/pytensor/link/jax/dispatch/random.py b/pytensor/link/jax/dispatch/random.py index edd9fff8b5..ef20c4dbd2 100644 --- a/pytensor/link/jax/dispatch/random.py +++ b/pytensor/link/jax/dispatch/random.py @@ -9,6 +9,7 @@ ) import pytensor.tensor.random.basic as ptr +from pytensor.compile.executor import reseeded_rng_value from pytensor.graph import Constant from pytensor.link.jax.dispatch.basic import jax_funcify, jax_typify from pytensor.link.jax.dispatch.shape import JAXShapeTuple @@ -80,6 +81,12 @@ def jax_typify_Generator(rng, **kwargs): return state +@reseeded_rng_value.register(dict) +def reseeded_rng_value_jax(current, generator): + # JAX stores RNGs as the typified state dict produced by `jax_typify_Generator`. + return jax_typify_Generator(generator) + + @jax_funcify.register(ptr.RandomVariable) def jax_funcify_RandomVariable(op: ptr.RandomVariable, node, **kwargs): """JAX implementation of random variables.""" diff --git a/tests/compile/test_executor.py b/tests/compile/test_executor.py index 7ac2ef9c1b..6997dd766a 100644 --- a/tests/compile/test_executor.py +++ b/tests/compile/test_executor.py @@ -1111,3 +1111,20 @@ def test_pickle_class_with_functions(self): blah.f2(5, 1) assert blah.f1._finder[blah.s].value != blah2.f1._finder[blah2.s].value + + +def test_reseed_rngs(): + rng = shared(np.random.default_rng(0)) + rv = pt.random.normal(0, 1, size=3, rng=rng) + f = function([], rv, updates={rng: rv.owner.outputs[0]}) + + f.reseed_rngs(123) + draw = f() + f.reseed_rngs(123) + np.testing.assert_array_equal(draw, f()) # same seed -> same draw + f.reseed_rngs(456) + assert not np.array_equal(draw, f()) # different seed -> different draw + + # A function without random inputs is a no-op. + x = scalar("x") + function([x], x * 2).reseed_rngs(0) diff --git a/tests/link/jax/test_random.py b/tests/link/jax/test_random.py index 53f91d9d34..29e1e1bf1f 100644 --- a/tests/link/jax/test_random.py +++ b/tests/link/jax/test_random.py @@ -963,3 +963,18 @@ def test_constant_shape_after_graph_rewriting(self): new_x = clone_replace(x, {size: pt.constant([2, 5])}, rebuild_strict=False) assert new_x.type.shape == (2, 5) assert compile_random_function([], new_x)().shape == (2, 5) + + +def test_reseed_rngs(): + # JAX copies RNG shared variables at compile time, so the originals can't be reseeded + # via set_value; Function.reseed_rngs reseeds the function's own (typified) RNG storage. + rng = shared(np.random.default_rng(0)) + rv = pt.random.normal(0, 1, size=3, rng=rng) + f = function([], rv, updates={rng: rv.owner.outputs[0]}, mode=jax_mode) + + f.reseed_rngs(123) + draw = np.asarray(f()) + f.reseed_rngs(123) + np.testing.assert_array_equal(draw, np.asarray(f())) # same seed -> same draw + f.reseed_rngs(456) + assert not np.array_equal(draw, np.asarray(f())) # different seed -> different draw