diff --git a/docs/src/index.md b/docs/src/index.md index f561474..51d40ea 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -18,6 +18,7 @@ BitPacking.bitwidth ```@docs NArray NarrowArray +Narrow ``` ## Narrow tuples diff --git a/src/BitPacking.jl b/src/BitPacking.jl index ea8caa6..c856faa 100644 --- a/src/BitPacking.jl +++ b/src/BitPacking.jl @@ -1,11 +1,12 @@ module BitPacking -import Adapt using Republic +import Adapt @public bitwidth export NArray, NVector, NMatrix export NarrowArray, NarrowVector, NarrowMatrix +export Narrow export NarrowTuple, @NarrowTuple @public Pad, ZeroPad, OnePad diff --git a/src/NarrowArray.jl b/src/NarrowArray.jl index 4f49801..a7f8c0b 100644 --- a/src/NarrowArray.jl +++ b/src/NarrowArray.jl @@ -137,8 +137,8 @@ function Base.reinterpret(::Type{T}, arr::NarrowArray{S}) where {T,S} end end -function Base.Broadcast.broadcasted(::Type{T}, arr::NarrowArray) where T - isbitstype(T) || return Base.Broadcast.Broadcasted(T, (arr,)) +function Broadcast.broadcasted(::Type{T}, arr::NarrowArray) where T + isbitstype(T) || return Broadcast.Broadcasted(T, (arr,)) L = length(eltype(parent(arr))) chunks = SVector{L,T}.(parent(arr)) @@ -147,11 +147,72 @@ end Base.copy(arr::NarrowArray) = reinterpret(eltype(arr), map(SArray, parent(arr))) +""" + Narrow{T} + +Representation tag for the packed form of logical element type `T`. `Narrow` has +no instances; it exists purely for dispatch: passing `Narrow{T}` selects the +packed [`NarrowArray{T}`](@ref) form where plain `T` selects the unpacked form. + +| operation | with `T` | with `Narrow{T}` | +|:--------------|:--------------------------------|:--------------------------------| +| `reinterpret` | `reinterpret(T, ::NarrowArray)` | `reinterpret(Narrow{T}, data)` | +| broadcast | `T.(::NarrowArray)` | `Narrow{T}.(array)` | + +For broadcast these are value conversions: `T.(narr)` unpacks to dense `T` +values and `Narrow{T}.(array)` packs values into a `NarrowArray{T}`. For +`reinterpret` they are instead bit-preserving views of the same buffer in the two +layouts: `reinterpret(T, narr)` views the packed bits as `T`, while +`reinterpret(Narrow{T}, data)` views an existing array of packed chunks as a +`NarrowArray{T}` without copying. + +`Narrow{T}.(array)` makes the narrowing explicit where `NarrowArray{T}(array)` +hides it; the equivalent in-place form is `dest .= expr` for a preallocated +`NarrowArray{T}` destination. All forms use the default chunk length +`pack_count(T)`, so the leading dimension must be a whole number of chunks. +""" +abstract type Narrow{T} end + +# Pack `dense` (logical values) into `chunks` by reinterpreting each run of `L` +# values along the first dimension as one `NVector{T,L}`. The fused `.=` writes +# straight into the existing `chunks`, so no packed temporary is allocated. +function _pack_into!(chunks, ::Type{T}, ::Val{L}, dense::AbstractArray{S}) where {T,L,S} + if S === T + chunks .= NVector{T,L}.(reinterpret(NTuple{L,T}, dense)) + else + chunks .= NVector{T,L}.(SVector{L,S}.(reinterpret(NTuple{L,S}, dense))) + end + return chunks +end + +# `dest .= expr` materializes the (fused) broadcast once, then packs it directly +# into `dest`'s existing parent at `dest`'s own chunk length `L`. +function Base.copyto!(dest::NarrowArray{T,N,L}, bc::Broadcast.Broadcasted{Nothing}) where {T,N,L} + axes(dest) == axes(bc) || + throw(DimensionMismatch("destination axes $(axes(dest)) do not match broadcast axes $(axes(bc))")) + dense = Broadcast.materialize(Broadcast.broadcasted(bc.f, bc.args...)) + _pack_into!(parent(dest), T, Val(L), dense) + return dest +end + +Base.similar(arr::NarrowArray) = NarrowArray(similar(parent(arr))) + +# `Narrow{T}.(x)` packs the (fused) broadcast `x` into a NarrowArray{T}. Routing +# through the constructor reuses its vectorized, backend-generic packing, so the +# result follows the backend of `x` rather than allocating a host `Array`. +_narrow_broadcast(::Type{T}, x) where T = NarrowArray{T}(Broadcast.materialize(x)) + +Broadcast.broadcasted(::Type{Narrow{T}}, x) where T = _narrow_broadcast(T, x) +Broadcast.broadcasted(::Type{Narrow{T}}, x::NarrowArray) where T = _narrow_broadcast(T, x) + +Base.reinterpret(::Type{Narrow{T}}, arr::AbstractArray) where T = + NarrowArray(reinterpret(narrow_chunk_type(T), arr)) + function Base.print_array(io::IO, arr::NarrowArray) host = Adapt.adapt(Array, arr) if host isa AbstractVecOrMat - return invoke(Base.print_array, Tuple{IO, AbstractVecOrMat}, io, host) + return @invoke Base.print_array(io::IO, host::AbstractVecOrMat) else - return invoke(Base.print_array, Tuple{IO, AbstractArray}, io, host) + return @invoke Base.print_array(io::IO, host::AbstractArray) end end diff --git a/test/runtests.jl b/test/runtests.jl index 02f666f..079b2a3 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -321,4 +321,64 @@ bits(x) = x @test_throws ArgumentError reinterpret(UInt16, narrow) end + @testset "Narrow" begin + values = Bool[1, 0, 1, 0, 1, 0, 1, 0] + + # Narrow{T}.(arr) packs into a NarrowArray{T} + packed = Narrow{Bool}.(values) + @test packed isa NarrowVector{Bool} + @test copy(packed) == values + @test packed == NarrowArray{Bool}(values) + + # the inner expression fuses, narrowing happens at the boundary + fused = Narrow{Bool}.(.!values) + @test fused isa NarrowVector{Bool} + @test copy(fused) == .!values + + # reinterpret(Narrow{T}, bytes) views packed bytes without copying + bytes = UInt8[0x55] + viewed = reinterpret(Narrow{Bool}, bytes) + @test viewed isa NarrowVector{Bool} + @test copy(viewed) == values + bytes[1] = 0x00 + @test copy(viewed) == falses(8) + + # destination broadcasting packs into preallocated narrow storage + dest = similar(NarrowArray{Bool}(values)) + @test dest isa NarrowVector{Bool} + @test size(dest) == size(values) + dest .= values + @test copy(dest) == values + dest .= .!values + @test copy(dest) == .!values + + # float4 round trip via both packing forms + f4_values = Float4_E2M1FN[float4(0x01), float4(0x02), float4(0x03), float4(0x04)] + f4_packed = Narrow{Float4_E2M1FN}.(f4_values) + @test f4_packed isa NarrowVector{Float4_E2M1FN} + @test collect(reinterpret(UInt8, f4_packed)) == UInt8[0x21, 0x43] + @test bits.(copy(reinterpret(Narrow{Float4_E2M1FN}, UInt8[0x21, 0x43]))) == bits.(f4_values) + + # cross-type packing converts before packing, matching the constructor + @test Narrow{Float4_E2M1FN}.(UInt8[0x01, 0x02, 0x03, 0x04]) == f4_packed + + # matrix destination chunks along the first dimension + src = repeat(values, 1, 2) + mat = similar(NarrowArray{Bool}(src)) + mat .= src + @test copy(mat) == src + + # cross-type destination broadcast converts before packing + f4_dest = similar(f4_packed) + f4_dest .= UInt8[0x01, 0x02, 0x03, 0x04] + @test collect(reinterpret(UInt8, f4_dest)) == UInt8[0x21, 0x43] + + # Narrow{T}.(narr) dispatches on a NarrowArray source + @test Narrow{Bool}.(packed) == packed + + # print_array handles arrays beyond vectors and matrices + arr3 = NarrowArray{Bool}(reshape(repeat(values, 4), 8, 2, 2)) + @test contains(sprint(show, MIME("text/plain"), arr3), "NarrowArray") + end + end