Skip to content
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
name = "ArrayDiff"
uuid = "c45fa1ca-6901-44ac-ae5b-5513a4852d50"
authors = ["Benoît Legat <benoit.legat@gmail.com>"]
version = "0.1.0"
authors = ["Benoît Legat <benoit.legat@gmail.com>"]

[deps]
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
MathOptInterface = "b8f27783-ece8-5eb3-8dc8-9495eed66fee"
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35"

[compat]
DataStructures = "0.18, 0.19"
ForwardDiff = "1"
MathOptInterface = "1.40"
NaNMath = "1"
SparseArrays = "1.10"
SparseMatrixColorings = "0.4"
julia = "1.10"
24 changes: 18 additions & 6 deletions src/ArrayDiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,26 +6,38 @@

module ArrayDiff

import SparseArrays
import SparseMatrixColorings as SMC
import ForwardDiff
import MathOptInterface as MOI
const Nonlinear = MOI.Nonlinear
import SparseArrays

"""
Mode() <: AbstractAutomaticDifferentiation
Mode(coloring_algorithm::SMC.GreedyColoringAlgorithm) <: AbstractAutomaticDifferentiation

Fork of `MOI.Nonlinear.SparseReverseMode` to add array support.
"""
struct Mode <: MOI.Nonlinear.AbstractAutomaticDifferentiation end
struct Mode{C<:SMC.GreedyColoringAlgorithm} <:
MOI.Nonlinear.AbstractAutomaticDifferentiation
coloring_algorithm::C
end

function Mode()
return Mode(
SMC.GreedyColoringAlgorithm(;
decompression = :substitution,
),
)
end

function MOI.Nonlinear.Evaluator(
model::MOI.Nonlinear.Model,
::Mode,
mode::Mode,
ordered_variables::Vector{MOI.VariableIndex},
)
return MOI.Nonlinear.Evaluator(
model,
NLPEvaluator(model, ordered_variables),
NLPEvaluator(model, ordered_variables, mode.coloring_algorithm),
)
end

Expand All @@ -48,7 +60,7 @@ import NaNMath:
pow,
sqrt

include("Coloring/Coloring.jl")
include("coloring.jl")
include("graph_tools.jl")
include("sizes.jl")
include("types.jl")
Expand Down
Loading
Loading