Add Mooncake extension; extend AD conformance suite with NamedTuple and cache-reuse coverage#160
Add Mooncake extension; extend AD conformance suite with NamedTuple and cache-reuse coverage#160yebai wants to merge 12 commits into
Conversation
|
AbstractPPL.jl documentation for PR #160 is available at: |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #160 +/- ##
==========================================
+ Coverage 87.11% 88.20% +1.09%
==========================================
Files 14 15 +1
Lines 784 899 +115
==========================================
+ Hits 683 793 +110
- Misses 101 106 +5 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Mooncake AD-backend extension built on the evaluator interface, with the
shared conformance suite extended to cover NamedTuple inputs and empty-input
arity errors. Squashed from prior incremental commits:
- AbstractPPLMooncakeExt: cache reuse, scalar/vector dispatch, NamedTuple
inputs via VectorEvaluator/NamedTupleEvaluator wrappers; integration test
in test/ext/mooncake.
- Evaluators._ad_output_arity: lift the duplicated `Union{Number,
AbstractVector}` output check from both extensions into one helper that
returns `:scalar` / `:vector` for downstream dispatch.
- Empty-input arity tagging (`Val(:scalar)` / `Val(:vector)`) so the
empty-input fast path raises the same "requires a scalar/vector-valued
function" error as the DI path instead of silently succeeding.
- AbstractPPLTestExt: add `Val(:namedtuple)` group (one ValueCase + one
ErrorCase); tighten regex assertions on the existing arity-mismatch cases.
- check_dims threaded through the inner `prepare` call so AD hot paths can
skip per-call shape checks.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
`_check_ad_input(evaluator, x)` in `Evaluators` replaces the duplicated
`T <: Integer` rejection plus length check that appeared at six AD entry
points (two in the DI extension, four in Mooncake). Compile-time `T`
elision is preserved.
Move `generate_testcases(::Val{:namedtuple})` and
`run_testcases(::Val{:namedtuple})` to sit alongside the `:vector` and
`:edge` definitions so the file reads generate-then-run for all three
groups.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
7232d5b to
f4847f1
Compare
f4847f1 to
771c2e6
Compare
- `:cache_reuse` conformance group in `AbstractPPLTestExt` drives
`value_and_{gradient,jacobian}!!` three times per case against a single
`prepared` evaluator to catch backend cache corruption between calls.
- DI ext sub-environment now also loads `ReverseDiff` and exercises
`AutoReverseDiff(compile=true)` against the conformance suite, covering
the `_prepare_di(::AutoReverseDiff{true}, …)` compiled-tape path.
- Lift the duplicated `value_and_{gradient,jacobian}!!` arity-mismatch
`ArgumentError` strings into shared `Evaluators._throw_*` helpers used by
both the DI and Mooncake extensions.
- `generate_testcases` docstring lists `:namedtuple` and `:cache_reuse`
alongside `:vector` / `:edge` as reserved group keys.
- Trim verbose `check_dims` clarifications in docstrings and
`docs/src/evaluators.md` to one sentence each.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
771c2e6 to
1f8b157
Compare
|
@shravanngoswamii, this should be the final PR of the sequel, which adds a native Mooncake backend. EDIT: There are some minor changes to DI and Mooncake extensions motivated by the DynamicPPL needs. |
Stale manifests cause subtle resolution and loading issues; document the expected `Pkg.update()` step alongside the existing test commands. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Two regressions visible on tiny-model gradients went through the new
AbstractPPL evaluator interface:
- `_check_ad_input` always ran on `value_and_{gradient,jacobian}!!`
entry, even when the evaluator was prepared with `check_dims=false`.
Now dispatch-gated on `VectorEvaluator{CheckInput}`: the `{false}`
overload is a no-op, so the `DimensionMismatch` and integer-rejection
paths are elided from the LLVM IR of the AD hot path.
- `DICache` stored `use_context::Bool` as a runtime field, leaving a
branch in the compiled call selecting the context vs no-context DI
form. `UseContext` is now a type parameter and the branch is resolved
by dispatch via `_di_value_and_{gradient,jacobian}` helpers.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
`Mooncake.value_and_gradient!!(cache, evaluator, x)` reset the evaluator's tangent buffer on every call, even though AbstractPPL discards `∂f` and only surfaces `∂x`. For evaluators that wrap a model with large fields (e.g. a 128-tuple of `Float64`), the zeroing was the dominant per-call overhead at tiny model sizes. Pass `args_to_zero=(false, true)` to the reverse-mode `Mooncake.Cache` path to skip the `∂f` reset while still zeroing the `∂x` buffer. The forward-mode `Mooncake.ForwardCache` doesn't accept the kwarg, so the branch is `isa`-dispatched on the concrete cache type and constant-folds at compile time. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Concise inline notes on:
- `VectorEvaluator{true|false}` callable bodies (shared `T <: Integer`
compile-time elision, and the `{false}` skip of `_check_vector_length`).
- Mooncake ext empty-input and arity-mismatch methods (compile-time dispatch
via `MooncakeCache{…,Nothing}` and `MooncakeCache{:scalar|:vector}` to
avoid runtime branches).
- `args_to_zero=(false, true)` at both Mooncake gradient call sites
(skipping the evaluator's tangent re-zeroing per call — `∂f` is discarded).
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Mooncake was deriving a nested `Tangent{NamedTuple{f::Tangent{...}}}` for
every `VectorEvaluator`/`NamedTupleEvaluator` it received, then walking
that structure on every backward pass. The evaluators are AbstractPPL's
own wrapper types and never appear as a downstream gradient target — the
public API only returns `(value, ∂x)`.
Register `Mooncake.tangent_type(::Type{<:VectorEvaluator}) = NoTangent`
(and the same for `NamedTupleEvaluator`) so the cache carries no tangent
for the user's problem fields. The `args_to_zero=(false, true)` mitigation
and the `_ConstantEvaluator` wrapper from the prior pass are both no
longer needed; the call sites pass `p.evaluator` directly.
Verified on the MWE setup: `Mooncake.Tangent{` count in the prepared cache
type is 0; value and gradient match a direct `logdensity_at(x, state, …)`
call.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Callers who know an equivalent raw `f(x, contexts...) ≡ problem(x)` can pass it via `prepare(AutoMooncake(), problem, x; raw_gradient_target=(f, contexts))`. Mooncake then compiles the tape on the raw call shape with `args_to_zero= (false, true, false, …)` instead of the generic `evaluator(x)` wrapper — sidestepping the fixed-overhead seen on tiny scalar-vector problems. `prepared(x)` still calls `problem(x)`; only the AD entry uses the lowered cache (a new `MooncakeLoweredCache` carries `cache`, `f`, `contexts`, and `args_to_zero`). Scoped strictly to reverse-mode `AutoMooncake` and scalar arity with non-empty input — anything else errors at prepare time. Jacobian on a lowered cache surfaces the existing arity-mismatch error. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…rget on all AD prepare methods
- Collapsed `MooncakeLoweredCache` into `MooncakeCache{A,C,F,CT,AZ}`. The
three new type params default to `Nothing` via the existing constructor;
the lowered-path constructor populates them. Dispatch on `CT<:Tuple`
(excluding the `Nothing` default) picks the lowered AD entry. No new
type, no runtime branching.
- DI extension's `prepare(::AbstractADType, ...)` now accepts
`raw_gradient_target=nothing` and silently ignores it. Same for the
Mooncake NamedTuple `prepare`. Generic user code that passes the kwarg
to non-Mooncake backends (or to the Mooncake NamedTuple path) no longer
hits a MethodError.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- `args_to_zero` was a derived value (`(false, true, false×length(contexts))`) stored as a struct field plus a 5th type parameter. Moved the construction to the AD entry; the tuple constant-folds for any concrete `contexts` arity. Saves one type parameter and one field. - Dropped two trailing comments on `raw_gradient_target=nothing` kwargs (the comment didn't explain WHY — the kwarg name and surrounding context already convey "this is a backend-specific optimization that defaults off"). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
shravanngoswamii
left a comment
There was a problem hiding this comment.
Other than these comments, PR looks good to me. Happy to approve!
| function AbstractPPL.prepare( | ||
| adtype::_MooncakeAD, | ||
| problem, | ||
| x::AbstractVector{<:Real}; |
There was a problem hiding this comment.
I may be missing some Mooncake-specific expectation here, but can we check if this method should accept every AbstractVector?
With a SubArray/view, reverse-mode prepares and runs, but the returned gradient is a Mooncake tangent object instead of a normal vector-like gradient. Forward-mode prepares but then errors at call time, and Jacobian also errorslater because Mooncake wants a dense vector input. See:
julia> using AbstractPPL, ADTypes, Mooncake
julia> xs = view([1.0, 2.0, 3.0], :)
3-element view(::Vector{Float64}, :) with eltype Float64:
1.0
2.0
3.0
julia> p_rev = AbstractPPL.prepare(AutoMooncake(), x -> sum(abs2, x), xs);
julia> val, grad = AbstractPPL.value_and_gradient!!(p_rev, xs);
julia> typeof(grad)
Mooncake.Tangent{@NamedTuple{parent::Vector{Float64}, indices::Mooncake.NoTangent, offset1::Mooncake.NoTangent, stride1::Mooncake.NoTangent}}
julia> p_fwd = AbstractPPL.prepare(AutoMooncakeForward(), x -> sum(abs2, x), xs);
julia> AbstractPPL.value_and_gradient!!(p_fwd, xs)
ERROR: MethodError: no method matching setindex!(::Mooncake.Tangent{@NamedTuple{…}}, ::Float64, ::Int64)
The function `setindex!` exists, but no method is defined for this combination of argument types.
Stacktrace:
[1] _fcache_gradient_seed_tangent(x::SubArray{…}, slot::Int64, cursor::Base.RefValue{…}, dict::IdDict{…})
@ Mooncake ~/.julia/packages/Mooncake/9Ej1i/src/interface.jl:1240
[2] #_fcache_gradient_seed_tangent##0
@ ~/.julia/packages/Mooncake/9Ej1i/src/interface.jl:1284 [inlined]
[3] ntuple
@ ./ntuple.jl:51 [inlined]
[4] _fcache_gradient_seed_tangent
@ ~/.julia/packages/Mooncake/9Ej1i/src/interface.jl:1283 [inlined]
[5] _fcache_gradient_seed_tangent
@ ~/.julia/packages/Mooncake/9Ej1i/src/interface.jl:1206 [inlined]
[6] (::Mooncake.var"#_fcache_gradient_chunked!!##0#_fcache_gradient_chunked!!##1"{Tuple{…}})(lane::Int64)
@ Mooncake ~/.julia/packages/Mooncake/9Ej1i/src/interface.jl:1792
[7] ntuple
@ ./ntuple.jl:19 [inlined]
[8] _fcache_gradient_chunked!!(cache::Mooncake.ForwardCache{…}, input_primals::Tuple{…})
@ Mooncake ~/.julia/packages/Mooncake/9Ej1i/src/interface.jl:1791
[9] value_and_gradient!!
@ ~/.julia/packages/Mooncake/9Ej1i/src/interface.jl:2032 [inlined]
[10] value_and_gradient!!(p::AbstractPPL.Evaluators.Prepared{…}, x::SubArray{…})
@ AbstractPPLMooncakeExt ~/Work/vectorly-ai/AbstractPPL.jl/ext/AbstractPPLMooncakeExt.jl:147
[11] top-level scope
@ REPL[7]:1
Some type information was truncated. Use `show(err)` to see complete types.
julia> p_vec = AbstractPPL.prepare(AutoMooncake(), x -> [x[1] * x[2], x[2] + x[3]], xs);
julia> AbstractPPL.value_and_jacobian!!(p_vec, xs)
ERROR: ArgumentError: value_and_jacobian!! only supports dense vector inputs; got SubArray{Float64, 1, Vector{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}}, true}
Stacktrace:
[1] _validate_jacobian_argument(x::SubArray{Float64, 1, Vector{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}}, true})
@ Mooncake ~/.julia/packages/Mooncake/9Ej1i/src/interface.jl:2411
[2] value_and_jacobian!!
@ ~/.julia/packages/Mooncake/9Ej1i/src/interface.jl:2522 [inlined]
[3] value_and_jacobian!!(p::AbstractPPL.Evaluators.Prepared{…}, x::SubArray{…})
@ AbstractPPLMooncakeExt ~/Work/vectorly-ai/AbstractPPL.jl/ext/AbstractPPLMooncakeExt.jl:207
[4] top-level scope
@ REPL[9]:1
Some type information was truncated. Use `show(err)` to see complete types.
julia> For comparison, the DI path returns a normal Vector{Float64} for the same scalar-gradient case with the view input.
Is this expected from Mooncake for views? If yes, maybe this extension should restrict the Mooncake vector path to dense vectors and throw a clear ArgumentError early. If views are meant to work, then maybe the returned tangent needs to be converted back into the gradient shape expected by AbstractPPL. A small regression test with view([1.0, 2.0, 3.0], :) would help make the intended behavior clear.
| @inline function AbstractPPL.value_and_gradient!!( | ||
| p::Prepared{<:_MooncakeAD,<:NamedTupleEvaluator}, values::NamedTuple | ||
| ) | ||
| val, (_, grad) = Mooncake.value_and_gradient!!(p.cache, p.evaluator, values) | ||
| return (val, grad) | ||
| end |
There was a problem hiding this comment.
Can we check if AbstractPPL's NamedTupleEvaluator shape check should run once before calling Mooncake here?
For vector inputs, this extension calls _check_ad_input before entering Mooncake, so wrong user inputs get AbstractPPL's normal errors. But for NamedTuple inputs, Mooncake validates its cache first for some mismatches. That means prepared(values) and value_and_gradient!!(prepared, values) can surface different errors for the same bad input.
julia> using AbstractPPL, ADTypes, Mooncake
julia> p = AbstractPPL.prepare(AutoMooncake(), vs -> vs.x^2 + sum(abs2, vs.y), (x=0.0, y=zeros(2)),);
julia> p((y=[1.0, 2.0], x=3.0))
ERROR: ArgumentError: Expected the same NamedTuple structure that was used to prepare this evaluator.
Stacktrace:
[1] _assert_namedtuple_shape(e::AbstractPPL.Evaluators.NamedTupleEvaluator{…}, values::@NamedTuple{…})
@ AbstractPPL.Evaluators ~/Work/vectorly-ai/AbstractPPL.jl/src/evaluators/Evaluators.jl:224
[2] (::AbstractPPL.Evaluators.NamedTupleEvaluator{…})(values::@NamedTuple{…})
@ AbstractPPL.Evaluators ~/Work/vectorly-ai/AbstractPPL.jl/src/evaluators/Evaluators.jl:207
[3] (::AbstractPPL.Evaluators.Prepared{…})(x::@NamedTuple{…})
@ AbstractPPL.Evaluators ~/Work/vectorly-ai/AbstractPPL.jl/src/evaluators/Evaluators.jl:40
[4] top-level scope
@ REPL[3]:1
Some type information was truncated. Use `show(err)` to see complete types.
julia> AbstractPPL.value_and_gradient!!(p, (y=[1.0, 2.0], x=3.0))
ERROR: PreparedCacheSpecError:
│ Cached autodiff call has a type mismatch for `x1`.
│ Expected top-level type: @NamedTuple{x::Float64, y::Vector{Float64}}
│ Got top-level type: @NamedTuple{y::Vector{Float64}, x::Float64}
│ Prepared pullback, gradient, derivative, HVP, and Hessian caches must be reused with the same top-level
│ argument types they were prepared with.
└
Stacktrace:
[1] _throw_prepared_cache_spec_error(kind::Symbol, i::Int64, expected::Type, got::Type)
@ Mooncake ~/.julia/packages/Mooncake/9Ej1i/src/interface.jl:1071
[2] macro expansion
@ ~/.julia/packages/Mooncake/9Ej1i/src/interface.jl:1089 [inlined]
[3] _validate_prepared_cache_inputs(specs::Tuple{…}, fx::Tuple{…})
@ Mooncake ~/.julia/packages/Mooncake/9Ej1i/src/interface.jl:1078
[4] #value_and_gradient!!#369
@ ~/.julia/packages/Mooncake/9Ej1i/src/interface.jl:856 [inlined]
[5] value_and_gradient!!
@ ~/.julia/packages/Mooncake/9Ej1i/src/interface.jl:849 [inlined]
[6] value_and_gradient!!(p::AbstractPPL.Evaluators.Prepared{…}, values::@NamedTuple{…})
@ AbstractPPLMooncakeExt ~/Work/vectorly-ai/AbstractPPL.jl/ext/AbstractPPLMooncakeExt.jl:127
[7] top-level scope
@ REPL[4]:1
Some type information was truncated. Use `show(err)` to see complete types.
julia> Same thing happens for missing/renamed fields. For a nested array size mismatch, the evaluator check does fire, so the behavior currently depends on what kind of NamedTuple mismatch happened.
Is the Mooncake cache error the expected API here, or should this behave like prepared(values)? If the AbstractPPL error is preferred, maybe call Evaluators._assert_namedtuple_shape(p.evaluator, values) before Mooncake.value_and_gradient!!.
| "`raw_gradient_target` is only supported for scalar-valued problems." | ||
| ), | ||
| ) | ||
| f, contexts = raw_gradient_target |
There was a problem hiding this comment.
Can we add a small validation for the shape of raw_gradient_target before building the Mooncake cache?
Right now malformed inputs can throw MethodError / BoundsError. For the single-context scalar case, it also reaches the cache/construction path before failing because contexts is not a tuple.
julia> using AbstractPPL, ADTypes, Mooncake
julia> raw(x, offset) = -0.5 * (x[1] - offset)^2;
julia> problem = x -> raw(x, 0.1);
julia> AbstractPPL.prepare(AutoMooncake(), problem, [0.3]; raw_gradient_target=raw)
ERROR: MethodError: no method matching iterate(::typeof(raw))
The function `iterate` exists, but no method is defined for this combination of argument types.
Closest candidates are:
iterate(::Compiler.ForwardToBackedgeIterator, ::Int64)
@ Base ~/.julia/juliaup/julia-1.12.6+0.x64.linux.gnu/share/julia/Compiler/src/typeinfer.jl:584
iterate(::Compiler.ForwardToBackedgeIterator)
@ Base ~/.julia/juliaup/julia-1.12.6+0.x64.linux.gnu/share/julia/Compiler/src/typeinfer.jl:584
iterate(::Base.MethodSpecializations)
@ Base runtime_internals.jl:1662
...
Stacktrace:
[1] indexed_iterate(I::Function, i::Int64)
@ Base ./tuple.jl:165
[2] prepare(adtype::AutoMooncake{…}, problem::Function, x::Vector{…}; check_dims::Bool, raw_gradient_target::Function)
@ AbstractPPLMooncakeExt ~/Work/vectorly-ai/AbstractPPL.jl/ext/AbstractPPLMooncakeExt.jl:104
[3] top-level scope
@ REPL[3]:1
Some type information was truncated. Use `show(err)` to see complete types.
julia> AbstractPPL.prepare(AutoMooncake(), problem, [0.3]; raw_gradient_target=(raw,))
ERROR: BoundsError: attempt to access Tuple{typeof(raw)} at index [2]
Stacktrace:
[1] indexed_iterate(t::Tuple{typeof(raw)}, i::Int64, state::Int64)
@ Base ./tuple.jl:162
[2] prepare(adtype::AutoMooncake{…}, problem::Function, x::Vector{…}; check_dims::Bool, raw_gradient_target::Tuple{…})
@ AbstractPPLMooncakeExt ~/Work/vectorly-ai/AbstractPPL.jl/ext/AbstractPPLMooncakeExt.jl:104
[3] top-level scope
@ REPL[4]:1
Some type information was truncated. Use `show(err)` to see complete types.
julia> AbstractPPL.prepare(AutoMooncake(), problem, [0.3]; raw_gradient_target=(raw, 0.1))
ERROR: MethodError: no method matching (AbstractPPLMooncakeExt.MooncakeCache{:scalar})(::Mooncake.Cache{…}, ::typeof(raw), ::Float64)
The type `AbstractPPLMooncakeExt.MooncakeCache{:scalar}` exists, but no method is defined for this combination of argument types when trying to construct it.
Closest candidates are:
(AbstractPPLMooncakeExt.MooncakeCache{A})(::C, ::F, ::CT) where {A, C, F, CT<:Tuple}
@ AbstractPPLMooncakeExt ~/Work/vectorly-ai/AbstractPPL.jl/ext/AbstractPPLMooncakeExt.jl:39
(AbstractPPLMooncakeExt.MooncakeCache{A})(::C) where {A, C}
@ AbstractPPLMooncakeExt ~/Work/vectorly-ai/AbstractPPL.jl/ext/AbstractPPLMooncakeExt.jl:36
Stacktrace:
[1] prepare(adtype::AutoMooncake{…}, problem::Function, x::Vector{…}; check_dims::Bool, raw_gradient_target::Tuple{…})
@ AbstractPPLMooncakeExt ~/Work/vectorly-ai/AbstractPPL.jl/ext/AbstractPPLMooncakeExt.jl:106
[2] top-level scope
@ REPL[5]:1
Some type information was truncated. Use `show(err)` to see complete types.
julia> I think a clear ArgumentError would be easier to understand.
…egation and raw_gradient_target as unsafe Addresses PR #160 review comments: - Throw a clear `ArgumentError` for non-`DenseVector` inputs instead of letting Mooncake return a shape-incorrect tangent (reverse) or crash inside Mooncake (forward/Jacobian). - Document that NamedTuple input-shape validation is intentionally delegated to Mooncake's `PreparedCacheSpec` to avoid duplicating checks on every AD call. - Add a docstring on the vector `prepare` method describing `raw_gradient_target` as an unsafe escape hatch that bypasses evaluator indirection and shape checks. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Summary
AbstractPPLMooncakeExt: gradient/jacobian via Mooncake with cache reuse,scalar/vector dispatch, and
NamedTupleinputs (wrappingVectorEvaluatorandNamedTupleEvaluator). BumpsAbstractPPLto0.15.AbstractPPLTestExtconformance suite grows two reusable groups:Val(:namedtuple)— value + gradient overNamedTupleinputs, plus anErrorCasefor structure mismatch.Val(:cache_reuse)— three sequentialvalue_and_{gradient,jacobian}!!calls against a single
Preparedto catch backend cache corruption.Evaluators._ad_output_arity(:scalar/:vector) andEvaluators._check_ad_inputfactor out the duplicated output-arity checkand input validation that appeared across both AD extensions; compile-time
T <: Integerelision is preserved.Val(:scalar)/Val(:vector)) so thegradient_prep === nothingarity check still fires for length-0 inputs.