diff --git a/Project.toml b/Project.toml index 6646343..c8cb451 100644 --- a/Project.toml +++ b/Project.toml @@ -3,11 +3,20 @@ uuid = "b58c8408-13c4-4787-8733-7038ae624acf" version = "0.2.1" authors = ["Anton Oresten and contributors"] +[workspace] +projects = ["test", "docs"] + [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" Republic = "27243419-9dde-4721-b67c-fd63626fea7f" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" +[weakdeps] +cuTile = "0dea8319-8c4a-4662-a73d-20234d115b9a" + +[extensions] +cuTileExt = "cuTile" + [compat] Adapt = "4.6.1" Republic = "2.2.0" diff --git a/ext/cuTileExt.jl b/ext/cuTileExt.jl new file mode 100644 index 0000000..d3acc67 --- /dev/null +++ b/ext/cuTileExt.jl @@ -0,0 +1,41 @@ +module cuTileExt + +using BitPacking: NarrowArray, bitwidth + +import Adapt +import cuTile as ct + +struct ReinterpretTileArray{T,N,A<:ct.TileArray{UInt8,N}} <: ct.AbstractTileArray{T,N} + parent::A +end + +Base.parent(arr::ReinterpretTileArray) = arr.parent +Base.eltype(::ReinterpretTileArray{T}) where T = T +Base.ndims(::ReinterpretTileArray{T,N}) where {T,N} = N + +function Base.size(arr::ReinterpretTileArray, i::Integer) + ratio = 8 ÷ bitwidth(eltype(arr)) + return i == 1 ? size(parent(arr), i) * ratio : size(parent(arr), i) +end +Base.size(arr::ReinterpretTileArray) = ntuple(i -> size(arr, i), Val(ndims(arr))) + +function Adapt.adapt_structure(to::ct.KernelAdaptor, arr::NarrowArray) + parent = Adapt.adapt(to, reinterpret(UInt8, arr)) + return ReinterpretTileArray{eltype(arr),ndims(parent),typeof(parent)}(parent) +end + +function ct.store(arr::ReinterpretTileArray, index, tile; kws...) + return ct.store(parent(arr), index, reinterpret(UInt8, tile); kws...) +end + +function ct.load(arr::ReinterpretTileArray, index, shape; kws...) + ratio = 8 ÷ bitwidth(eltype(arr)) + shape′ = ntuple(Val(ndims(arr))) do i + i == 1 ? shape[i] ÷ ratio : shape[i] + end + byte_tile = ct.load(parent(arr), index, shape′; kws...) + tile = reinterpret(eltype(arr), byte_tile) + return tile +end + +end diff --git a/test/Project.toml b/test/Project.toml index 0c36332..9597d47 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,2 +1,3 @@ [deps] +BitPacking = "b58c8408-13c4-4787-8733-7038ae624acf" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"