diff --git a/ext/RecursiveArrayToolsZygoteExt.jl b/ext/RecursiveArrayToolsZygoteExt.jl index e4ef4c1f..40453b19 100644 --- a/ext/RecursiveArrayToolsZygoteExt.jl +++ b/ext/RecursiveArrayToolsZygoteExt.jl @@ -44,6 +44,8 @@ end # using linear indexing (which now returns scalar elements for VectorOfArray). if y isa AbstractVectorOfArray (y.u,) + elseif y isa NamedTuple && haskey(y, :u) + (y.u,) else ( [ @@ -68,6 +70,8 @@ end end if y isa AbstractVectorOfArray (y.u, nothing) + elseif y isa NamedTuple && haskey(y, :u) + (y.u, nothing) else ( [ @@ -113,11 +117,15 @@ end @adjoint function Base.Array(VA::AbstractVectorOfArray) adj = let VA = VA - function Array_adjoint(y) - # Return a VectorOfArray so it flows correctly back through VectorOfArray constructor - VA = recursivecopy(VA) - copyto!(VA, y) - return (VA,) + function Array_adjoint(y::AbstractArray{T, N}) where {T, N} + arrarr = [ + [ + y[ntuple(_ -> Colon(), Val(N - 2))..., j, i] + for j in 1:size(y)[end - 1] + ] + for i in 1:size(y)[end] + ] + return ((u = arrarr,),) end end Array(VA), adj diff --git a/test/adjoints.jl b/test/adjoints.jl index fec7f548..00b2d178 100644 --- a/test/adjoints.jl +++ b/test/adjoints.jl @@ -90,3 +90,45 @@ voa_gs, = Zygote.gradient(voa) do x sum(sum.(x.u)) end @test voa_gs isa RecursiveArrayTools.VectorOfArray + +@testset "Base.Array(::AbstractVectorOfArray) cotangent shape" begin + let voa = VectorOfArray([Float64.(1:3), Float64.(4:6), Float64.(7:9)]) + y = Array(voa) + @test size(y) == (3, 3) + _, back = Zygote.pullback(Base.Array, voa) + cot, = back(ones(Float64, size(y))) + @test cot isa NamedTuple + @test haskey(cot, :u) + @test length(cot.u) == length(voa.u) + for i in eachindex(voa.u) + @test cot.u[i] == ones(Float64, length(voa.u[i])) + end + end + + let ntraj = 4, ntime = 5, nstate = 2 + voa = VectorOfArray([ + VectorOfArray([Float64.((j - 1) * nstate .+ (1:nstate)) .+ (i - 0.5) + for j in 1:ntime]) + for i in 1:ntraj + ]) + y = Array(voa) + @test size(y) == (nstate, ntime, ntraj) + _, back = Zygote.pullback(Base.Array, voa) + cot, = back(reshape(collect(Float64, 1:length(y)), size(y))) + @test cot isa NamedTuple + @test length(cot.u) == ntraj + for i in 1:ntraj + @test length(cot.u[i]) == ntime + @test all(length(v) == nstate for v in cot.u[i]) + end + end +end + +@testset "Array(::VectorOfArray) gradient matches ForwardDiff" begin + function row_loss(x) + voa = VectorOfArray([x .* i for i in 1:5]) + sum(abs2, 1.0 .- Array(voa)) + end + x = collect(Float64, 1:5) + @test Zygote.gradient(row_loss, x)[1] == ForwardDiff.gradient(row_loss, x) +end