diff --git a/src/TensorStoreWrapper.jl b/src/TensorStoreWrapper.jl index 160fc34..da888a4 100644 --- a/src/TensorStoreWrapper.jl +++ b/src/TensorStoreWrapper.jl @@ -150,6 +150,12 @@ function Base.setindex!(w::TensorStoreWrapper, v, indices...; kwargs...) end Base.size(w::TensorStoreWrapper) = pyconvert(Tuple, parent(w).shape) +function Base.size(w::TensorStoreWrapper, d::Integer) + d < 1 && throw(ArgumentError("dimension must be ≥ 1")) + d > ndims(w) && return 1 + di = Base.to_index(d) + return pyconvert(Int, parent(w).shape[di-1]) +end Base.ndims(w::TensorStoreWrapper) = pyconvert(Int, parent(w).rank) const TS_TYPE_MAP = Dict( @@ -228,6 +234,12 @@ end # IndexDomainWrapper methods Base.size(w::IndexDomainWrapper) = pyconvert(Tuple, parent(w).shape) +function Base.size(w::IndexDomainWrapper, d::Integer) + d < 1 && throw(ArgumentError("dimension must be ≥ 1")) + d > ndims(w) && return 1 + di = Base.to_index(d) + return pyconvert(Int, parent(w).shape[di-1]) +end Base.ndims(w::IndexDomainWrapper) = pyconvert(Int, parent(w).rank) function Base.axes(w::IndexDomainWrapper) rank = ndims(w) diff --git a/test/runtests.jl b/test/runtests.jl index 4068501..21727b4 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -23,6 +23,10 @@ using PythonCall @test eltype(w) == Int32 @test ndims(w) == 2 @test size(w) == (10, 20) + @test size(w, 1) == 10 + @test size(w, 2) == 20 + @test size(w, 3) == 1 + @test_throws ArgumentError size(w, 0) @test axes(w) == (1:10, 1:20) end @@ -44,6 +48,10 @@ using PythonCall # Labeled indexing sub_w = w[x=1:5, y=11:15] @test size(sub_w) == (5, 5) + @test size(domain, 1) == 10 + @test size(domain, 2) == 20 + @test size(domain, 3) == 1 + @test_throws ArgumentError size(domain, 0) @test axes(sub_w) == (1:5, 11:15) # translate_by