Skip to content

Commit 31dbad5

Browse files
committed
Refactor types a bit
1 parent 178ec48 commit 31dbad5

File tree

4 files changed

+71
-49
lines changed

4 files changed

+71
-49
lines changed

src/AlgorithmsInterfaceExtensions.jl

Lines changed: 45 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -104,12 +104,6 @@ end
104104

105105
abstract type AlgorithmIterator end
106106

107-
struct DefaultAlgorithmIterator{Problem, Algorithm, State} <: AlgorithmIterator
108-
problem::Problem
109-
algorithm::Algorithm
110-
state::State
111-
end
112-
113107
function algorithm_iterator(
114108
problem::Problem, algorithm::Algorithm, state::State
115109
)
@@ -135,6 +129,12 @@ function Base.iterate(iterator::AlgorithmIterator, init = nothing)
135129
return iterator.state, nothing
136130
end
137131

132+
struct DefaultAlgorithmIterator{Problem, Algorithm, State} <: AlgorithmIterator
133+
problem::Problem
134+
algorithm::Algorithm
135+
state::State
136+
end
137+
138138
#============================ with_algorithmlogger ========================================#
139139

140140
# Allow passing functions, not just CallbackActions.
@@ -147,12 +147,16 @@ end
147147

148148
#============================ NestedAlgorithm =============================================#
149149

150-
abstract type AbstractNestedAlgorithm <: Algorithm end
150+
abstract type NestedAlgorithm <: Algorithm end
151+
152+
function nested_algorithm(f::Function, nalgorithms::Int; kwargs...)
153+
return DefaultNestedAlgorithm(f, nalgorithms; kwargs...)
154+
end
151155

152-
max_iterations(algorithm::AbstractNestedAlgorithm) = length(algorithm.algorithms)
156+
max_iterations(algorithm::NestedAlgorithm) = length(algorithm.algorithms)
153157

154158
function AI.step!(
155-
problem::AI.Problem, algorithm::AbstractNestedAlgorithm, state::AI.State;
159+
problem::AI.Problem, algorithm::NestedAlgorithm, state::AI.State;
156160
logging_context_prefix = Symbol()
157161
)
158162
# Perform the current sweep.
@@ -167,45 +171,30 @@ function AI.step!(
167171
end
168172

169173
#=
170-
NestedAlgorithm(sweeps::AbstractVector{<:Algorithm})
174+
DefaultNestedAlgorithm(sweeps::AbstractVector{<:Algorithm})
171175
172176
An algorithm that consists of running an algorithm at each iteration
173177
from a list of stored algorithms.
174178
=#
175-
@kwdef struct NestedAlgorithm{
179+
@kwdef struct DefaultNestedAlgorithm{
176180
Algorithms <: AbstractVector{<:Algorithm},
177181
StoppingCriterion <: AI.StoppingCriterion,
178-
} <: AbstractNestedAlgorithm
182+
} <: NestedAlgorithm
179183
algorithms::Algorithms
180184
stopping_criterion::StoppingCriterion = AI.StopAfterIteration(length(algorithms))
181185
end
182-
function NestedAlgorithm(f::Function, nalgorithms::Int; kwargs...)
183-
return NestedAlgorithm(; algorithms = f.(1:nalgorithms), kwargs...)
186+
function DefaultNestedAlgorithm(f::Function, nalgorithms::Int; kwargs...)
187+
return DefaultNestedAlgorithm(; algorithms = f.(1:nalgorithms), kwargs...)
184188
end
185189

186190
#============================ FlattenedAlgorithm ==========================================#
187191

188192
# Flatten a nested algorithm.
189-
@kwdef struct FlattenedAlgorithm{
190-
Algorithms <: AbstractVector{<:Algorithm},
191-
StoppingCriterion <: AI.StoppingCriterion,
192-
} <: Algorithm
193-
algorithms::Algorithms
194-
stopping_criterion::StoppingCriterion =
195-
AI.StopAfterIteration(sum(max_iterations, algorithms))
196-
end
197-
function FlattenedAlgorithm(f::Function, nalgorithms::Int; kwargs...)
198-
return FlattenedAlgorithm(; algorithms = f.(1:nalgorithms), kwargs...)
199-
end
193+
abstract type FlattenedAlgorithm <: Algorithm end
194+
abstract type FlattenedAlgorithmState <: State end
200195

201-
@kwdef mutable struct FlattenedAlgorithmState{
202-
Iterate, StoppingCriterionState <: AI.StoppingCriterionState,
203-
} <: State
204-
iterate::Iterate
205-
iteration::Int = 0
206-
parent_iteration::Int = 1
207-
child_iteration::Int = 0
208-
stopping_criterion_state::StoppingCriterionState
196+
function flattened_algorithm(f::Function, nalgorithms::Int; kwargs...)
197+
return DefaultFlattenedAlgorithm(f, nalgorithms; kwargs...)
209198
end
210199

211200
function AI.initialize_state(
@@ -214,7 +203,7 @@ function AI.initialize_state(
214203
stopping_criterion_state = AI.initialize_state(
215204
problem, algorithm, algorithm.stopping_criterion
216205
)
217-
return FlattenedAlgorithmState(; stopping_criterion_state, kwargs...)
206+
return DefaultFlattenedAlgorithmState(; stopping_criterion_state, kwargs...)
218207
end
219208
function AI.increment!(
220209
problem::Problem, algorithm::Algorithm, state::FlattenedAlgorithmState
@@ -247,4 +236,26 @@ function AI.step!(
247236
return state
248237
end
249238

239+
@kwdef struct DefaultFlattenedAlgorithm{
240+
Algorithms <: AbstractVector{<:Algorithm},
241+
StoppingCriterion <: AI.StoppingCriterion,
242+
} <: FlattenedAlgorithm
243+
algorithms::Algorithms
244+
stopping_criterion::StoppingCriterion =
245+
AI.StopAfterIteration(sum(max_iterations, algorithms))
246+
end
247+
function DefaultFlattenedAlgorithm(f::Function, nalgorithms::Int; kwargs...)
248+
return DefaultFlattenedAlgorithm(; algorithms = f.(1:nalgorithms), kwargs...)
249+
end
250+
251+
@kwdef mutable struct DefaultFlattenedAlgorithmState{
252+
Iterate, StoppingCriterionState <: AI.StoppingCriterionState,
253+
} <: FlattenedAlgorithmState
254+
iterate::Iterate
255+
iteration::Int = 0
256+
parent_iteration::Int = 1
257+
child_iteration::Int = 0
258+
stopping_criterion_state::StoppingCriterionState
259+
end
260+
250261
end

src/eigenproblem.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ end
99

1010
function dmrg(operator, state; nsweeps, regions, region_kwargs, kwargs...)
1111
problem = EigenProblem(operator)
12-
algorithm = AIE.NestedAlgorithm(nsweeps) do i
12+
algorithm = Sweeping(nsweeps) do i
1313
return Sweep(; regions, region_kwargs = region_kwargs[i])
1414
end
1515
return AI.solve(problem, algorithm; iterate = state, kwargs...).iterate
@@ -35,7 +35,7 @@ end
3535
function solve_region!!(problem::EigenProblem, algorithm::RegionAlgorithm, state)
3636
operator = problem.operator
3737
region = algorithm.region
38-
region_kwargs = algorithm.region_kwargs(algorithm, state)
38+
region_kwargs = algorithm.kwargs(algorithm, state)
3939

4040
#=
4141
# Reduce the `operator` and state `x` onto the region `region`,

src/sweep.jl

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
import AlgorithmsInterface as AI
22
import .AlgorithmsInterfaceExtensions as AIE
33

4-
@kwdef struct RegionAlgorithm{Region, RegionKwargs <: Function} <: AIE.Algorithm
5-
region::Region
6-
region_kwargs::RegionKwargs
4+
@kwdef struct Sweeping{
5+
Algorithms <: AbstractVector{<:AI.Algorithm},
6+
StoppingCriterion <: AI.StoppingCriterion,
7+
} <: AIE.NestedAlgorithm
8+
algorithms::Algorithms
9+
stopping_criterion::StoppingCriterion = AI.StopAfterIteration(length(algorithms))
710
end
8-
function RegionAlgorithm(region, region_kwargs::NamedTuple)
9-
return RegionAlgorithm(region, Returns(region_kwargs))
11+
function Sweeping(f::Function, nalgorithms::Int; kwargs...)
12+
return Sweeping(; algorithms = f.(1:nalgorithms), kwargs...)
1013
end
1114

1215
#=
@@ -36,3 +39,11 @@ function Sweep(;
3639
return Sweep(algorithms, stopping_criterion)
3740
end
3841
AIE.max_iterations(algorithm::Sweep) = length(algorithm.algorithms)
42+
43+
@kwdef struct RegionAlgorithm{Region, Kwargs <: Function}
44+
region::Region
45+
kwargs::Kwargs
46+
end
47+
function RegionAlgorithm(region, kwargs::NamedTuple)
48+
return RegionAlgorithm(region, Returns(kwargs))
49+
end

test/test_basics.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import AlgorithmsInterface as AI
22
using Graphs: path_graph
3-
using TensorNetworkSolvers: EigenProblem, Sweep, dmrg, dmrg_sweep
3+
using TensorNetworkSolvers: EigenProblem, Sweep, Sweeping, dmrg, dmrg_sweep
44
import TensorNetworkSolvers.AlgorithmsInterfaceExtensions as AIE
55
using Test: @test, @testset
66

@@ -66,7 +66,7 @@ using Test: @test, @testset
6666
end
6767
x = []
6868
problem = EigenProblem(operator)
69-
algorithm = AIE.NestedAlgorithm(nsweeps) do i
69+
algorithm = Sweeping(nsweeps) do i
7070
Sweep(; regions, region_kwargs = region_kwargs[i])
7171
end
7272
state = AI.initialize_state(problem, algorithm; iterate = x)
@@ -92,7 +92,7 @@ using Test: @test, @testset
9292
end
9393
x = []
9494
problem = EigenProblem(operator)
95-
algorithm = AIE.FlattenedAlgorithm(nsweeps) do i
95+
algorithm = AIE.flattened_algorithm(nsweeps) do i
9696
Sweep(; regions, region_kwargs = region_kwargs[i])
9797
end
9898
state = AI.initialize_state(problem, algorithm; iterate = x)
@@ -146,11 +146,11 @@ using Test: @test, @testset
146146
return nothing
147147
end
148148
x = AIE.with_algorithmlogger(
149-
:EigenProblem_NestedAlgorithm_Start => print_dmrg_start,
150-
:EigenProblem_NestedAlgorithm_PreStep => print_dmrg_prestep,
151-
:EigenProblem_NestedAlgorithm_PostStep => print_dmrg_poststep,
152-
:EigenProblem_NestedAlgorithm_Sweep_Start => print_sweep_start,
153-
:EigenProblem_NestedAlgorithm_Sweep_PostStep => print_sweep_poststep,
149+
:EigenProblem_Sweeping_Start => print_dmrg_start,
150+
:EigenProblem_Sweeping_PreStep => print_dmrg_prestep,
151+
:EigenProblem_Sweeping_PostStep => print_dmrg_poststep,
152+
:EigenProblem_Sweeping_Sweep_Start => print_sweep_start,
153+
:EigenProblem_Sweeping_Sweep_PostStep => print_sweep_poststep,
154154
) do
155155
x = dmrg(operator, x0; nsweeps, regions, region_kwargs)
156156
return x

0 commit comments

Comments
 (0)