diff --git a/Project.toml b/Project.toml index 9c6da1b6..2ce216a7 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ITensorNetworks" uuid = "2919e153-833c-4bdc-8836-1ea460a35fc7" authors = ["Matthew Fishman , Joseph Tindall and contributors"] -version = "0.15.4" +version = "0.15.5" [deps] AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" diff --git a/src/solvers/applyexp.jl b/src/solvers/applyexp.jl index fc780376..7f650539 100644 --- a/src/solvers/applyexp.jl +++ b/src/solvers/applyexp.jl @@ -62,9 +62,9 @@ end function default_sweep_callback( sweep_iterator::SweepIterator{<:ApplyExpProblem}; exponent_description = "exponent", - outputlevel = 0, process_time = identity, ) + outputlevel = get(region_kwargs(region_iterator(sweep_iterator)), :outputlevel, 0) return if outputlevel >= 1 the_problem = problem(sweep_iterator) @printf( diff --git a/src/solvers/eigsolve.jl b/src/solvers/eigsolve.jl index 8fbed9b2..9acaa513 100644 --- a/src/solvers/eigsolve.jl +++ b/src/solvers/eigsolve.jl @@ -23,8 +23,7 @@ end function update!( region_iter::RegionIterator{<:EigsolveProblem}, local_state; - outputlevel = 0, - solver = eigsolve_solver, + solver = eigsolve_solver ) prob = problem(region_iter) @@ -34,6 +33,7 @@ function update!( prob.eigenvalue = eigval + outputlevel = get(region_kwargs(region_iter), :outputlevel, 0) if outputlevel >= 2 @printf(" Region %s: energy = %.12f\n", current_region(region_iter), eigenvalue(prob)) end @@ -41,8 +41,9 @@ function update!( end function default_sweep_callback( - sweep_iterator::SweepIterator{<:EigsolveProblem}; outputlevel = 0 + sweep_iterator::SweepIterator{<:EigsolveProblem} ) + outputlevel = get(region_kwargs(region_iterator(sweep_iterator)), :outputlevel, 0) return if outputlevel >= 1 nsweeps = length(sweep_iterator) current_sweep = sweep_iterator.which_sweep @@ -51,9 +52,10 @@ function default_sweep_callback( else @printf("After sweep %d/%d ", current_sweep, nsweeps) end - @printf("eigenvalue=%.12f", eigenvalue(problem)) - @printf(" maxlinkdim=%d", maxlinkdim(state(problem))) - @printf(" max truncerror=%d", max_truncerror(problem)) + current_problem = problem(sweep_iterator) + @printf("eigenvalue=%.12f", eigenvalue(current_problem)) + @printf(" maxlinkdim=%d", maxlinkdim(current_problem)) + @printf(" max truncerror=%d", max_truncerror(current_problem)) println() flush(stdout) end diff --git a/src/solvers/iterators.jl b/src/solvers/iterators.jl index 990d86d2..b5c5c0bb 100644 --- a/src/solvers/iterators.jl +++ b/src/solvers/iterators.jl @@ -12,7 +12,10 @@ abstract type AbstractNetworkIterator end islaststep(iterator::AbstractNetworkIterator) = state(iterator) >= length(iterator) function Base.iterate(iterator::AbstractNetworkIterator, init = true) - islaststep(iterator) && return nothing + # The assumption is that first "increment!" is implicit, therefore we must skip the + # the termination check for the first iteration, i.e. `AbstractNetworkIterator` is not + # defined when length < 1, + init || islaststep(iterator) && return nothing # We seperate increment! from step! and demand that any AbstractNetworkIterator *must* # define a method for increment! This way we avoid cases where one may wish to nest # calls to different step! methods accidentaly incrementing multiple times. @@ -44,6 +47,9 @@ mutable struct RegionIterator{Problem, RegionPlan} <: AbstractNetworkIterator which_region::Int const which_sweep::Int function RegionIterator(problem::P, region_plan::R, sweep::Int) where {P, R} + if length(region_plan) == 0 + throw(BoundsError("Cannot construct a region iterator with 0 elements.")) + end return new{P, R}(problem, region_plan, 1, sweep) end end @@ -115,26 +121,33 @@ region_plan(problem; sweep_kwargs...) = euler_sweep(state(problem); sweep_kwargs mutable struct SweepIterator{Problem, Iter} <: AbstractNetworkIterator region_iter::RegionIterator{Problem} - sweep_kwargs::Iterators.Stateful{Iter} + sweep_kwargs::Iter which_sweep::Int + nsweeps::Int function SweepIterator(problem::Prob, sweep_kwargs::Iter) where {Prob, Iter} - stateful_sweep_kwargs = Iterators.Stateful(sweep_kwargs) - first_kwargs, _ = Iterators.peel(stateful_sweep_kwargs) + first_state = Iterators.peel(sweep_kwargs) + if isnothing(first_state) + throw(BoundsError("Cannot construct a sweep iterator with 0 elements.")) + end + first_kwargs, sweep_kwargs_rest = first_state region_iter = RegionIterator(problem; sweep = 1, first_kwargs...) - return new{Prob, Iter}(region_iter, stateful_sweep_kwargs, 1) + return new{Prob, typeof(sweep_kwargs_rest)}(region_iter, sweep_kwargs_rest, 1, length(sweep_kwargs)) end end -islaststep(sweep_iter::SweepIterator) = isnothing(peek(sweep_iter.sweep_kwargs)) +islaststep(sweep_iter::SweepIterator) = isempty(sweep_iter.sweep_kwargs) region_iterator(sweep_iter::SweepIterator) = sweep_iter.region_iter + problem(sweep_iter::SweepIterator) = problem(region_iterator(sweep_iter)) state(sweep_iter::SweepIterator) = sweep_iter.which_sweep -Base.length(sweep_iter::SweepIterator) = length(sweep_iter.sweep_kwargs) + +Base.length(sweep_iter::SweepIterator) = sweep_iter.nsweeps + function increment!(sweep_iter::SweepIterator) sweep_iter.which_sweep += 1 - sweep_kwargs, _ = Iterators.peel(sweep_iter.sweep_kwargs) + sweep_kwargs, sweep_iter.sweep_kwargs = Iterators.peel(sweep_iter.sweep_kwargs) update_region_iterator!(sweep_iter; sweep_kwargs...) return sweep_iter end diff --git a/src/solvers/region_plans/euler_plans.jl b/src/solvers/region_plans/euler_plans.jl index aff6e3cd..c631ecaf 100644 --- a/src/solvers/region_plans/euler_plans.jl +++ b/src/solvers/region_plans/euler_plans.jl @@ -1,7 +1,7 @@ using Graphs: dst, src using NamedGraphs.GraphsExtensions: default_root_vertex -function euler_sweep(graph; nsites, root_vertex = default_root_vertex(graph), sweep_kwargs...) +function euler_sweep(graph; nsites = 1, root_vertex = default_root_vertex(graph), sweep_kwargs...) sweep_kwargs = (; nsites, root_vertex, sweep_kwargs...) if nsites == 1 diff --git a/test/solvers/test_applyexp.jl b/test/solvers/test_applyexp.jl index 79b37003..4c7e0938 100644 --- a/test/solvers/test_applyexp.jl +++ b/test/solvers/test_applyexp.jl @@ -53,7 +53,7 @@ end nsites = 2 factorize_kwargs = (; cutoff, maxdim) - E, gs_psi = dmrg(H, psi0; factorize_kwargs, nsites, nsweeps, outputlevel) + E, gs_psi = dmrg(H, psi0; factorize_kwargs, nsites, nsweeps, outputlevel = 0) (outputlevel >= 1) && println("2-site DMRG energy = ", E) nsites = 1 diff --git a/test/solvers/test_sweepiterator.jl b/test/solvers/test_sweepiterator.jl new file mode 100644 index 00000000..69340f41 --- /dev/null +++ b/test/solvers/test_sweepiterator.jl @@ -0,0 +1,60 @@ +using Test: @test, @testset +using ITensorNetworks: ITensorNetworks, AbstractProblem, RegionIterator, SweepIterator, compute!, region_iterator, region_kwargs + +include("utilities/tree_graphs.jl") + +# TestProblem type for testing +struct TestProblem <: AbstractProblem + graph +end + +ITensorNetworks.state(T::TestProblem) = T.graph + +ITensorNetworks.compute!(R::RegionIterator{<:TestProblem}) = "TestProblem Compute" + + +@testset "SweepIterator Basics" begin + g = build_tree(; nbranch = 3, nbranch_sites = 3) + prob = TestProblem(g) + + nsweeps = 5 + + # Basic construction, taking length + sweep_iter = SweepIterator(prob, nsweeps) + @test length(sweep_iter) == nsweeps + + # Pass keyword parameters + test_kwarg_a = 1 + test_kwarg_b = "b" + sweep_iter = SweepIterator(prob, nsweeps; test_kwarg_a, test_kwarg_b) + @test region_kwargs(region_iterator(sweep_iter)).test_kwarg_a == test_kwarg_a + @test region_kwargs(region_iterator(sweep_iter)).test_kwarg_b == test_kwarg_b + + # Pass array of parameters + kws_array = [(; outputlevel = 0), (; outputlevel = 1)] + sweep_iter = SweepIterator(prob, kws_array) + @test length(sweep_iter) == length(kws_array) + @test region_kwargs(region_iterator(sweep_iter)).outputlevel == 0 +end + +@testset "SweepIterator Iteration" begin + g = build_tree(; nbranch = 3, nbranch_sites = 3) + prob = TestProblem(g) + + nsweeps = 5 + sweep_iter = SweepIterator(prob, nsweeps) + count = 0 + for _ in sweep_iter + count += 1 + end + @test count == nsweeps + + # Test case of one iteration + nsweeps = 1 + sweep_iter = SweepIterator(prob, nsweeps) + count = 0 + for _ in sweep_iter + count += 1 + end + @test count == nsweeps +end