Skip to content

Use ChainRules for operators #22

@blegat

Description

@blegat

Currently, the approach in ReverseAD is to generate the symbolic expression of the first and second-order derivatives for classical univariate functions using Calculus
https://github.com/jump-dev/MathOptInterface.jl/blob/100eab2e669e73689e1dc214391d97c24402e35c/src/Nonlinear/univariate_expressions_generator.jl
Then, given a representation of the operator as an Int, we do an hard-coded binary search to evaluate a O(log(n)) number of Int comparison instead of a O(n) number of comparison:
https://github.com/jump-dev/MathOptInterface.jl/blob/100eab2e669e73689e1dc214391d97c24402e35c/src/Nonlinear/operators.jl#L570-L582
I'm wondering whether we could get closer to ChainRules instead like other Julia AD framework.
The naive way to do this would be

op = :tanh
f = eval(op)
value_and_derivative(f, 1)

The issue is that, because the value of op is discovered at run-time, the type of f is type-unstable.
But we can use the same trick with the if-else and do

if op == :tanh
    value_and_derivative(tanh, x)
elseif op == :tan
    value_and_derivative(tan, x)
elseif ...
else
   value_and_derivative(eval(op), x)
end

Again, we can do a binary search instead of just a list of if-else.
So, for a fixed number of symbols, we avoid the type-instability thanks to the if-else and we have a fallback for the other ones with the eval.
That would also mean that for registered functions, we need to implement a method and just rely on multiple dispatch instead of adding an operators to the list of user-defined operators, user-defined operators already trigger a type-instability when they are called anyway.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions