From 966d2acaa4439171e8d3e33bb3779f525d2d5d88 Mon Sep 17 00:00:00 2001 From: ChrisRackauckas-Claude Date: Fri, 8 May 2026 09:51:13 -0400 Subject: [PATCH] fix(zygote ext): return struct-tangent cotangent from `Array(::AbstractVectorOfArray)` The previous `@adjoint Base.Array(VA::AbstractVectorOfArray)` reshaped the flat-array cotangent `y` back into the AbstractVectorOfArray's `.u` layout by `copyto!`'ing into a `recursivecopy(VA)` and returning that as the cotangent. The result is the same AbstractVectorOfArray subtype as `VA`. Under AbstractVectorOfArray's flat scalar `getindex`, downstream Zygote pullbacks that walk the cotangent by row index (e.g. SciMLBase's EnsembleSolution constructor pullback iterating per-trajectory tangents, or DiffEqGPU's `batch_solve` comprehension indexing into `solus[i]`) receive the i-th SCALAR of the underlying linear layout instead of the i-th row. The chain then either drops gradient information silently (scalar where vector expected) or fails with `DimensionMismatch` / `Need an adjoint for constructor` deeper in the trace. Return a NamedTuple `(u = arrarr,)` struct tangent instead, matching the AbstractVectorOfArray field layout. Downstream constructor pullbacks (`VectorOfArray(u)` and `DiffEqArray(u, t)`) now also accept a NamedTuple cotangent with `:u` and forward `y.u` directly. Verified end-to-end on the EnsembleGPUArray reverse-mode chain (Zygote + DiffEqGPU + JLArrays); previously hit `DimensionMismatch` at the EnsembleSolution constructor pullback, now flows through cleanly. Tests added in `test/adjoints.jl`: - Direct shape assertions on the cotangent for both 2D and 3D AbstractVectorOfArray layouts. - End-to-end Zygote-vs-ForwardDiff parity on a loss that walks the row layout. Co-Authored-By: Chris Rackauckas --- ext/RecursiveArrayToolsZygoteExt.jl | 18 +++++++++---- test/adjoints.jl | 42 +++++++++++++++++++++++++++++ 2 files changed, 55 insertions(+), 5 deletions(-) diff --git a/ext/RecursiveArrayToolsZygoteExt.jl b/ext/RecursiveArrayToolsZygoteExt.jl index e4ef4c1f..40453b19 100644 --- a/ext/RecursiveArrayToolsZygoteExt.jl +++ b/ext/RecursiveArrayToolsZygoteExt.jl @@ -44,6 +44,8 @@ end # using linear indexing (which now returns scalar elements for VectorOfArray). if y isa AbstractVectorOfArray (y.u,) + elseif y isa NamedTuple && haskey(y, :u) + (y.u,) else ( [ @@ -68,6 +70,8 @@ end end if y isa AbstractVectorOfArray (y.u, nothing) + elseif y isa NamedTuple && haskey(y, :u) + (y.u, nothing) else ( [ @@ -113,11 +117,15 @@ end @adjoint function Base.Array(VA::AbstractVectorOfArray) adj = let VA = VA - function Array_adjoint(y) - # Return a VectorOfArray so it flows correctly back through VectorOfArray constructor - VA = recursivecopy(VA) - copyto!(VA, y) - return (VA,) + function Array_adjoint(y::AbstractArray{T, N}) where {T, N} + arrarr = [ + [ + y[ntuple(_ -> Colon(), Val(N - 2))..., j, i] + for j in 1:size(y)[end - 1] + ] + for i in 1:size(y)[end] + ] + return ((u = arrarr,),) end end Array(VA), adj diff --git a/test/adjoints.jl b/test/adjoints.jl index fec7f548..00b2d178 100644 --- a/test/adjoints.jl +++ b/test/adjoints.jl @@ -90,3 +90,45 @@ voa_gs, = Zygote.gradient(voa) do x sum(sum.(x.u)) end @test voa_gs isa RecursiveArrayTools.VectorOfArray + +@testset "Base.Array(::AbstractVectorOfArray) cotangent shape" begin + let voa = VectorOfArray([Float64.(1:3), Float64.(4:6), Float64.(7:9)]) + y = Array(voa) + @test size(y) == (3, 3) + _, back = Zygote.pullback(Base.Array, voa) + cot, = back(ones(Float64, size(y))) + @test cot isa NamedTuple + @test haskey(cot, :u) + @test length(cot.u) == length(voa.u) + for i in eachindex(voa.u) + @test cot.u[i] == ones(Float64, length(voa.u[i])) + end + end + + let ntraj = 4, ntime = 5, nstate = 2 + voa = VectorOfArray([ + VectorOfArray([Float64.((j - 1) * nstate .+ (1:nstate)) .+ (i - 0.5) + for j in 1:ntime]) + for i in 1:ntraj + ]) + y = Array(voa) + @test size(y) == (nstate, ntime, ntraj) + _, back = Zygote.pullback(Base.Array, voa) + cot, = back(reshape(collect(Float64, 1:length(y)), size(y))) + @test cot isa NamedTuple + @test length(cot.u) == ntraj + for i in 1:ntraj + @test length(cot.u[i]) == ntime + @test all(length(v) == nstate for v in cot.u[i]) + end + end +end + +@testset "Array(::VectorOfArray) gradient matches ForwardDiff" begin + function row_loss(x) + voa = VectorOfArray([x .* i for i in 1:5]) + sum(abs2, 1.0 .- Array(voa)) + end + x = collect(Float64, 1:5) + @test Zygote.gradient(row_loss, x)[1] == ForwardDiff.gradient(row_loss, x) +end