Skip to content

Commit 5f07c95

Browse files
committed
Simplify FlattenedAlgorithm
1 parent f50737c commit 5f07c95

File tree

2 files changed

+10
-10
lines changed

2 files changed

+10
-10
lines changed

src/AlgorithmsInterfaceExtensions.jl

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -178,15 +178,16 @@ end
178178
#============================ FlattenedAlgorithm ==========================================#
179179

180180
# Flatten a nested algorithm.
181-
function default_flattened_stopping_criterion(algorithm::NestedAlgorithm)
182-
return AI.StopAfterIteration(sum(max_iterations, algorithm.algorithms))
183-
end
184181
@kwdef struct FlattenedAlgorithm{
185-
ParentAlgorithm <: AI.Algorithm, StoppingCriterion <: AI.StoppingCriterion,
182+
Algorithms <: AbstractVector{<:Algorithm},
183+
StoppingCriterion <: AI.StoppingCriterion,
186184
} <: Algorithm
187-
parent_algorithm::ParentAlgorithm
185+
algorithms::Algorithms
188186
stopping_criterion::StoppingCriterion =
189-
default_flattened_stopping_criterion(parent_algorithm)
187+
AI.StopAfterIteration(sum(max_iterations, algorithms))
188+
end
189+
function FlattenedAlgorithm(f::Function, nalgorithms::Int; kwargs...)
190+
return FlattenedAlgorithm(; algorithms = f.(1:nalgorithms), kwargs...)
190191
end
191192

192193
@kwdef mutable struct FlattenedAlgorithmState{
@@ -212,7 +213,7 @@ function AI.increment!(
212213
)
213214
# Increment the total iteration count.
214215
state.iteration += 1
215-
if state.child_iteration max_iterations(algorithm.parent_algorithm.algorithms[state.parent_iteration])
216+
if state.child_iteration max_iterations(algorithm.algorithms[state.parent_iteration])
216217
# We're on the last iteration of the child algorithm, so move to the next
217218
# child algorithm.
218219
state.parent_iteration += 1
@@ -227,7 +228,7 @@ function AI.step!(
227228
problem::AI.Problem, algorithm::FlattenedAlgorithm, state::FlattenedAlgorithmState;
228229
logging_context_prefix = Symbol()
229230
)
230-
algorithm_sweep = algorithm.parent_algorithm.algorithms[state.parent_iteration]
231+
algorithm_sweep = algorithm.algorithms[state.parent_iteration]
231232
state_sweep = AI.initialize_state(
232233
problem, algorithm_sweep;
233234
state.iterate, iteration = state.child_iteration

test/test_basics.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,10 +92,9 @@ using Test: @test, @testset
9292
end
9393
x = []
9494
problem = EigenProblem(operator)
95-
sweeping = AIE.NestedAlgorithm(nsweeps) do i
95+
algorithm = AIE.FlattenedAlgorithm(nsweeps) do i
9696
Sweep(; regions, region_kwargs = region_kwargs[i])
9797
end
98-
algorithm = AIE.FlattenedAlgorithm(; parent_algorithm = sweeping)
9998
state = AI.initialize_state(problem, algorithm; iterate = x)
10099
iterator = AIE.algorithm_iterator(problem, algorithm, state)
101100
iterations = Int[]

0 commit comments

Comments
 (0)