Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 55 additions & 2 deletions pytensor/xtensor/basic.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
from collections.abc import Sequence

from pytensor.compile.ops import TypeCastingOp
from pytensor.gradient import DisconnectedType, disconnected_type, grad_undefined
from pytensor.graph import Apply, Op
from pytensor.graph.basic import Variable
from pytensor.tensor.type import TensorType
from pytensor.tensor.type import TensorType, continuous_dtypes
from pytensor.xtensor.type import XTensorType, as_xtensor, xtensor


def grad_connected(var: Variable) -> bool:
"""Whether an XOp input can carry a cotangent (a continuous-dtype xtensor)."""
return isinstance(var.type, XTensorType) and var.type.dtype in continuous_dtypes


class XOp(Op):
"""A base class for XOps that shouldn't be materialized"""

Expand All @@ -18,12 +24,59 @@ def perform(self, node, inputs, outputs):
def do_constant_folding(self, fgraph, node):
return False

def pullback(self, inputs, outputs, cotangents):
# XOps carry no gradient of their own. Defer to LazyGrad, which the
# expand_lazy_grad rewrite differentiates by lowering core_op to tensor ops and
# taking their pullback, so no XOp runs lowering inside its own pullback. Discrete
# xtensor inputs (e.g. integer indices) have an undefined gradient; structural
# inputs (slices, rngs) are disconnected.
from pytensor.xtensor.shape import zeros_like

# A disconnected cotangent (no contribution from that output) becomes a zero,
# so LazyGrad never takes a DisconnectedType as an input.
cotangents = [
zeros_like(out) if isinstance(cot.type, DisconnectedType) else cot
for cot, out in zip(cotangents, outputs)
]
grads = iter(
LazyGrad(self, len(outputs))(*inputs, *cotangents, return_list=True)
)
return [
next(grads)
if grad_connected(inp)
else grad_undefined(self, i, inp)
if isinstance(inp.type, XTensorType)
else disconnected_type()
for i, inp in enumerate(inputs)
]

def vectorize_node(
self, node, *new_inputs, new_dim: str | None
) -> Sequence[Variable]:
raise NotImplementedError(f"Vectorized node not implemented for {self}")


class LazyGrad(XOp):
"""Deferred vector-Jacobian product of another XOp.

Wraps the differentiated ``core_op`` with its inputs and the output cotangents. The
``expand_lazy_grad`` rewrite differentiates it by lowering ``core_op`` to tensor ops
and taking their pullback, so no XOp ever runs lowering inside its own pullback.
There is one output per differentiable (continuous-dtype) input.
"""

__props__ = ("core_op", "n_cotangents")

def __init__(self, core_op: Op, n_cotangents: int):
self.core_op = core_op
self.n_cotangents = n_cotangents

def make_node(self, *inputs):
forward_inputs = inputs[: -self.n_cotangents]
outputs = [inp.type() for inp in forward_inputs if grad_connected(inp)]
return Apply(self, list(inputs), outputs)


class XTypeCastOp(TypeCastingOp):
"""Base class for Ops that type cast between TensorType and XTensorType.

Expand Down Expand Up @@ -120,7 +173,7 @@ def make_node(self, x):
def pullback(self, inputs, outs, g_outs):
[x] = inputs
[g_out] = g_outs
return [rename(g_out, dims=x.type.dims)]
return [type(self)(x.type.dims)(g_out)]

def vectorize_node(self, node, new_x, new_dim):
[old_x] = node.inputs
Expand Down
62 changes: 61 additions & 1 deletion pytensor/xtensor/rewriting/basic.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,25 @@
from pytensor.gradient import DisconnectedType, pullback
from pytensor.graph import node_rewriter
from pytensor.graph.basic import clone_get_equiv
from pytensor.graph.rewriting.utils import rewrite_graph
from pytensor.graph.traversal import ancestors, graph_inputs
from pytensor.tensor.basic import register_infer_shape
from pytensor.tensor.rewriting.basic import register_canonicalize, register_useless
from pytensor.xtensor.basic import (
LazyGrad,
Rename,
TensorFromXTensor,
XOp,
XTensorFromTensor,
grad_connected,
xtensor_from_tensor,
)
from pytensor.xtensor.random.type import RNGToXRNG, XRNGToRNG
from pytensor.xtensor.rewriting.utils import register_lower_xtensor
from pytensor.xtensor.rewriting.utils import (
register_lower_lazy_grad,
register_lower_xtensor,
)
from pytensor.xtensor.shape import zeros_like


@register_infer_shape
Expand Down Expand Up @@ -85,3 +96,52 @@ def useless_xrng_to_rng(fgraph, node):
[x] = node.inputs
if x.owner and isinstance(x.owner.op, RNGToXRNG):
return [x.owner.inputs[0]]


@register_lower_lazy_grad
@node_rewriter(tracks=[LazyGrad])
def expand_lazy_grad(fgraph, node):
"""Differentiate an XOp by lowering it to tensor ops and taking their pullback.

Runs before lower_xtensor: the differentiated op (``core_op``) is rebuilt on fresh
stand-ins and lowered to tensor ops in isolation, then differentiated with the
ordinary tensor pullback. Stand-ins (rather than the real inputs) give a repeated
input separate per-slot cotangents, and survive the lowering of the conversion ops
that the real inputs would be folded into.
"""
op = node.op
forward_inputs = node.inputs[: -op.n_cotangents]
cotangents = node.inputs[-op.n_cotangents :]

dummies = [inp.type() if grad_connected(inp) else inp for inp in forward_inputs]
lowered = rewrite_graph(
list(op.core_op.make_node(*dummies).outputs),
include=("lower_lazy_grad", "lower_xtensor"),
)
if any(isinstance(var.owner.op, XOp) for var in ancestors(lowered) if var.owner):
raise NotImplementedError(f"pullback not implemented for {op.core_op}")

memo = {d: inp for d, inp in zip(dummies, forward_inputs) if grad_connected(inp)}
input_grads = pullback(
lowered,
list(memo),
cotangents,
disconnected_inputs="ignore",
return_disconnected="disconnected",
)
# The lowering and pullback above built nodes inside throwaway FunctionGraphs. Re-clone
# the grad into fresh nodes so it imports into the main graph through the normal path,
# grafting the real inputs back in place of the stand-ins. Real variables the grad
# already shares (the node inputs and any value the gradient reuses) are kept as-is.
keep = list(node.inputs) + [
v
for v in graph_inputs(input_grads, blockers=node.inputs)
if v not in memo and v not in set(node.inputs)
]
equiv = clone_get_equiv(keep, input_grads, copy_inputs=False, memo=dict(memo))
# An input the cost doesn't reach through this node contributes a zero (its other
# paths are summed in by the grad engine); a node output can't be DisconnectedType.
return [
zeros_like(inp) if isinstance(grad.type, DisconnectedType) else equiv[grad]
for grad, inp in zip(input_grads, memo.values())
]
27 changes: 27 additions & 0 deletions pytensor/xtensor/rewriting/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,26 @@

lower_xtensor_db = EquilibriumDB(ignore_newtrees=False)

# Expanding the lazy gradient Op (LazyGrad) rewrites a whole grad subgraph in one shot,
# which is unsafe to splice into the lower_xtensor equilibrium mid-flight. It runs just
# before it instead, so the expanded grad is lowered by the normal pass like any other.
lower_lazy_grad_db = EquilibriumDB(ignore_newtrees=False)

infer_shape_db.register(
"lower_xtensor",
lower_xtensor_db,
"infer_shape",
)

optdb.register(
"lower_lazy_grad",
lower_lazy_grad_db,
"fast_run",
"fast_compile",
"minimum_compile",
position=0.089, # before lower_xtensor
)

optdb.register(
"lower_xtensor",
lower_xtensor_db,
Expand Down Expand Up @@ -64,6 +78,19 @@ def register(inner_rewriter: RewriteDatabase | NodeRewriter):
return node_rewriter


def register_lower_lazy_grad(node_rewriter: NodeRewriter, **kwargs):
name = kwargs.pop("name", None) or node_rewriter.__name__ # type: ignore
lower_lazy_grad_db.register(
name,
node_rewriter,
"fast_run",
"fast_compile",
"minimum_compile",
**kwargs,
)
return node_rewriter


def lower_aligned(x: XTensorVariable, out_dims: Sequence[str]) -> TensorVariable:
"""Lower an XTensorVariable to a TensorVariable so that it's dimensions are aligned with "out_dims"."""
inp_dims = {d: i for i, d in enumerate(x.type.dims)}
Expand Down
135 changes: 135 additions & 0 deletions tests/xtensor/test_grad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
import pytest


pytest.importorskip("xarray")
pytestmark = pytest.mark.filterwarnings("error")

import numpy as np

import pytensor
import pytensor.tensor as pt
import pytensor.xtensor as px
from pytensor.graph import rewrite_graph
from pytensor.xtensor.type import as_xtensor
from tests.unittest_tools import verify_grad


def grad_through_lowering(cost, wrt):
"""Reference: lower the xtensor graph to tensor ops, then take the gradient."""
cost = rewrite_graph(cost, include=("lower_xtensor",), clone=True)
return pt.grad(cost, wrt)


def _x():
xt = pt.tensor("x", shape=(3, 4))
return xt, as_xtensor(xt, dims=("a", "b"))


def _y():
yt = pt.tensor("y", shape=(4, 2))
return yt, as_xtensor(yt, dims=("b", "c"))


def build_cases():
xt, x = _x()
yt, y = _y()
return [
("reduce_sum", (px.math.exp(x).sum("a") * 1.5).sum(), [xt]),
("reduce_mean_std", (x.mean("a") + x.std("a")).sum(), [xt]),
("reduce_max", (x.max("a") * 1.5).sum(), [xt]),
("reduce_min", (x.min("a") * 1.5).sum(), [xt]),
("cumsum", px.math.exp(x).cumsum("a").sum(), [xt]),
("elemwise", (px.math.tanh(x) * px.math.sin(x)).sum(), [xt]),
("transpose", (x.transpose("b", "a") ** 2).sum(), [xt]),
("concat", px.concat([x, x + 1.0], dim="a").sum(), [xt]),
("stack", px.math.exp(x).stack({"z": ("a", "b")}).sum(), [xt]),
("rename", (x.rename({"a": "a2"}) ** 2).sum(), [xt]),
# Swapping names exercises Rename as a positional relabel (not a permutation).
("rename_swap", (x.rename({"a": "b", "b": "a"}).sum("a") ** 2).sum(), [xt]),
("dot", (px.dot(x, y, dim="b") ** 2).sum(), [xt, yt]),
]


# Differentiating mean/std lowers to several Shape(x) views that duplicate the
# forward's; merging those duplicates while the destroy handler is attached leaves its
# client bookkeeping inconsistent, which `on_opt_error=raise` (used in tests) turns
# fatal. The gradient itself is correct, so the failure is in graph optimization only.
_XFAIL_DESTROY_HANDLER = {"reduce_mean_std"}


@pytest.mark.parametrize(
"loss, wrt",
[
pytest.param(
loss,
wrt,
id=name,
marks=pytest.mark.xfail(
reason="merging duplicated Shape views upsets the destroy handler",
strict=True,
)
if name in _XFAIL_DESTROY_HANDLER
else (),
)
for name, loss, wrt in build_cases()
],
)
def test_grad_matches_lowering(loss, wrt):
# pt.grad must work directly on the un-lowered xtensor graph and agree with the
# supported "lower first, then grad" path.
rng = np.random.default_rng(7)
test_vals = [rng.normal(size=w.type.shape).astype(w.type.dtype) for w in wrt]
g_direct = pt.grad(loss.values, wrt)
g_ref = grad_through_lowering(loss.values, wrt)
fn = pytensor.function(wrt, [*g_direct, *g_ref])
out = fn(*test_vals)
n = len(wrt)
for direct, ref in zip(out[:n], out[n:]):
np.testing.assert_allclose(direct, ref)


def test_grad_repeated_input():
# A repeated input must accumulate per-slot cotangents (no factor-of-N error).
xt = pt.vector("x", shape=(3,))
x = as_xtensor(xt, dims=("a",))
x_test = np.array([1.0, 2.0, 3.0])
for power, loss in [(2, (x * x).sum()), (3, (x * x * x).sum())]:
g = pytensor.function([xt], pt.grad(loss.values, xt))(x_test)
np.testing.assert_allclose(g, power * x_test ** (power - 1))


def test_grad_second_order():
W = pytensor.shared(np.ones((3, 2)), name="W")
xt = pt.vector("x", shape=(3,))
x = as_xtensor(xt, dims=("a",))
y = px.dot(x, as_xtensor(W, dims=("a", "b")), dim="a")
loss = (y * y).sum()
g2 = pt.grad(pt.grad(loss.values, W).sum(), W)
g2_ref = pt.grad(grad_through_lowering(loss.values, W).sum(), W)
direct, ref = pytensor.function([xt], [g2, g2_ref])(np.arange(3.0))
np.testing.assert_allclose(direct, ref)


def test_grad_through_indexing():
# The index itself is non-differentiable (an integer xtensor) so it gets an
# undefined gradient, but the array input's gradient is still correct: a scatter
# of the cotangent into the indexed positions.
xt = pt.tensor("x", shape=(3, 4))
x = as_xtensor(xt, dims=("a", "b"))
loss = (x.isel(a=1) ** 2).sum()
grad = pt.grad(loss.values, xt)
x_test = np.arange(12.0).reshape(3, 4)
expected = np.zeros((3, 4))
expected[1] = 2 * x_test[1]
np.testing.assert_allclose(pytensor.function([xt], grad)(x_test), expected)


def test_verify_grad():
rng = np.random.default_rng(seed=420)

def dot_loss(x, w):
xx = as_xtensor(x, dims=("a",))
ww = as_xtensor(w, dims=("a", "b"))
return (px.dot(xx, ww, dim="a") ** 2).sum().values

verify_grad(dot_loss, [rng.normal(size=(3,)), rng.normal(size=(3, 2))], rng=rng)
Loading