diff --git a/pytensor/xtensor/basic.py b/pytensor/xtensor/basic.py index 09a8d8fe1f..61d6d73277 100644 --- a/pytensor/xtensor/basic.py +++ b/pytensor/xtensor/basic.py @@ -1,12 +1,18 @@ from collections.abc import Sequence from pytensor.compile.ops import TypeCastingOp +from pytensor.gradient import DisconnectedType, disconnected_type, grad_undefined from pytensor.graph import Apply, Op from pytensor.graph.basic import Variable -from pytensor.tensor.type import TensorType +from pytensor.tensor.type import TensorType, continuous_dtypes from pytensor.xtensor.type import XTensorType, as_xtensor, xtensor +def grad_connected(var: Variable) -> bool: + """Whether an XOp input can carry a cotangent (a continuous-dtype xtensor).""" + return isinstance(var.type, XTensorType) and var.type.dtype in continuous_dtypes + + class XOp(Op): """A base class for XOps that shouldn't be materialized""" @@ -18,12 +24,59 @@ def perform(self, node, inputs, outputs): def do_constant_folding(self, fgraph, node): return False + def pullback(self, inputs, outputs, cotangents): + # XOps carry no gradient of their own. Defer to LazyGrad, which the + # expand_lazy_grad rewrite differentiates by lowering core_op to tensor ops and + # taking their pullback, so no XOp runs lowering inside its own pullback. Discrete + # xtensor inputs (e.g. integer indices) have an undefined gradient; structural + # inputs (slices, rngs) are disconnected. + from pytensor.xtensor.shape import zeros_like + + # A disconnected cotangent (no contribution from that output) becomes a zero, + # so LazyGrad never takes a DisconnectedType as an input. + cotangents = [ + zeros_like(out) if isinstance(cot.type, DisconnectedType) else cot + for cot, out in zip(cotangents, outputs) + ] + grads = iter( + LazyGrad(self, len(outputs))(*inputs, *cotangents, return_list=True) + ) + return [ + next(grads) + if grad_connected(inp) + else grad_undefined(self, i, inp) + if isinstance(inp.type, XTensorType) + else disconnected_type() + for i, inp in enumerate(inputs) + ] + def vectorize_node( self, node, *new_inputs, new_dim: str | None ) -> Sequence[Variable]: raise NotImplementedError(f"Vectorized node not implemented for {self}") +class LazyGrad(XOp): + """Deferred vector-Jacobian product of another XOp. + + Wraps the differentiated ``core_op`` with its inputs and the output cotangents. The + ``expand_lazy_grad`` rewrite differentiates it by lowering ``core_op`` to tensor ops + and taking their pullback, so no XOp ever runs lowering inside its own pullback. + There is one output per differentiable (continuous-dtype) input. + """ + + __props__ = ("core_op", "n_cotangents") + + def __init__(self, core_op: Op, n_cotangents: int): + self.core_op = core_op + self.n_cotangents = n_cotangents + + def make_node(self, *inputs): + forward_inputs = inputs[: -self.n_cotangents] + outputs = [inp.type() for inp in forward_inputs if grad_connected(inp)] + return Apply(self, list(inputs), outputs) + + class XTypeCastOp(TypeCastingOp): """Base class for Ops that type cast between TensorType and XTensorType. @@ -120,7 +173,7 @@ def make_node(self, x): def pullback(self, inputs, outs, g_outs): [x] = inputs [g_out] = g_outs - return [rename(g_out, dims=x.type.dims)] + return [type(self)(x.type.dims)(g_out)] def vectorize_node(self, node, new_x, new_dim): [old_x] = node.inputs diff --git a/pytensor/xtensor/rewriting/basic.py b/pytensor/xtensor/rewriting/basic.py index 364a5c9965..24ad6ef6f4 100644 --- a/pytensor/xtensor/rewriting/basic.py +++ b/pytensor/xtensor/rewriting/basic.py @@ -1,14 +1,25 @@ +from pytensor.gradient import DisconnectedType, pullback from pytensor.graph import node_rewriter +from pytensor.graph.basic import clone_get_equiv +from pytensor.graph.rewriting.utils import rewrite_graph +from pytensor.graph.traversal import ancestors, graph_inputs from pytensor.tensor.basic import register_infer_shape from pytensor.tensor.rewriting.basic import register_canonicalize, register_useless from pytensor.xtensor.basic import ( + LazyGrad, Rename, TensorFromXTensor, + XOp, XTensorFromTensor, + grad_connected, xtensor_from_tensor, ) from pytensor.xtensor.random.type import RNGToXRNG, XRNGToRNG -from pytensor.xtensor.rewriting.utils import register_lower_xtensor +from pytensor.xtensor.rewriting.utils import ( + register_lower_lazy_grad, + register_lower_xtensor, +) +from pytensor.xtensor.shape import zeros_like @register_infer_shape @@ -85,3 +96,52 @@ def useless_xrng_to_rng(fgraph, node): [x] = node.inputs if x.owner and isinstance(x.owner.op, RNGToXRNG): return [x.owner.inputs[0]] + + +@register_lower_lazy_grad +@node_rewriter(tracks=[LazyGrad]) +def expand_lazy_grad(fgraph, node): + """Differentiate an XOp by lowering it to tensor ops and taking their pullback. + + Runs before lower_xtensor: the differentiated op (``core_op``) is rebuilt on fresh + stand-ins and lowered to tensor ops in isolation, then differentiated with the + ordinary tensor pullback. Stand-ins (rather than the real inputs) give a repeated + input separate per-slot cotangents, and survive the lowering of the conversion ops + that the real inputs would be folded into. + """ + op = node.op + forward_inputs = node.inputs[: -op.n_cotangents] + cotangents = node.inputs[-op.n_cotangents :] + + dummies = [inp.type() if grad_connected(inp) else inp for inp in forward_inputs] + lowered = rewrite_graph( + list(op.core_op.make_node(*dummies).outputs), + include=("lower_lazy_grad", "lower_xtensor"), + ) + if any(isinstance(var.owner.op, XOp) for var in ancestors(lowered) if var.owner): + raise NotImplementedError(f"pullback not implemented for {op.core_op}") + + memo = {d: inp for d, inp in zip(dummies, forward_inputs) if grad_connected(inp)} + input_grads = pullback( + lowered, + list(memo), + cotangents, + disconnected_inputs="ignore", + return_disconnected="disconnected", + ) + # The lowering and pullback above built nodes inside throwaway FunctionGraphs. Re-clone + # the grad into fresh nodes so it imports into the main graph through the normal path, + # grafting the real inputs back in place of the stand-ins. Real variables the grad + # already shares (the node inputs and any value the gradient reuses) are kept as-is. + keep = list(node.inputs) + [ + v + for v in graph_inputs(input_grads, blockers=node.inputs) + if v not in memo and v not in set(node.inputs) + ] + equiv = clone_get_equiv(keep, input_grads, copy_inputs=False, memo=dict(memo)) + # An input the cost doesn't reach through this node contributes a zero (its other + # paths are summed in by the grad engine); a node output can't be DisconnectedType. + return [ + zeros_like(inp) if isinstance(grad.type, DisconnectedType) else equiv[grad] + for grad, inp in zip(input_grads, memo.values()) + ] diff --git a/pytensor/xtensor/rewriting/utils.py b/pytensor/xtensor/rewriting/utils.py index b3d6433658..0b15420bd4 100644 --- a/pytensor/xtensor/rewriting/utils.py +++ b/pytensor/xtensor/rewriting/utils.py @@ -12,12 +12,26 @@ lower_xtensor_db = EquilibriumDB(ignore_newtrees=False) +# Expanding the lazy gradient Op (LazyGrad) rewrites a whole grad subgraph in one shot, +# which is unsafe to splice into the lower_xtensor equilibrium mid-flight. It runs just +# before it instead, so the expanded grad is lowered by the normal pass like any other. +lower_lazy_grad_db = EquilibriumDB(ignore_newtrees=False) + infer_shape_db.register( "lower_xtensor", lower_xtensor_db, "infer_shape", ) +optdb.register( + "lower_lazy_grad", + lower_lazy_grad_db, + "fast_run", + "fast_compile", + "minimum_compile", + position=0.089, # before lower_xtensor +) + optdb.register( "lower_xtensor", lower_xtensor_db, @@ -64,6 +78,19 @@ def register(inner_rewriter: RewriteDatabase | NodeRewriter): return node_rewriter +def register_lower_lazy_grad(node_rewriter: NodeRewriter, **kwargs): + name = kwargs.pop("name", None) or node_rewriter.__name__ # type: ignore + lower_lazy_grad_db.register( + name, + node_rewriter, + "fast_run", + "fast_compile", + "minimum_compile", + **kwargs, + ) + return node_rewriter + + def lower_aligned(x: XTensorVariable, out_dims: Sequence[str]) -> TensorVariable: """Lower an XTensorVariable to a TensorVariable so that it's dimensions are aligned with "out_dims".""" inp_dims = {d: i for i, d in enumerate(x.type.dims)} diff --git a/tests/xtensor/test_grad.py b/tests/xtensor/test_grad.py new file mode 100644 index 0000000000..deaa0e6dff --- /dev/null +++ b/tests/xtensor/test_grad.py @@ -0,0 +1,135 @@ +import pytest + + +pytest.importorskip("xarray") +pytestmark = pytest.mark.filterwarnings("error") + +import numpy as np + +import pytensor +import pytensor.tensor as pt +import pytensor.xtensor as px +from pytensor.graph import rewrite_graph +from pytensor.xtensor.type import as_xtensor +from tests.unittest_tools import verify_grad + + +def grad_through_lowering(cost, wrt): + """Reference: lower the xtensor graph to tensor ops, then take the gradient.""" + cost = rewrite_graph(cost, include=("lower_xtensor",), clone=True) + return pt.grad(cost, wrt) + + +def _x(): + xt = pt.tensor("x", shape=(3, 4)) + return xt, as_xtensor(xt, dims=("a", "b")) + + +def _y(): + yt = pt.tensor("y", shape=(4, 2)) + return yt, as_xtensor(yt, dims=("b", "c")) + + +def build_cases(): + xt, x = _x() + yt, y = _y() + return [ + ("reduce_sum", (px.math.exp(x).sum("a") * 1.5).sum(), [xt]), + ("reduce_mean_std", (x.mean("a") + x.std("a")).sum(), [xt]), + ("reduce_max", (x.max("a") * 1.5).sum(), [xt]), + ("reduce_min", (x.min("a") * 1.5).sum(), [xt]), + ("cumsum", px.math.exp(x).cumsum("a").sum(), [xt]), + ("elemwise", (px.math.tanh(x) * px.math.sin(x)).sum(), [xt]), + ("transpose", (x.transpose("b", "a") ** 2).sum(), [xt]), + ("concat", px.concat([x, x + 1.0], dim="a").sum(), [xt]), + ("stack", px.math.exp(x).stack({"z": ("a", "b")}).sum(), [xt]), + ("rename", (x.rename({"a": "a2"}) ** 2).sum(), [xt]), + # Swapping names exercises Rename as a positional relabel (not a permutation). + ("rename_swap", (x.rename({"a": "b", "b": "a"}).sum("a") ** 2).sum(), [xt]), + ("dot", (px.dot(x, y, dim="b") ** 2).sum(), [xt, yt]), + ] + + +# Differentiating mean/std lowers to several Shape(x) views that duplicate the +# forward's; merging those duplicates while the destroy handler is attached leaves its +# client bookkeeping inconsistent, which `on_opt_error=raise` (used in tests) turns +# fatal. The gradient itself is correct, so the failure is in graph optimization only. +_XFAIL_DESTROY_HANDLER = {"reduce_mean_std"} + + +@pytest.mark.parametrize( + "loss, wrt", + [ + pytest.param( + loss, + wrt, + id=name, + marks=pytest.mark.xfail( + reason="merging duplicated Shape views upsets the destroy handler", + strict=True, + ) + if name in _XFAIL_DESTROY_HANDLER + else (), + ) + for name, loss, wrt in build_cases() + ], +) +def test_grad_matches_lowering(loss, wrt): + # pt.grad must work directly on the un-lowered xtensor graph and agree with the + # supported "lower first, then grad" path. + rng = np.random.default_rng(7) + test_vals = [rng.normal(size=w.type.shape).astype(w.type.dtype) for w in wrt] + g_direct = pt.grad(loss.values, wrt) + g_ref = grad_through_lowering(loss.values, wrt) + fn = pytensor.function(wrt, [*g_direct, *g_ref]) + out = fn(*test_vals) + n = len(wrt) + for direct, ref in zip(out[:n], out[n:]): + np.testing.assert_allclose(direct, ref) + + +def test_grad_repeated_input(): + # A repeated input must accumulate per-slot cotangents (no factor-of-N error). + xt = pt.vector("x", shape=(3,)) + x = as_xtensor(xt, dims=("a",)) + x_test = np.array([1.0, 2.0, 3.0]) + for power, loss in [(2, (x * x).sum()), (3, (x * x * x).sum())]: + g = pytensor.function([xt], pt.grad(loss.values, xt))(x_test) + np.testing.assert_allclose(g, power * x_test ** (power - 1)) + + +def test_grad_second_order(): + W = pytensor.shared(np.ones((3, 2)), name="W") + xt = pt.vector("x", shape=(3,)) + x = as_xtensor(xt, dims=("a",)) + y = px.dot(x, as_xtensor(W, dims=("a", "b")), dim="a") + loss = (y * y).sum() + g2 = pt.grad(pt.grad(loss.values, W).sum(), W) + g2_ref = pt.grad(grad_through_lowering(loss.values, W).sum(), W) + direct, ref = pytensor.function([xt], [g2, g2_ref])(np.arange(3.0)) + np.testing.assert_allclose(direct, ref) + + +def test_grad_through_indexing(): + # The index itself is non-differentiable (an integer xtensor) so it gets an + # undefined gradient, but the array input's gradient is still correct: a scatter + # of the cotangent into the indexed positions. + xt = pt.tensor("x", shape=(3, 4)) + x = as_xtensor(xt, dims=("a", "b")) + loss = (x.isel(a=1) ** 2).sum() + grad = pt.grad(loss.values, xt) + x_test = np.arange(12.0).reshape(3, 4) + expected = np.zeros((3, 4)) + expected[1] = 2 * x_test[1] + np.testing.assert_allclose(pytensor.function([xt], grad)(x_test), expected) + + +def test_verify_grad(): + rng = np.random.default_rng(seed=420) + + def dot_loss(x, w): + xx = as_xtensor(x, dims=("a",)) + ww = as_xtensor(w, dims=("a", "b")) + return (px.dot(xx, ww, dim="a") ** 2).sum().values + + verify_grad(dot_loss, [rng.normal(size=(3,)), rng.normal(size=(3, 2))], rng=rng)