diff --git a/src/TensorStoreWrapper.jl b/src/TensorStoreWrapper.jl index 160fc34..4c073d1 100644 --- a/src/TensorStoreWrapper.jl +++ b/src/TensorStoreWrapper.jl @@ -187,6 +187,9 @@ function Base.axes(w::TensorStoreWrapper) max_indices = pyconvert(Vector{Int}, domain.exclusive_max) return Tuple((min_indices[i]+1):max_indices[i] for i in 1:rank) end +Base.axes(w::TensorStoreWrapper, d::Integer) = axes(w)[d] +Base.firstindex(w::TensorStoreWrapper, d::Integer) = first(axes(w, d)) +Base.lastindex(w::TensorStoreWrapper, d::Integer) = last(axes(w, d)) function Base.show(io::IO, w::TensorStoreWrapper) print(io, "TensorStore(", eltype(w), ", rank=", ndims(w), ", shape=", size(w), ")") @@ -235,6 +238,9 @@ function Base.axes(w::IndexDomainWrapper) max_indices = pyconvert(Vector{Int}, parent(w).exclusive_max) return Tuple((min_indices[i]+1):max_indices[i] for i in 1:rank) end +Base.axes(w::IndexDomainWrapper, d::Integer) = axes(w)[d] +Base.firstindex(w::IndexDomainWrapper, d::Integer) = first(axes(w, d)) +Base.lastindex(w::IndexDomainWrapper, d::Integer) = last(axes(w, d)) """ labels(w::IndexDomainWrapper) -> Vector{String} diff --git a/test/runtests.jl b/test/runtests.jl index 4068501..c0b26be 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -24,6 +24,10 @@ using PythonCall @test ndims(w) == 2 @test size(w) == (10, 20) @test axes(w) == (1:10, 1:20) + @test firstindex(w, 1) == 1 + @test lastindex(w, 1) == 10 + @test lastindex(w, 2) == 20 + @test size(w[begin:2:end, begin:2:end]) == (5,10) end @testset "Write & Read Operations" begin @@ -45,6 +49,10 @@ using PythonCall sub_w = w[x=1:5, y=11:15] @test size(sub_w) == (5, 5) @test axes(sub_w) == (1:5, 11:15) + @test firstindex(sub_w, 1) == 1 + @test lastindex(sub_w, 1) == 5 + @test lastindex(sub_w, 2) == 15 + @test size(sub_w[begin:2:end, begin:2:end]) == (3,3) # translate_by tw = PyTensorStore.translate_by(w, 10, 20)