Skip to content

Implement gradients for xtensor ops via tensor lowering#2269

Draft
cetagostini wants to merge 5 commits into
pymc-devs:mainfrom
cetagostini:xtensor-grad-lowering
Draft

Implement gradients for xtensor ops via tensor lowering#2269
cetagostini wants to merge 5 commits into
pymc-devs:mainfrom
cetagostini:xtensor-grad-lowering

Conversation

@cetagostini

@cetagostini cetagostini commented Jun 30, 2026

Copy link
Copy Markdown
Contributor

Problem

pt.grad through un-lowered xtensor ops fails on every backend:

import numpy as np, pytensor, pytensor.tensor as pt
import pytensor.xtensor as px
from pytensor.xtensor.type import as_xtensor
pytensor.config.floatX = "float32"

W = pytensor.shared(np.ones((3, 2), "float32"))
x = px.xtensor("x", dims=("a",), shape=(3,))
y = px.dot(x, as_xtensor(W, dims=("a", "b")), dim="a")
loss = (y * y).sum()

pt.grad(loss.values, W)   # main: NotImplementedError: pullback not implemented for XReduce

XElemwise, XReduce, Dot, Index, the shape ops, … implement neither pullback nor a
legacy grad/L_op. They are designed to be lowered to ordinary tensor ops (the
lower_xtensor rewrite) first, after which the normal tensor gradient rules apply — so the
documented workaround is to rewrite_graph(loss.values, include=("lower_xtensor",)) before
pt.grad. Un-lowered xtensor graphs simply have nothing to differentiate.

Fix

Add a single generic pullback to the base XOp class that performs that lowering per node
and differentiates through it:

  1. build fresh distinct stand-ins for the array inputs (so a repeated input such as
    y * y produces separate per-slot cotangents that the engine accumulates); structural
    inputs (slices, rngs) have no dtype and are kept as is;
  2. lower the single node to its tensor-ops equivalent via the public lower_xtensor rewrite;
  3. take the vector-Jacobian product with the standard pullback, then graph_replace the real
    inputs back in.

This mirrors how OpFromGraph/Scan differentiate their inner graphs, and reuses the
conversion ops' existing pullbacks (TensorFromXTensor/XTensorFromTensor). No per-op gradient
code and no duplication of the tensor gradient rules — every xtensor op becomes differentiable
through its own lowering.

It also fixes Rename.pullback, which misused the rename() keyword API
(rename(g_out, dims=...) tried to rename a dim literally named "dims") and crashed for any
.rename() in the grad path. This was previously unreachable because pt.grad failed earlier
at the un-differentiable XOps.

Example

With the fix, the snippet above returns the gradient directly, matching the lower-then-grad
reference and finite differences (verify_grad):

>>> pytensor.function([x], pt.grad(loss.values, W), mode="MLX")(np.arange(3, dtype="float32"))
array([[ 0.,  0.],
       [ 6.,  6.],
       [12., 12.]], dtype=float32)

tests/xtensor/test_grad.py (new) covers it: a parametrized direct-grad-vs-lowering comparison
over reduce/mean+std/cumsum/elemwise/transpose/concat/stack/rename/swap-rename/dot, plus
repeated-input accumulation, second-order, indexing, and a verify_grad finite-difference
check. The tests are RED on main (NotImplementedError) and GREEN with the fix; full
tests/xtensor/ suite is green.

Notes

  • Ops with non-differentiable inputs (indices/slices) emit the engine's standard
    "implement a connection_pattern" advisory. A generic dtype-based connection_pattern is
    unsound (it would mark integer operands like the 2 in x * 2 as disconnected) and a
    coarse one would turn multi-output Broadcast partial-grads into hard errors, so it is left
    as the benign advisory; the gradients themselves are correct.
  • Three pre-existing, out-of-scope issues surfaced while validating this; I'll add them as
    comments below with self-contained reproducers on current main (happy to open separate
    issues).

AI assistance: this change was developed with Claude Code, including an adversarial
multi-reviewer pass; all gradients were validated against the lowering reference and
verify_grad.

`pt.grad` through un-lowered xtensor ops (XElemwise, XReduce, Dot, ...)
raised `NotImplementedError: pullback not implemented for XReduce`: the
xtensor ops implement neither `pullback` nor a legacy `grad`/`L_op`, as
they are designed to be lowered to tensor ops (the `lower_xtensor`
rewrite) first.

Add a generic `XOp.pullback` that does the lowering per node and
differentiates through it: it lowers the single node to its tensor-ops
equivalent, takes the vector-Jacobian product with the standard
`pullback`, and grafts the real inputs back via `graph_replace`. Repeated
inputs use fresh distinct per-slot stand-ins so the engine accumulates
their cotangents correctly. This mirrors how OpFromGraph/Scan
differentiate their inner graphs and reuses the existing
TensorFromXTensor/XTensorFromTensor pullbacks.

Also fix `Rename.pullback`, which misused the `rename()` keyword API
(`rename(g_out, dims=...)` renamed a dim literally named "dims") and
crashed for any `.rename()` in the grad path -- previously unreachable
because `pt.grad` failed earlier at the un-differentiable XOps.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@cetagostini

cetagostini commented Jun 30, 2026

Copy link
Copy Markdown
Contributor Author

Note

Already fixed on current mainMin gained a pullback in dc503f117 (Add gradient (pullback) to Min Op). xtensor .min()/.max() grad now works through this PR's XOp.pullback; I added reduce_min/reduce_max to tests/xtensor/test_grad.py. My original report below was made against a stale fork maindisregard.

Original (outdated) report

The xtensor min reduction lowers to the tensor Min op, which — unlike Max — implemented no gradient, so min-reductions couldn't be differentiated even after lowering.

@cetagostini

Copy link
Copy Markdown
Contributor Author

Follow-up 2 (pre-existing): XTensorVariable has no zeros_like / ones_like

Surfaced while implementing this PR. There are module-level px.zeros_like / px.ones_like,
but the XTensorVariable instance methods are missing. As a consequence,
pt.grad(cost, wrt, return_disconnected="zero") and gradients of an xtensor-valued cost
raise AttributeError (the new XOp.pullback sidesteps this by using
return_disconnected="disconnected", so this PR is unaffected).

Self-contained reproducer on current main:

import pytensor.xtensor as px

x = px.xtensor("x", dims=("a",), shape=(3,))
print(hasattr(x, "zeros_like"))   # False
print(hasattr(x, "ones_like"))    # False
x.zeros_like()                    # AttributeError

Output on main:

hasattr(XTensorVariable, 'zeros_like'): False
hasattr(XTensorVariable, 'ones_like'): False

Likely fix: add zeros_like/ones_like methods on XTensorVariable delegating to the
module-level helpers. Out of scope here — happy to open a separate issue.

@cetagostini

cetagostini commented Jun 30, 2026

Copy link
Copy Markdown
Contributor Author

Note

Already fixed on current main — the MLX IncSubtensor dispatch now coerces slice bounds to Python ints ("MLX slices reject array-typed bounds"). Verified the sliced-index gradient now runs on MLX ([0, 1, 1, 0, 0]). My original report below was made against a stale fork maindisregard.

Original (outdated) report

The gradient of a sliced index produces an IncSubtensor with a slice, which the MLX backend could not compile (ValueError: Slice indices must be integers or None.).

@cetagostini cetagostini requested a review from ricardoV94 June 30, 2026 07:45
@cetagostini cetagostini self-assigned this Jun 30, 2026
cetagostini and others added 3 commits June 30, 2026 10:49
Per review: no function-local imports. `pytensor.gradient`,
`graph.replace`, `graph.rewriting.utils`, and `graph.traversal` do not
import xtensor, so there is no circular-import risk (the latter two are
already imported at module level in xtensor/vectorization.py).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
With `Min` now carrying a pullback on main (dc503f1), grad through the
xtensor min/max reductions works via the generic XOp.pullback; add them to
the direct-vs-lowering comparison.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@ricardoV94

Copy link
Copy Markdown
Member

The Op shouldn't be calling lowering on the pullback. I'd like to explore a lazy grad Op. The point is I want xtensor Ops to have as little logic as possible. There's a PR opened where we explored that a bit. This approach is a no-go

Address review: XOp.pullback no longer lowers. It wraps the inputs and output
cotangents in a thin LazyGrad XOp; the expand_lazy_grad rewrite (a pass that
runs just before lower_xtensor) differentiates it by lowering core_op to tensor
ops and taking their pullback, so no XOp runs lowering inside its own pullback.
Integer xtensor inputs (e.g. indices) get an undefined gradient instead of
disconnected, which drops the spurious connection_pattern warning.

reduce_mean_std is xfailed: differentiating mean/std produces duplicated Shape
views whose merge upsets the destroy handler under on_opt_error=raise; the
gradient values themselves are correct.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@cetagostini

Copy link
Copy Markdown
Contributor Author

Hey @ricardoV94 — I ran your "lazy grad Op" suggestion through my bot and got something working, dropping it here 👇

Reworked so XOp.pullback doesn't touch lowering anymore. It just wraps the inputs + output cotangents into a thin LazyGrad XOp and returns — that's the whole Op method. The real work happens in a rewrite (expand_lazy_grad, its own pass right before lower_xtensor): it lowers the wrapped op to tensor ops and takes their pullback. So the Op stays dumb like you wanted, and nothing calls lowering from inside a pullback.

Status

  • ✅ original repro works, grads numerically correct
  • ✅ 15/16 grad cases + the rest of the xtensor tests green
  • integer xtensor inputs (indices) now get grad_undefined instead of disconnected → drops the spurious connection_pattern warning (and arguably more correct, since the output does depend on the index value)

The one holdout (xfailed for now): reduce_mean_std. Bot traced it pretty far. Differentiating mean/std lowers into several TensorFromXTensor(x) view nodes that duplicate the forward's. When MergeOptimizer merges those while the destroy handler is attached, a forward Shape(x) gets repointed to the grad's copy of x without being re-registered as a client, so a later local_subtensor_shape_constant prune hits del clients[x][Shape]KeyError. Important bit: the grad values are correct — it only blows up under on_opt_error=raise (i.e. in tests). Couldn't reproduce it with plain duplicate-Shape/duplicate-view graphs, so it looks specific to the multi-view pattern the per-node lowering produces.

Two ways I can see to close it:

  1. Avoid the dup views — have the lazy grad reuse the forward's already-lowered conversions instead of re-lowering its own copy (basically "lower the whole graph, then grad", like the supported path). Cleaner but more surgery — the naive version disconnects because lowering folds tensor_from_xtensor(as_xtensor(xt)) → xt.
  2. Treat it as a destroy-handler bookkeeping bug when merging duplicate view vars and fix it upstream (I have the trace).

Does the lazy-grad shape look right to you before I push further? And any preference between 1 and 2? 🙏

(explored + written with Claude Code)

@ricardoV94

Copy link
Copy Markdown
Member

When I say lazygrad (better called pullback or push forward) I don't mean a per Op one btw

@ricardoV94

Copy link
Copy Markdown
Member

Also please don't reply to me with llm generated messages. Or I'll not reply to them

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants