From c136a58a621131d4b376b69c5ec8cccfbaaa5a26 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Mon, 5 Jan 2026 17:03:23 +0100 Subject: [PATCH] change blocktype of TensorMap to `StridedView` --- src/spaces/homspace.jl | 8 ++++---- src/tensors/tensor.jl | 23 +++++++++++++++-------- 2 files changed, 19 insertions(+), 12 deletions(-) diff --git a/src/spaces/homspace.jl b/src/spaces/homspace.jl index d032a1586..017c6553f 100644 --- a/src/spaces/homspace.jl +++ b/src/spaces/homspace.jl @@ -299,7 +299,7 @@ const StridedStructure{N} = Tuple{NTuple{N, Int}, NTuple{N, Int}, Int} struct FusionBlockStructure{I, N, F₁, F₂} totaldim::Int - blockstructure::SectorDict{I, Tuple{Tuple{Int, Int}, UnitRange{Int}}} + blockstructure::SectorDict{I, StridedStructure{2}} fusiontreelist::Vector{Tuple{F₁, F₂}} fusiontreestructure::Vector{StridedStructure{N}} fusiontreeindices::FusionTreeDict{Tuple{F₁, F₂}, Int} @@ -325,9 +325,9 @@ end F₂ = fusiontreetype(I, N₂) # output structure - blockstructure = SectorDict{I, Tuple{Tuple{Int, Int}, UnitRange{Int}}}() # size, range + blockstructure = SectorDict{I, StridedStructure{2}}() # size, strides, offset fusiontreelist = Vector{Tuple{F₁, F₂}}() - fusiontreestructure = Vector{Tuple{NTuple{N₁ + N₂, Int}, NTuple{N₁ + N₂, Int}, Int}}() # size, strides, offset + fusiontreestructure = Vector{StridedStructure{N₁ + N₂}}() # size, strides, offset # temporary data structures splittingtrees = Vector{F₁}() @@ -367,8 +367,8 @@ end blocksize = (blockdim₁, blockdim₂) blocklength = blockdim₁ * blockdim₂ blockrange = (blockoffset + 1):(blockoffset + blocklength) + blockstructure[c] = (blocksize, strides, blockoffset) blockoffset = last(blockrange) - blockstructure[c] = (blocksize, blockrange) end fusiontreeindices = sizehint!( diff --git a/src/tensors/tensor.jl b/src/tensors/tensor.jl index 300654632..a99e0e3c3 100644 --- a/src/tensors/tensor.jl +++ b/src/tensors/tensor.jl @@ -453,28 +453,35 @@ blocks(t::TensorMap) = BlockIterator(t, fusionblockstructure(t).blockstructure) function blocktype(::Type{TT}) where {TT <: TensorMap} A = storagetype(TT) T = eltype(A) - return Base.ReshapedArray{T, 2, SubArray{T, 1, A, Tuple{UnitRange{Int}}, true}, Tuple{}} + @static if isdefined(Core, :Memory) # StridedViews normalizes parent types! + if A <: Vector{T} + A = GenericMemory{T} + end + end + return StridedView{T, 2, A, typeof(identity)} end function Base.iterate(iter::BlockIterator{<:TensorMap}, state...) next = iterate(iter.structure, state...) isnothing(next) && return next - (c, (sz, r)), newstate = next - return c => reshape(view(iter.t.data, r), sz), newstate + (c, (sz, str, offset)), newstate = next + return c => StridedView(iter.t.data, sz, str, offset), newstate end function Base.getindex(iter::BlockIterator{<:TensorMap}, c::Sector) sectortype(iter.t) === typeof(c) || throw(SectorMismatch()) - (d₁, d₂), r = get(iter.structure, c) do - # is s is not a key, at least one of the two dimensions will be zero: + (d₁, d₂), (s₁, s₂), offset = get(iter.structure, c) do + # is c is not a key, at least one of the two dimensions will be zero: # it then does not matter where exactly we construct a view in `t.data`, # as it will have length zero anyway d₁′ = blockdim(codomain(iter.t), c) d₂′ = blockdim(domain(iter.t), c) - l = d₁′ * d₂′ - return (d₁′, d₂′), 1:l + s₁ = 1 + s₂ = 0 + offset = 0 + return (d₁′, d₂′), (s₁, s₂), offset end - return reshape(view(iter.t.data, r), (d₁, d₂)) + return StridedView(iter.t.data, (d₁, d₂), (s₁, s₂), offset) end # Getting and setting the data at the subblock level