Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ITensorBase"
uuid = "4795dd04-0d67-49bb-8f44-b89c448a1dc7"
version = "0.10.0"
version = "0.10.1"
authors = ["ITensor developers <support@itensor.org> and contributors"]

[workspace]
Expand Down
11 changes: 10 additions & 1 deletion src/abstractnamedtensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,16 @@ dimnametype(type::Type{<:AbstractNamedTensor}) = Any
# Unwrapping the names (named-array interface).
# TODO: Use `IsNamed` trait?
unnamed(a::AbstractNamedTensor) = throw(MethodError(unnamed, a))
unnamed(a::AbstractNamedTensor, inds) = unnamed(aligneddims(a, inds))
function unnamed(a::AbstractNamedTensor, names)
return _permuteddims_to(unnamed(a), getperm(dimnames(a), names))
end
# Function barrier: `unnamed(a)` is abstractly typed, so dispatching on the concrete array here
# makes `ndims` a compile-time constant. Building the permutation as an `ntuple(…, Val(ndims))`
# (an `NTuple{N,Int}`) rather than `Tuple(perm)` (a length-non-inferrable `Tuple{Vararg{Int}}`)
# lets `permuteddims` build a concretely-typed wrapper, roughly halving the permute cost.
@noinline function _permuteddims_to(array::AbstractArray, perm)
return permuteddims(array, ntuple(i -> perm[i], Val(ndims(array))))
end
unname(a::AbstractNamedTensor, inds) = unnamed(aligndims(a, inds))

"""
Expand Down
207 changes: 67 additions & 140 deletions src/broadcast.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
using ..ITensorBase: AbstractNamedTensor, ITensorBase, NamedUnitRange, getperm, inds, name,
named, nameddims, unname, unnamed
using Base.Broadcast: Broadcast as BC, Broadcasted, broadcast_shape, broadcasted,
check_broadcast_shape, combine_axes
using ..ITensorBase: AbstractNamedTensor, ITensorBase, dimnames, named, nameddims, unnamed
using Base.Broadcast: Broadcast as BC, Broadcasted, broadcasted
using TensorAlgebra: TensorAlgebra as TA

abstract type AbstractNamedTensorStyle{N} <: BC.AbstractArrayStyle{N} end
Expand All @@ -22,103 +20,16 @@ end
# `AbstractArray`); without this the default `broadcastable` wraps it in a `Ref`.
BC.broadcastable(a::AbstractNamedTensor) = a

function BC.combine_axes(
a1::AbstractNamedTensor, a_rest::AbstractNamedTensor...
)
return broadcast_shape(axes(a1), combine_axes(a_rest...))
end
function BC.combine_axes(a1::AbstractNamedTensor, a2::AbstractNamedTensor)
return broadcast_shape(axes(a1), axes(a2))
end
BC.combine_axes(a::AbstractNamedTensor) = axes(a)

# The named axes are a `Tuple` of `NamedUnitRange`s. Dispatch the
# name-aware shape combination on that tuple form (the elements are not
# `AbstractUnitRange`s, so Base's positional tuple-shape methods do not apply).
function BC.broadcast_shape(
ax1::Tuple{NamedUnitRange, Vararg{NamedUnitRange}},
ax2::Tuple{NamedUnitRange, Vararg{NamedUnitRange}},
ax_rest::Tuple{NamedUnitRange, Vararg{NamedUnitRange}}...
)
return broadcast_shape(broadcast_shape(ax1, ax2), ax_rest...)
end

function BC.broadcast_shape(
ax1::Tuple{NamedUnitRange, Vararg{NamedUnitRange}},
ax2::Tuple{NamedUnitRange, Vararg{NamedUnitRange}}
)
return promote_shape(ax1, ax2)
end

# Handle scalar values.
function BC.broadcast_shape(
ax1::Tuple{}, ax2::Tuple{NamedUnitRange, Vararg{NamedUnitRange}}
)
return ax2
end
function BC.broadcast_shape(
ax1::Tuple{NamedUnitRange, Vararg{NamedUnitRange}}, ax2::Tuple{}
)
return ax1
end

function Base.promote_shape(
ax1::Tuple{NamedUnitRange, Vararg{NamedUnitRange}},
ax2::Tuple{NamedUnitRange, Vararg{NamedUnitRange}}
)
return set_promote_shape(ax1, ax2)
end

function set_promote_shape(
ax1::Tuple{NamedUnitRange, Vararg{NamedUnitRange, N}},
ax2::Tuple{NamedUnitRange, Vararg{NamedUnitRange, N}}
) where {N}
perm = getperm(ax2, ax1)
ax2_aligned = map(i -> ax2[i], perm)
ax_promoted = promote_shape(unnamed.(ax1), unnamed.(ax2_aligned))
return named.(ax_promoted, name.(ax1))
broadcasted_unnamed(x::Number, names) = x
function broadcasted_unnamed(a::AbstractNamedTensor, names)
# An operand already aligned to the destination names (the first operand always, and the
# common case for the rest) needs no permutation, avoiding a `getperm` allocation and the
# identity `permuteddims` wrapper. Skipping it makes a small add several times slower.
dimnames(a) == names && return unnamed(a)
return unnamed(a, names)
end

# Handle operations like `randn() + randn(2, 2)[i, j]``.
# TODO: Decide if this should be a general definition for `AbstractNamedTensor`,
# or just for `AbstractNamedTensor`.
function set_promote_shape(
ax1::Tuple{}, ax2::Tuple{NamedUnitRange, Vararg{NamedUnitRange}}
)
return ax2
end

# Handle operations like `randn(2, 2)[i, j] + randn()`.
# TODO: Decide if this should be a general definition for `AbstractNamedTensor`,
# or just for `AbstractNamedTensor`.
function set_promote_shape(
ax1::Tuple{NamedUnitRange, Vararg{NamedUnitRange}}, ax2::Tuple{}
)
return ax1
end

function BC.check_broadcast_shape(
ax1::Tuple{NamedUnitRange, Vararg{NamedUnitRange}},
ax2::Tuple{NamedUnitRange, Vararg{NamedUnitRange}}
)
return set_check_broadcast_shape(ax1, ax2)
end

function set_check_broadcast_shape(
ax1::Tuple{Any, Vararg{Any, N}},
ax2::Tuple{Any, Vararg{Any, N}}
) where {N}
perm = getperm(ax2, ax1)
ax2_aligned = map(i -> ax2[i], perm)
check_broadcast_shape(unnamed.(ax1), unnamed.(ax2_aligned))
return nothing
end
set_check_broadcast_shape(ax1::Tuple{}, ax2::Tuple{}) = nothing

broadcasted_unnamed(x::Number, inds) = x
broadcasted_unnamed(a::AbstractNamedTensor, inds) = unnamed(a, inds)
function broadcasted_unnamed(bc::Broadcasted, inds)
return broadcasted(bc.f, Base.Fix2(broadcasted_unnamed, inds).(bc.args)...)
function broadcasted_unnamed(bc::Broadcasted, names)
return broadcasted(bc.f, Base.Fix2(broadcasted_unnamed, names).(bc.args)...)
end

# A bare (unnamed) array operand, used as an allocation prototype so a broadcast
Expand All @@ -129,58 +40,74 @@ unnamed_prototype(arg::AbstractNamedTensor, args...) = unnamed(arg)
unnamed_prototype(arg::Broadcasted, args...) = unnamed_prototype(arg.args..., args...)
unnamed_prototype(arg, args...) = unnamed_prototype(args...)

function Base.similar(bc::Broadcasted{<:AbstractNamedTensorStyle}, elt::Type, ax)
inds_a = name.(ax)
bc_unnamed = broadcasted_unnamed(bc, inds_a)
a_unnamed = similar(bc_unnamed, elt)
return nameddims(a_unnamed, inds_a)
# Skip Base's shape-combination step: named broadcasts don't need the `NamedUnitRange` axis
# machinery. Name compatibility is handled by the per-operand alignment in `broadcasted_unnamed`
# (via `getperm`), and unnamed-shape compatibility by TensorAlgebra.
BC.instantiate(bc::Broadcasted{<:AbstractNamedTensorStyle}) = bc

# The destination dimension names of a broadcast are those of its first named operand.
# Sourcing them here (rather than from `axes(bc)`) keeps the named axes off the hot path.
_dimnames(a::AbstractNamedTensor, args...) = dimnames(a)
_dimnames(bc::Broadcasted, args...) = _dimnames(bc.args..., args...)
_dimnames(_, args...) = _dimnames(args...)
dimnames(bc::Broadcasted) = _dimnames(bc.args...)

# The result element type of a linear combination, from the concrete unnamed leaves at runtime.
# `eltype(::LinearBroadcasted)` uses `Base.promote_op`, which runs a live inference call here
# because the leaves wrap a named tensor's (non-inferrable) backing array, so promote the
# concrete `eltype`s instead.
_lineareltype(a::AbstractArray) = eltype(a)
function _lineareltype(s::TA.ScaledBroadcasted)
return promote_type(typeof(TA.coeff(s)), _lineareltype(TA.unscaled(s)))
end
_lineareltype(s::TA.AddBroadcasted) = promote_type(map(_lineareltype, TA.addends(s))...)

inds(bc::Broadcasted) = name.(axes(bc))
function Base.copy(bc::Broadcasted{<:AbstractNamedTensorStyle})
# We could use:
# ```julia
# elt = combine_eltypes(bc.f, bc.args)
# copyto!(similar(bc, elt), bc)
# ```
# but `combine_eltypes` is based on type inference, which might fail.
# Calling broadcasted on the unnamed arrays reuses the code logic in
# Base.Broadcast for handling cases where type inference fails by determining
# the output element type at runtime with widening.
inds_dest = inds(bc)
bc_unnamed = broadcasted_unnamed(bc, inds_dest)
nms = dimnames(bc)
dest_unnamed = _copy_unnamed(broadcasted_unnamed(bc, nms), unnamed_prototype(bc))
return nameddims(dest_unnamed, nms)
end

# Function barrier: `broadcasted_unnamed` and `unnamed_prototype` produce concretely-typed
# values whose *inferred* types are abstract (the named backing array is abstract), so this
# call re-specializes on the concrete runtime types and everything below is type-stable
# (`eltype(lb)` is now inferrable, no runtime `promote_op`). Inlining the body into `copy`
# instead costs one extra allocation per call.
function _copy_unnamed(bc_unnamed, prototype)
lb = TA.tryflattenlinear(bc_unnamed)
if isnothing(lb)
# Not a linear combination: ordinary fused broadcast.
dest_unnamed = copy(bc_unnamed)
else
# Linear: lower to bipermutedimsopadd!. Allocate from an operand so the
# result keeps the backend, using the backend's result axes (not `lb`'s).
dest_axes = unnamed.(Tuple(axes(bc)))
dest_unnamed = similar(unnamed_prototype(bc), eltype(lb), dest_axes)
copyto!(dest_unnamed, lb)
end
return nameddims(dest_unnamed, inds_dest)
isnothing(lb) && return copy(bc_unnamed)
return copyto!(similar(prototype, eltype(lb)), lb)
end

# `Base.Broadcast.materialize!` otherwise reconstructs the broadcast over `axes(dest)` and
# re-runs `instantiate`, forcing the `NamedUnitRange` axis machinery this style's `instantiate`
# no-op exists to skip (`combine_axes`/`set_promote_shape`). Route straight to `copyto!`, which
# aligns by dimname instead.
function BC.materialize!(
dest::AbstractNamedTensor,
bc::Broadcasted{<:AbstractNamedTensorStyle}
)
copyto!(dest, bc)
return dest
end

function Base.copyto!(
dest::AbstractNamedTensor,
bc::Broadcasted{<:AbstractNamedTensorStyle}
)
dest_unnamed = unnamed(dest)
inds_dest = axes(dest)
bc_unnamed = broadcasted_unnamed(bc, inds_dest)
lb = TA.tryflattenlinear(bc_unnamed)
if isnothing(lb)
# Not a linear combination: ordinary fused broadcast.
copyto!(dest_unnamed, bc_unnamed)
else
# Linear: lower to bipermutedimsopadd! into the existing dest.
copyto!(dest_unnamed, lb)
end
_copyto_unnamed!(unnamed(dest), broadcasted_unnamed(bc, dimnames(dest)))
return dest
end

# Function barrier mirroring `_copy_unnamed`: `unnamed(dest)` and `broadcasted_unnamed`
# have abstract inferred types (the named backing array is abstract), so this call
# re-specializes on the concrete runtime types and the flatten/lower below is type-stable.
function _copyto_unnamed!(dest_unnamed, bc_unnamed)
lb = TA.tryflattenlinear(bc_unnamed)
isnothing(lb) && return copyto!(dest_unnamed, bc_unnamed)
return copyto!(dest_unnamed, lb)
end

# Operator-preserving broadcasting.
#
# An `NamedTensorOperator` broadcasts as itself (it does not peel to its `state`), so
Expand All @@ -190,7 +117,7 @@ end
# - operator ⊗ scalar → operator (`2 .* op` stays an operator),
# - operator ⊗ non-operator tensor → error.
# The `BroadcastStyle(::Type{<:NamedTensorOperator})` mapping and the operator-specific
# `copy` / `similar` (which unwrap, delegate to `NamedTensorStyle`, then rewrap) live in
# `copy` (which unwraps, delegates to `NamedTensorStyle`, then rewraps) live in
# `itensoroperator.jl`, where `NamedTensorOperator` is defined. `*` (contraction) is
# unchanged and still decays to `state`.

Expand Down
6 changes: 0 additions & 6 deletions src/namedtensoroperator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -342,12 +342,6 @@ function Base.copy(bc::Broadcasted{<:NamedTensorOperatorStyle})
return operator(result, cod, dom)
end

function Base.similar(bc::Broadcasted{<:NamedTensorOperatorStyle}, elt::Type, ax)
cod, dom = broadcast_operator_codomain_domain(bc)
result = similar(statebroadcasted(bc), elt, ax)
return operator(result, cod, dom)
end

for f in MATRIX_FUNCTIONS
@eval begin
function Base.$f(a::NamedTensorOperator)
Expand Down
Loading