diff --git a/Project.toml b/Project.toml index dc5bf29..617c54b 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "SparseArraysBase" uuid = "0d5efcca-f356-4864-8770-e1ed8d78f208" -version = "0.10.3" +version = "0.10.4" authors = ["ITensor developers and contributors"] [workspace] @@ -35,5 +35,5 @@ LinearAlgebra = "1.10" MapBroadcast = "0.1.5" Random = "1.10" SparseArrays = "1.10" -TensorAlgebra = "0.11, 0.12, 0.13" +TensorAlgebra = "0.14" julia = "1.10" diff --git a/ext/SparseArraysBaseTensorAlgebraExt/SparseArraysBaseTensorAlgebraExt.jl b/ext/SparseArraysBaseTensorAlgebraExt/SparseArraysBaseTensorAlgebraExt.jl index afe9bc0..7808033 100644 --- a/ext/SparseArraysBaseTensorAlgebraExt/SparseArraysBaseTensorAlgebraExt.jl +++ b/ext/SparseArraysBaseTensorAlgebraExt/SparseArraysBaseTensorAlgebraExt.jl @@ -2,7 +2,8 @@ module SparseArraysBaseTensorAlgebraExt using SparseArrays: SparseMatrixCSC using SparseArraysBase: AnyAbstractSparseArray, AnyAbstractSparseMatrix, SparseArrayDOK -using TensorAlgebra: TensorAlgebra, FusionStyle, ReshapeFusion, matricize, unmatricize +using TensorAlgebra: + TensorAlgebra, FusionStyle, ReshapeFusion, bipermutedimsopadd!, matricize, unmatricize struct SparseArrayFusion <: FusionStyle end TensorAlgebra.FusionStyle(::Type{<:AnyAbstractSparseArray}) = SparseArrayFusion() @@ -24,4 +25,35 @@ function TensorAlgebra.unmatricize( return convert(SparseArrayDOK, a) end +# A sparse array can't be wrapped in a `StridedView`, so the generic +# `bipermutedimsopadd!` doesn't apply. Accumulate over a lazily permuted source via +# broadcasting, which dispatches to the sparse broadcast path. `_opadd!` mirrors the +# accumulation in TensorAlgebra's generic method. +function TensorAlgebra.bipermutedimsopadd!( + dest::AnyAbstractSparseArray, op, src::AbstractArray, + perm_codomain, perm_domain, + α::Number, β::Number + ) + perm = (perm_codomain..., perm_domain...) + _opadd!(dest, op, PermutedDimsArray(src, perm), α, β) + return dest +end + +function _opadd!(dest::AbstractArray, op, src::AbstractArray, α, β) + if op === identity + if iszero(β) + dest .= α .* src + else + dest .= β .* dest .+ α .* src + end + else + if iszero(β) + dest .= α .* op.(src) + else + dest .= β .* dest .+ α .* op.(src) + end + end + return dest +end + end diff --git a/test/Project.toml b/test/Project.toml index 58332f4..252666c 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -34,5 +34,5 @@ SparseArrays = "1.10" SparseArraysBase = "0.10" StableRNGs = "1.0.2" Suppressor = "0.2.8" -TensorAlgebra = "0.11, 0.12, 0.13" +TensorAlgebra = "0.14" Test = "<0.0.1, 1"