Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 38 additions & 4 deletions src/gradients.jl
Original file line number Diff line number Diff line change
Expand Up @@ -256,12 +256,46 @@ function finite_difference_gradient(
dir = true) where {T1, T2, T3, T4, fdtype, returntype, inplace}
if typeof(x) <: AbstractArray
df = zero(returntype) .* x
finite_difference_gradient!(
df, f, x, cache, relstep = relstep, absstep = absstep, dir = dir)
df
else
df = zero(cache.c1)
# Scalar x: compute out-of-place to support immutable output types
# (e.g. ArrayPartition{SVector} from SecondOrderODEProblem).
_scalar_gradient_oop(f, x, cache, fdtype, returntype, inplace;
relstep = relstep, absstep = absstep, dir = dir)
end
end

# Out-of-place scalar→vector gradient that never mutates the result,
# so it works even when f returns immutable arrays (SVector, etc.).
function _scalar_gradient_oop(
f, x::Number, cache, fdtype, returntype, inplace;
relstep, absstep, dir)
fx, c1, c2 = cache.fx, cache.c1, cache.c2

if fdtype == Val(:forward)
epsilon = compute_epsilon(Val(:forward), x, relstep, absstep, dir)
_c1 = inplace == Val(true) ? (f(c1, x + epsilon); c1) : f(x + epsilon)
if typeof(fx) != Nothing
@. (_c1 - fx) / epsilon
else
_c2 = inplace == Val(true) ? (f(c2, x); c2) : f(x)
@. (_c1 - _c2) / epsilon
end
elseif fdtype == Val(:central)
epsilon = compute_epsilon(Val(:central), x, relstep, absstep, dir)
_c1 = inplace == Val(true) ? (f(c1, x + epsilon); c1) : f(x + epsilon)
_c2 = inplace == Val(true) ? (f(c2, x - epsilon); c2) : f(x - epsilon)
@. (_c1 - _c2) / (2 * epsilon)
elseif fdtype == Val(:complex) && returntype <: Real
epsilon_complex = eps(real(eltype(x)))
_c1 = inplace == Val(true) ?
(f(c1, x + im * epsilon_complex); c1) : f(x + im * epsilon_complex)
@. imag(_c1) / epsilon_complex
else
fdtype_error(returntype)
end
finite_difference_gradient!(
df, f, x, cache, relstep = relstep, absstep = absstep, dir = dir)
df
end

# vector of derivatives of a vector->scalar map by each component of a vector x
Expand Down
34 changes: 34 additions & 0 deletions test/finitedifftests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,40 @@ complex_cache = FiniteDiff.GradientCache(df, x, Val{:complex})
@test err_func(FiniteDiff.finite_difference_gradient!(df, f, x, complex_cache), df_ref) < 1e-15
end

@time @testset "Gradient of f:scalar->vector with immutable output" begin
# Regression test: finite_difference_gradient with scalar x must work
# when f returns immutable arrays, since the out-of-place API should
# never need to mutate the result. This failed previously because the
# cached version called finite_difference_gradient! which used @. df = ...
# to write into an immutable buffer.
#
# We use a wrapper around a regular Vector that blocks setindex! to
# simulate immutable array types (like StaticArrays.SVector or
# ArrayPartition containing SVectors).
struct ReadOnlyVec{T} <: AbstractVector{T}
data::Vector{T}
end
Base.size(v::ReadOnlyVec) = size(v.data)
Base.getindex(v::ReadOnlyVec, i::Int) = v.data[i]
Base.setindex!(::ReadOnlyVec, _, ::Int) = error("ReadOnlyVec does not support setindex!")
Base.similar(v::ReadOnlyVec) = ReadOnlyVec(zeros(eltype(v), length(v)))
Base.zero(v::ReadOnlyVec) = ReadOnlyVec(zeros(eltype(v), length(v)))
# Out-of-place broadcast returns a plain Vector (like SVector .+ SVector returns SVector)
Base.BroadcastStyle(::Type{<:ReadOnlyVec}) = Broadcast.DefaultArrayStyle{1}()

g(t) = ReadOnlyVec([sin(t), cos(t)])
t0 = 1.0
g_ref = [cos(t0), -sin(t0)]

# Out-of-place cached version (the bug path)
df_template = similar(g(t0))
for fdtype in (Val(:forward), Val(:central))
cache = FiniteDiff.GradientCache(df_template, t0, fdtype, Float64, Val(false))
result = FiniteDiff.finite_difference_gradient(g, t0, cache)
@test err_func(collect(result), g_ref) < 1e-4
end
end

f(df, x) = (df[1] = sin(x); df[2] = cos(x); df)
x = (2π * rand()) * (1 + im)
fx = fill(zero(typeof(x)), 2)
Expand Down
Loading