Skip to content

Implement jvp for cumulative logsumexp#3711

Open
obchain wants to merge 1 commit into
ml-explore:mainfrom
obchain:fix/scan-jvp-logcumsumexp
Open

Implement jvp for cumulative logsumexp#3711
obchain wants to merge 1 commit into
ml-explore:mainfrom
obchain:fix/scan-jvp-logcumsumexp

Conversation

@obchain

@obchain obchain commented Jun 17, 2026

Copy link
Copy Markdown
Contributor

Proposed changes

Fixes #3710.

Scan::jvp only implemented the Sum reduction and threw for everything else, so forward-mode differentiation through mx.logcumsumexp raised JVP is not implemented for cumulative prod/min/max. Its vjp was already implemented, so only forward mode was affected.

The jvp of logcumsumexp is the running softmax-weighted sum of the tangents:

d/dt logcumsumexp(x)_k = sum_{i<=k} softmax(x)_i * t_i

This is computed in log space for numerical stability by splitting the tangent into its positive and negative parts, mirroring the structure of the existing LogAddExp vjp. Exclusive scans leave the first element with no inputs (output -inf, locally constant), so its tangent is set to zero — this also avoids an inf - inf in the expression.

cumprod / cummax / cummin jvps are unchanged (still not implemented) and the error message stays accurate.

Before:

>>> mx.jvp(lambda z: mx.logcumsumexp(z), [mx.array([1.,2.,3.])], [mx.ones(3)])
RuntimeError: JVP is not implemented for cumulative prod/min/max

Added a test in test_autograd.py that checks the jvp against an explicit softmax-weighted reference and verifies the jvp/vjp adjoint identity for every combination of the reverse / inclusive flags across axes.

Checklist

Put an x in the boxes that apply.

  • I have read the CONTRIBUTING document
  • I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have updated the necessary documentation (if needed)

@aeiwz aeiwz left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Finding

[P2] Preserve complex tangent values — mlx/primitives.cpp (

mlx/mlx/primitives.cpp

Lines 4308 to 4316 in cba7502

auto log_min = array(finfo(t.dtype()).min, t.dtype());
auto log_abs_t = log(abs(t, stream()), stream());
auto log_t_positive =
where(greater(t, zero, stream()), log_abs_t, log_min, stream());
auto log_t_negative =
where(less(t, zero, stream()), log_abs_t, log_min, stream());
auto masked_scan = [&](const array& log_t) {
return exp(
)

The positive/negative split uses abs(t) and comparisons, which discards the phase of complex tangents. Since logcumsumexp supports complex arrays, JVPs such as tangent [1+1j, 2-1j]
return real magnitudes instead of the expected complex derivative. The implementation should either support complex tangents directly or explicitly reject them. Add a complex JVP test.

Numerically verified expected derivative:

[1+1j, 1.73106-0.462117j]

The current expression produces approximately:

[1.41421+0j, 2.01504+0j]

No other actionable findings.

Scan::jvp only handled the Sum reduction and threw for everything else,
so forward-mode differentiation through mx.logcumsumexp raised. The jvp
is the running softmax-weighted sum of the tangents,

    d/dt logcumsumexp(x)_k = sum_{i<=k} softmax(x)_i * t_i,

computed in log space by splitting the tangent into its positive and
negative parts, mirroring the existing vjp. Exclusive scans leave the
first element with no inputs (output -inf, locally constant), so its
tangent is set to zero, which also avoids an inf - inf there.
@obchain obchain force-pushed the fix/scan-jvp-logcumsumexp branch from cba7502 to 715e141 Compare June 24, 2026 06:26
@obchain

obchain commented Jun 24, 2026

Copy link
Copy Markdown
Contributor Author

Good observation. The positive/negative split is taken directly from the existing LogAddExp vjp in the same function, which makes the same real-tangent assumption — log(abs(t)) plus a sign split can't be defined for a negative real t without it, which is why the split exists.

I checked, and the current vjp is also not phase-correct for complex tangents (the holomorphic adjoint <conj(w), Jv> == <conj(vjp), t> doesn't hold for complex w/t), so this is a pre-existing limitation on the reverse side rather than something this jvp introduces. Making logcumsumexp fully correct for complex tangents would mean changing both the jvp and the vjp together, so I'd prefer to keep that out of this PR and scope it to match the existing vjp (real tangents), which is the common case. Happy to do the complex jvp+vjp as a follow-up if that's wanted.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[BUG] jvp of mx.logcumsumexp is not implemented

2 participants