Deduplicate Max/Min reduction gradient and add Min forward-mode#2270
Open
cetagostini wants to merge 1 commit into
Open
Deduplicate Max/Min reduction gradient and add Min forward-mode#2270cetagostini wants to merge 1 commit into
cetagostini wants to merge 1 commit into
Conversation
The `Min` Op pullback was recently added (dc503f1) by copy-pasting `Max.pullback`. This deduplicates that logic into a shared `MaxAndMinCAReduce` base class so `Max` and `Min` cannot drift apart, and fills the remaining forward-mode (`pushforward`) gap. - Move the shared `clone` and `pullback` onto a `MaxAndMinCAReduce` base; `Max`/`Min` become thin subclasses binding only their scalar op. The reduction subgradient is identical for both (the cotangent routes to the position(s) where the input equals the reduced output), so a single source of truth is the natural design. - Replace `Max`'s matrix-only, single-axis `pushforward` (which raised `NotImplementedError` for >2D / multi-axis inputs and was *not* the adjoint of its own pullback on ties) with a general `(eq(out_pad, x) * x_dot).sum(axis)`. This is the exact transpose of the pullback, works for any ndim/axis, and gives `Min` forward-mode for free. Adds forward/reverse parity tests for the `Min` Op and a deterministic tie-consistency test asserting the Op `pushforward` equals the pullback transpose at ties (random data is tie-free almost surely). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
480538e to
b5b190d
Compare
Min via shared Max/Min reduction base
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
This builds on the recently-merged #2253 /
dc503f117("Add gradient (pullback) toMinOp"), which fixed the missingMingradient by copy-pastingMax.pullbackintoMin. This PR:Deduplicates that gradient logic into a shared base class
MaxAndMinCAReduce(NonZeroDimsCAReduce), soMaxandMincannot drift apart. The reduction subgradient is identical for both — the cotangent routes to the position(s) where the input equals the reduced output — so a single source of truth is the natural design.Max/Minbecome thin subclasses binding only their scalar op.Adds the forward-mode (
pushforward/ R_op) thatMinstill lacks, and generalizesMax's. The oldMax.pushforwardwas matrix-only (raisedNotImplementedErrorfor >2D or multi-axis inputs) and — via itsArgmax-based single-winner logic — was not the adjoint of its own pullback at ties. The new formis the exact transpose (adjoint) of the
pullback, supports arbitrary ndim/axis, and givesMinforward-mode for free. The adjoint identity⟨J·ẋ, ḡ⟩ = ⟨ẋ, Jᵀ·ḡ⟩now holds on ties (it previously did not).Tests
TestPushforwardPullback.test_min— forward/reverse-mode parity withtest_max, exercising the bareMinop (whichpt.minnever instantiates, since it lowers to-max(-x)).TestPushforwardPullback.test_max_min_pushforward_on_ties— deterministic tied-input test pinning the tie-summing convention and the adjoint identity (random data is tie-free almost surely, so it would not guard this).The reverse-mode regression is already covered by
TestMinMax::test_grad_Min(added in #2253), which now runs through the shared base class. ExistingMax/Min/CAReduce/uncanonicalizesuites pass; gradients verified on the C, Numba, JAX, and MLX backends. Op identity (equality, hashing, pickling,__props__,__str__) is preserved.Checklist
🤖 Generated with Claude Code