Skip to content

Deduplicate Max/Min reduction gradient and add Min forward-mode#2270

Open
cetagostini wants to merge 1 commit into
pymc-devs:mainfrom
cetagostini:fix-min-careduce-pullback
Open

Deduplicate Max/Min reduction gradient and add Min forward-mode#2270
cetagostini wants to merge 1 commit into
pymc-devs:mainfrom
cetagostini:fix-min-careduce-pullback

Conversation

@cetagostini

@cetagostini cetagostini commented Jun 30, 2026

Copy link
Copy Markdown
Contributor

Description

This builds on the recently-merged #2253 / dc503f117 ("Add gradient (pullback) to Min Op"), which fixed the missing Min gradient by copy-pasting Max.pullback into Min. This PR:

  1. Deduplicates that gradient logic into a shared base class MaxAndMinCAReduce(NonZeroDimsCAReduce), so Max and Min cannot drift apart. The reduction subgradient is identical for both — the cotangent routes to the position(s) where the input equals the reduced output — so a single source of truth is the natural design. Max/Min become thin subclasses binding only their scalar op.

  2. Adds the forward-mode (pushforward / R_op) that Min still lacks, and generalizes Max's. The old Max.pushforward was matrix-only (raised NotImplementedError for >2D or multi-axis inputs) and — via its Argmax-based single-winner logic — was not the adjoint of its own pullback at ties. The new form

    out_dot = (eq(out_pad, x) * x_dot).sum(axis=axis)

    is the exact transpose (adjoint) of the pullback, supports arbitrary ndim/axis, and gives Min forward-mode for free. The adjoint identity ⟨J·ẋ, ḡ⟩ = ⟨ẋ, Jᵀ·ḡ⟩ now holds on ties (it previously did not).

Tests

  • TestPushforwardPullback.test_min — forward/reverse-mode parity with test_max, exercising the bare Min op (which pt.min never instantiates, since it lowers to -max(-x)).
  • TestPushforwardPullback.test_max_min_pushforward_on_ties — deterministic tied-input test pinning the tie-summing convention and the adjoint identity (random data is tie-free almost surely, so it would not guard this).

The reverse-mode regression is already covered by TestMinMax::test_grad_Min (added in #2253), which now runs through the shared base class. Existing Max/Min/CAReduce/uncanonicalize suites pass; gradients verified on the C, Numba, JAX, and MLX backends. Op identity (equality, hashing, pickling, __props__, __str__) is preserved.

Checklist

  • Explain motivation and context (above)
  • New tests added (forward-mode + tie adjoint-consistency)
  • Pre-commit / ruff clean

🤖 Generated with Claude Code

The `Min` Op pullback was recently added (dc503f1) by copy-pasting
`Max.pullback`. This deduplicates that logic into a shared
`MaxAndMinCAReduce` base class so `Max` and `Min` cannot drift apart, and
fills the remaining forward-mode (`pushforward`) gap.

- Move the shared `clone` and `pullback` onto a `MaxAndMinCAReduce` base;
  `Max`/`Min` become thin subclasses binding only their scalar op. The
  reduction subgradient is identical for both (the cotangent routes to the
  position(s) where the input equals the reduced output), so a single
  source of truth is the natural design.
- Replace `Max`'s matrix-only, single-axis `pushforward` (which raised
  `NotImplementedError` for >2D / multi-axis inputs and was *not* the
  adjoint of its own pullback on ties) with a general
  `(eq(out_pad, x) * x_dot).sum(axis)`. This is the exact transpose of the
  pullback, works for any ndim/axis, and gives `Min` forward-mode for free.

Adds forward/reverse parity tests for the `Min` Op and a deterministic
tie-consistency test asserting the Op `pushforward` equals the pullback
transpose at ties (random data is tie-free almost surely).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@cetagostini cetagostini force-pushed the fix-min-careduce-pullback branch from 480538e to b5b190d Compare June 30, 2026 08:46
@cetagostini cetagostini changed the title Implement gradient for tensor Min via shared Max/Min reduction base Deduplicate Max/Min reduction gradient and add Min forward-mode Jun 30, 2026
@cetagostini cetagostini requested a review from ricardoV94 June 30, 2026 08:51
@cetagostini cetagostini self-assigned this Jun 30, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant