From 8bce3d61c55817a0974f9fe004e6f81d45dd7ce4 Mon Sep 17 00:00:00 2001 From: Tamme Claus Date: Sun, 7 Dec 2025 01:26:39 +0100 Subject: [PATCH 1/2] create contiguous views also for views with Base.Slice --- src/host/base.jl | 8 +++++--- test/testsuite/base.jl | 8 ++++++++ 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/src/host/base.jl b/src/host/base.jl index eb381ebba..dfb6139eb 100644 --- a/src/host/base.jl +++ b/src/host/base.jl @@ -278,9 +278,11 @@ struct NonContiguous end GPUIndexStyle() = Contiguous() GPUIndexStyle(I...) = NonContiguous() GPUIndexStyle(::Union{Base.ScalarIndex, CartesianIndex}...) = Contiguous() -GPUIndexStyle(i1::Colon, ::Union{Base.ScalarIndex, CartesianIndex}...) = Contiguous() -GPUIndexStyle(i1::AbstractUnitRange, ::Union{Base.ScalarIndex, CartesianIndex}...) = Contiguous() -GPUIndexStyle(i1::Colon, I...) = GPUIndexStyle(I...) +GPUIndexStyle(::Colon, ::Union{Base.ScalarIndex, CartesianIndex}...) = Contiguous() +GPUIndexStyle(::Base.Slice, ::Union{Base.ScalarIndex, CartesianIndex}...) = Contiguous() +GPUIndexStyle(::AbstractUnitRange, ::Union{Base.ScalarIndex, CartesianIndex}...) = Contiguous() +GPUIndexStyle(::Colon, I...) = GPUIndexStyle(I...) +GPUIndexStyle(::Base.Slice, I...) = GPUIndexStyle(I...) viewlength() = () @inline viewlength(::Real, I...) = viewlength(I...) # skip scalar diff --git a/test/testsuite/base.jl b/test/testsuite/base.jl index 6bcfd5b4d..5a19bf13d 100644 --- a/test/testsuite/base.jl +++ b/test/testsuite/base.jl @@ -333,11 +333,19 @@ end end @test compare(x->view(x, :, 1:4, 3), AT, rand(Float32, 5, 4, 3)) + @test compare(x->view(x, Base.Slice(Base.OneTo(5)), 1:4, 3), AT, rand(Float32, 5, 4, 3)) let x = AT(rand(Float32, 5, 4, 3)) @test_throws BoundsError view(x, :, :, 1:10) end + @test compare(x -> selectdim(x, 3, 1), AT, rand(Float32, 2, 2, 2)) + let x = AT(rand(Float32, 5, 4, 3)) + @test typeof(view(x, :, :, 1:2)) == typeof(view(x, :, :, Base.Slice(Base.OneTo(2)))) + @test typeof(view(x, :, :, 1)) == typeof(view(x, :, Base.Slice(Base.OneTo(4)), 1)) + @test typeof(selectdim(x, 3, 1)) == typeof(view(x, :, :, 1)) + end + # bug in parentindices conversion let x = AT{Int}(undef, 1, 1) x[1,:] .= 42 From 8b5a985ecfa079cd95cdf012959ac6a21738d833 Mon Sep 17 00:00:00 2001 From: Tamme Claus Date: Sun, 7 Dec 2025 01:47:12 +0100 Subject: [PATCH 2/2] fix tests (for array) --- test/testsuite/base.jl | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/test/testsuite/base.jl b/test/testsuite/base.jl index 5a19bf13d..8a5ca855e 100644 --- a/test/testsuite/base.jl +++ b/test/testsuite/base.jl @@ -339,11 +339,12 @@ end @test_throws BoundsError view(x, :, :, 1:10) end - @test compare(x -> selectdim(x, 3, 1), AT, rand(Float32, 2, 2, 2)) - let x = AT(rand(Float32, 5, 4, 3)) - @test typeof(view(x, :, :, 1:2)) == typeof(view(x, :, :, Base.Slice(Base.OneTo(2)))) - @test typeof(view(x, :, :, 1)) == typeof(view(x, :, Base.Slice(Base.OneTo(4)), 1)) - @test typeof(selectdim(x, 3, 1)) == typeof(view(x, :, :, 1)) + @testset "selectdim" begin + @test compare(x -> selectdim(x, 3, 1), AT, rand(Float32, 2, 2, 2)) + let x = AT(rand(Float32, 5, 4, 3)) + @test typeof(selectdim(x, 3, 1)) == typeof(view(x, :, :, 1)) + @test typeof(selectdim(x, 2, 1)) == typeof(view(x, :, 1, :)) + end end # bug in parentindices conversion