-
Notifications
You must be signed in to change notification settings - Fork 0
Description
Currently, the approach in ReverseAD is to generate the symbolic expression of the first and second-order derivatives for classical univariate functions using Calculus
https://github.com/jump-dev/MathOptInterface.jl/blob/100eab2e669e73689e1dc214391d97c24402e35c/src/Nonlinear/univariate_expressions_generator.jl
Then, given a representation of the operator as an Int, we do an hard-coded binary search to evaluate a O(log(n)) number of Int comparison instead of a O(n) number of comparison:
https://github.com/jump-dev/MathOptInterface.jl/blob/100eab2e669e73689e1dc214391d97c24402e35c/src/Nonlinear/operators.jl#L570-L582
I'm wondering whether we could get closer to ChainRules instead like other Julia AD framework.
The naive way to do this would be
op = :tanh
f = eval(op)
value_and_derivative(f, 1)The issue is that, because the value of op is discovered at run-time, the type of f is type-unstable.
But we can use the same trick with the if-else and do
if op == :tanh
value_and_derivative(tanh, x)
elseif op == :tan
value_and_derivative(tan, x)
elseif ...
else
value_and_derivative(eval(op), x)
endAgain, we can do a binary search instead of just a list of if-else.
So, for a fixed number of symbols, we avoid the type-instability thanks to the if-else and we have a fallback for the other ones with the eval.
That would also mean that for registered functions, we need to implement a method and just rely on multiple dispatch instead of adding an operators to the list of user-defined operators, user-defined operators already trigger a type-instability when they are called anyway.