Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ITensorNetworksNext"
uuid = "302f2e75-49f0-4526-aef7-d8ba550cb06c"
authors = ["ITensor developers <support@itensor.org> and contributors"]
version = "0.3.1"
version = "0.3.2"

[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Expand All @@ -10,9 +10,9 @@ AlgorithmsInterface = "d1e3940c-cd12-4505-8585-b0a4b322527d"
BackendSelection = "680c2d7c-f67a-4cc9-ae9c-da132b1447a5"
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
DataGraphs = "b5a273c3-7e6c-41f6-98bd-8d7f1525a36a"
DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f"
DiagonalArrays = "74fd4be6-21e2-4f6f-823a-4360d37c7a77"
Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4"
FunctionImplementations = "7c7cc465-9c6a-495f-bdd1-f42428e86d0c"
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
Expand All @@ -37,13 +37,13 @@ AlgorithmsInterface = "0.1"
BackendSelection = "0.1.6"
Combinatorics = "1"
DataGraphs = "0.2.7"
DerivableInterfaces = "0.5.5"
DiagonalArrays = "0.3.23"
Dictionaries = "0.4.5"
FunctionImplementations = "0.3"
Graphs = "1.13.1"
LinearAlgebra = "1.10"
MacroTools = "0.5.16"
NamedDimsArrays = "0.8, 0.9, 0.10, 0.11"
NamedDimsArrays = "0.12"
NamedGraphs = "0.6.9, 0.7, 0.8"
SimpleTraits = "0.9.5"
SplitApplyCombine = "1.2.3"
Expand Down
2 changes: 1 addition & 1 deletion src/LazyNamedDimsArrays/lazybroadcast.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using NamedDimsArrays: AbstractNamedDimsArrayStyle
using NamedDimsArrays.Broadcast: AbstractNamedDimsArrayStyle

# Lazy broadcasting.
struct LazyNamedDimsArrayStyle <: AbstractNamedDimsArrayStyle{Any} end
Expand Down
63 changes: 30 additions & 33 deletions src/LazyNamedDimsArrays/lazyinterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,22 +23,19 @@ opwalk(opmap, a) = walk(opmap, identity, a)
argwalk(argmap, a) = walk(identity, argmap, a)

# Generic lazy functionality.
using DerivableInterfaces: AbstractArrayInterface, InterfaceFunction
struct LazyInterface{N} <: AbstractArrayInterface{N} end
LazyInterface() = LazyInterface{Any}()
LazyInterface(::Val{N}) where {N} = LazyInterface{N}()
LazyInterface{M}(::Val{N}) where {M, N} = LazyInterface{N}()
const lazy_interface = LazyInterface()
using FunctionImplementations: AbstractArrayStyle
struct LazyStyle <: AbstractArrayStyle end
const lazy_style = LazyStyle()

const maketerm_lazy = lazy_interface(maketerm)
const maketerm_lazy = lazy_style(maketerm)
function maketerm_lazy(type::Type, head, args, metadata)
if head ≡ *
return type(maketerm(Mul, head, args, metadata))
else
return error("Only mul supported right now.")
end
end
const getindex_lazy = lazy_interface(getindex)
const getindex_lazy = lazy_style(getindex)
function getindex_lazy(a::AbstractArray, I...)
u = unwrap(a)
if !iscall(u)
Expand All @@ -47,7 +44,7 @@ function getindex_lazy(a::AbstractArray, I...)
return error("Indexing into expression not supported.")
end
end
const arguments_lazy = lazy_interface(arguments)
const arguments_lazy = lazy_style(arguments)
function arguments_lazy(a)
u = unwrap(a)
if !iscall(u)
Expand All @@ -59,17 +56,17 @@ function arguments_lazy(a)
end
end
using TermInterface: children
const children_lazy = lazy_interface(children)
const children_lazy = lazy_style(children)
children_lazy(a) = arguments(a)
using TermInterface: head
const head_lazy = lazy_interface(head)
const head_lazy = lazy_style(head)
head_lazy(a) = operation(a)
const iscall_lazy = lazy_interface(iscall)
const iscall_lazy = lazy_style(iscall)
iscall_lazy(a) = iscall(unwrap(a))
using TermInterface: isexpr
const isexpr_lazy = lazy_interface(isexpr)
const isexpr_lazy = lazy_style(isexpr)
isexpr_lazy(a) = iscall(a)
const operation_lazy = lazy_interface(operation)
const operation_lazy = lazy_style(operation)
function operation_lazy(a)
u = unwrap(a)
if !iscall(u)
Expand All @@ -80,7 +77,7 @@ function operation_lazy(a)
return error("Variant not supported.")
end
end
const sorted_arguments_lazy = lazy_interface(sorted_arguments)
const sorted_arguments_lazy = lazy_style(sorted_arguments)
function sorted_arguments_lazy(a)
u = unwrap(a)
if !iscall(u)
Expand All @@ -92,12 +89,12 @@ function sorted_arguments_lazy(a)
end
end
using TermInterface: sorted_children
const sorted_children_lazy = lazy_interface(sorted_children)
const sorted_children_lazy = lazy_style(sorted_children)
sorted_children_lazy(a) = sorted_arguments(a)
const ismul_lazy = lazy_interface(ismul)
const ismul_lazy = lazy_style(ismul)
ismul_lazy(a) = ismul(unwrap(a))
using AbstractTrees: AbstractTrees
const abstracttrees_children_lazy = lazy_interface(AbstractTrees.children)
const abstracttrees_children_lazy = lazy_style(AbstractTrees.children)
function abstracttrees_children_lazy(a)
if !iscall(a)
return ()
Expand All @@ -106,7 +103,7 @@ function abstracttrees_children_lazy(a)
end
end
using AbstractTrees: nodevalue
const nodevalue_lazy = lazy_interface(nodevalue)
const nodevalue_lazy = lazy_style(nodevalue)
function nodevalue_lazy(a)
if !iscall(a)
return unwrap(a)
Expand All @@ -115,11 +112,11 @@ function nodevalue_lazy(a)
end
end
using Base.Broadcast: materialize
const materialize_lazy = lazy_interface(materialize)
const materialize_lazy = lazy_style(materialize)
materialize_lazy(a) = argwalk(unwrap, a)
const copy_lazy = lazy_interface(copy)
const copy_lazy = lazy_style(copy)
copy_lazy(a) = materialize(a)
const equals_lazy = lazy_interface(==)
const equals_lazy = lazy_style(==)
function equals_lazy(a1, a2)
u1, u2 = unwrap.((a1, a2))
if !iscall(u1) && !iscall(u2)
Expand All @@ -130,7 +127,7 @@ function equals_lazy(a1, a2)
return false
end
end
const isequal_lazy = lazy_interface(isequal)
const isequal_lazy = lazy_style(isequal)
function isequal_lazy(a1, a2)
u1, u2 = unwrap.((a1, a2))
if !iscall(u1) && !iscall(u2)
Expand All @@ -141,13 +138,13 @@ function isequal_lazy(a1, a2)
return false
end
end
const hash_lazy = lazy_interface(hash)
const hash_lazy = lazy_style(hash)
function hash_lazy(a, h::UInt64)
h = hash(Symbol(unspecify_type_parameters(typeof(a))), h)
# Use `_hash`, which defines a custom hash for NamedDimsArray.
return _hash(unwrap(a), h)
end
const map_arguments_lazy = lazy_interface(map_arguments)
const map_arguments_lazy = lazy_style(map_arguments)
function map_arguments_lazy(f, a)
u = unwrap(a)
if !iscall(u)
Expand All @@ -159,21 +156,21 @@ function map_arguments_lazy(f, a)
end
end
function substitute end
const substitute_lazy = lazy_interface(substitute)
const substitute_lazy = lazy_style(substitute)
function substitute_lazy(a, substitutions::AbstractDict)
haskey(substitutions, a) && return substitutions[a]
!iscall(a) && return a
return map_arguments(arg -> substitute(arg, substitutions), a)
end
substitute_lazy(a, substitutions) = substitute(a, Dict(substitutions))
using AbstractTrees: printnode
const printnode_lazy = lazy_interface(printnode)
const printnode_lazy = lazy_style(printnode)
function printnode_lazy(io, a)
# Use `printnode_nameddims` to avoid type piracy,
# since it overloads on `AbstractNamedDimsArray`.
return printnode_nameddims(io, unwrap(a))
end
const show_lazy = lazy_interface(show)
const show_lazy = lazy_style(show)
function show_lazy(io::IO, a)
if !iscall(a)
return show(io, unwrap(a))
Expand All @@ -187,12 +184,12 @@ function show_lazy(io::IO, mime::MIME"text/plain", a)
!iscall(a) ? show(io, mime, unwrap(a)) : show(io, a)
return nothing
end
const add_lazy = lazy_interface(+)
const add_lazy = lazy_style(+)
add_lazy(a1, a2) = error("Not implemented.")
const sub_lazy = lazy_interface(-)
const sub_lazy = lazy_style(-)
sub_lazy(a) = error("Not implemented.")
sub_lazy(a1, a2) = error("Not implemented.")
const mul_lazy = lazy_interface(*)
const mul_lazy = lazy_style(*)
function mul_lazy(a)
u = unwrap(a)
if !iscall(u)
Expand All @@ -216,7 +213,7 @@ mul_lazy(a1::Number, a2::Number) = a1 * a2
div_lazy(a1, a2::Number) = error("Not implemented.")

# NamedDimsArrays.jl interface.
const inds_lazy = lazy_interface(inds)
const inds_lazy = lazy_style(inds)
function inds_lazy(a)
u = unwrap(a)
if !iscall(u)
Expand All @@ -227,7 +224,7 @@ function inds_lazy(a)
return error("Variant not supported.")
end
end
const dename_lazy = lazy_interface(dename)
const dename_lazy = lazy_style(dename)
function dename_lazy(a)
u = unwrap(a)
if !iscall(u)
Expand Down
4 changes: 2 additions & 2 deletions src/LazyNamedDimsArrays/symbolicarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ end
function Base.setindex!(a::SymbolicArray{<:Any, N}, value, I::Vararg{Int, N}) where {N}
return error("Indexing into SymbolicArray not supported.")
end
using DerivableInterfaces: DerivableInterfaces
DerivableInterfaces.permuteddims(a::SymbolicArray, p) = permutedims(a, p)
using FunctionImplementations: FunctionImplementations
FunctionImplementations.permuteddims(a::SymbolicArray, p) = permutedims(a, p)
function Base.permutedims(a::SymbolicArray, p)
@assert ndims(a) == length(p) && isperm(p)
return SymbolicArray(symname(a), ntuple(i -> axes(a)[p[i]], ndims(a)))
Expand Down
2 changes: 1 addition & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ Dictionaries = "0.4.5"
Graphs = "1.13.1"
ITensorBase = "0.3, 0.4"
ITensorNetworksNext = "0.3"
NamedDimsArrays = "0.8, 0.9, 0.10, 0.11"
NamedDimsArrays = "0.12"
NamedGraphs = "0.6.8, 0.7, 0.8"
QuadGK = "2.11.2"
SafeTestsets = "0.1"
Expand Down
Loading