Implement gradients for xtensor ops via tensor lowering#2269
Implement gradients for xtensor ops via tensor lowering#2269cetagostini wants to merge 5 commits into
Conversation
`pt.grad` through un-lowered xtensor ops (XElemwise, XReduce, Dot, ...) raised `NotImplementedError: pullback not implemented for XReduce`: the xtensor ops implement neither `pullback` nor a legacy `grad`/`L_op`, as they are designed to be lowered to tensor ops (the `lower_xtensor` rewrite) first. Add a generic `XOp.pullback` that does the lowering per node and differentiates through it: it lowers the single node to its tensor-ops equivalent, takes the vector-Jacobian product with the standard `pullback`, and grafts the real inputs back via `graph_replace`. Repeated inputs use fresh distinct per-slot stand-ins so the engine accumulates their cotangents correctly. This mirrors how OpFromGraph/Scan differentiate their inner graphs and reuses the existing TensorFromXTensor/XTensorFromTensor pullbacks. Also fix `Rename.pullback`, which misused the `rename()` keyword API (`rename(g_out, dims=...)` renamed a dim literally named "dims") and crashed for any `.rename()` in the grad path -- previously unreachable because `pt.grad` failed earlier at the un-differentiable XOps. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
|
Note Already fixed on current Original (outdated) reportThe xtensor |
Follow-up 2 (pre-existing):
|
|
Note Already fixed on current Original (outdated) reportThe gradient of a sliced index produces an |
Per review: no function-local imports. `pytensor.gradient`, `graph.replace`, `graph.rewriting.utils`, and `graph.traversal` do not import xtensor, so there is no circular-import risk (the latter two are already imported at module level in xtensor/vectorization.py). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
With `Min` now carrying a pullback on main (dc503f1), grad through the xtensor min/max reductions works via the generic XOp.pullback; add them to the direct-vs-lowering comparison. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
|
The Op shouldn't be calling lowering on the pullback. I'd like to explore a lazy grad Op. The point is I want xtensor Ops to have as little logic as possible. There's a PR opened where we explored that a bit. This approach is a no-go |
Address review: XOp.pullback no longer lowers. It wraps the inputs and output cotangents in a thin LazyGrad XOp; the expand_lazy_grad rewrite (a pass that runs just before lower_xtensor) differentiates it by lowering core_op to tensor ops and taking their pullback, so no XOp runs lowering inside its own pullback. Integer xtensor inputs (e.g. indices) get an undefined gradient instead of disconnected, which drops the spurious connection_pattern warning. reduce_mean_std is xfailed: differentiating mean/std produces duplicated Shape views whose merge upsets the destroy handler under on_opt_error=raise; the gradient values themselves are correct. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
|
Hey @ricardoV94 — I ran your "lazy grad Op" suggestion through my bot and got something working, dropping it here 👇 Reworked so Status
The one holdout (xfailed for now): Two ways I can see to close it:
Does the lazy-grad shape look right to you before I push further? And any preference between 1 and 2? 🙏 (explored + written with Claude Code) |
|
When I say lazygrad (better called pullback or push forward) I don't mean a per Op one btw |
|
Also please don't reply to me with llm generated messages. Or I'll not reply to them |
Problem
pt.gradthrough un-lowered xtensor ops fails on every backend:XElemwise,XReduce,Dot,Index, the shape ops, … implement neitherpullbacknor alegacy
grad/L_op. They are designed to be lowered to ordinary tensor ops (thelower_xtensorrewrite) first, after which the normal tensor gradient rules apply — so thedocumented workaround is to
rewrite_graph(loss.values, include=("lower_xtensor",))beforept.grad. Un-lowered xtensor graphs simply have nothing to differentiate.Fix
Add a single generic
pullbackto the baseXOpclass that performs that lowering per nodeand differentiates through it:
y * yproduces separate per-slot cotangents that the engine accumulates); structuralinputs (slices, rngs) have no
dtypeand are kept as is;lower_xtensorrewrite;pullback, thengraph_replacethe realinputs back in.
This mirrors how
OpFromGraph/Scandifferentiate their inner graphs, and reuses theconversion ops' existing pullbacks (
TensorFromXTensor/XTensorFromTensor). No per-op gradientcode and no duplication of the tensor gradient rules — every xtensor op becomes differentiable
through its own lowering.
It also fixes
Rename.pullback, which misused therename()keyword API(
rename(g_out, dims=...)tried to rename a dim literally named"dims") and crashed for any.rename()in the grad path. This was previously unreachable becausept.gradfailed earlierat the un-differentiable XOps.
Example
With the fix, the snippet above returns the gradient directly, matching the lower-then-grad
reference and finite differences (
verify_grad):tests/xtensor/test_grad.py(new) covers it: a parametrized direct-grad-vs-lowering comparisonover reduce/mean+std/cumsum/elemwise/transpose/concat/stack/rename/swap-rename/dot, plus
repeated-input accumulation, second-order, indexing, and a
verify_gradfinite-differencecheck. The tests are RED on
main(NotImplementedError) and GREEN with the fix; fulltests/xtensor/suite is green.Notes
"implement a
connection_pattern" advisory. A generic dtype-basedconnection_patternisunsound (it would mark integer operands like the
2inx * 2as disconnected) and acoarse one would turn multi-output
Broadcastpartial-grads into hard errors, so it is leftas the benign advisory; the gradients themselves are correct.
comments below with self-contained reproducers on current
main(happy to open separateissues).
AI assistance: this change was developed with Claude Code, including an adversarial
multi-reviewer pass; all gradients were validated against the lowering reference and
verify_grad.