Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
odow committed Feb 14, 2025
1 parent ed9626f commit fad13e6
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 10 deletions.
10 changes: 5 additions & 5 deletions docs/src/submodules/Nonlinear/SymbolicAD.md
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ julia> f = MOI.ScalarNonlinearFunction(:sin, Any[x])
sin(MOI.VariableIndex(1))
julia> MOI.Nonlinear.SymbolicAD.derivative(f, x)
*(cos(MOI.VariableIndex(1)), (true))
cos(MOI.VariableIndex(1)
```

Note that the resultant expression can often be simplified. Thus, in most cases
Expand All @@ -196,14 +196,14 @@ using it in other places:
julia> x = MOI.VariableIndex(1)
MOI.VariableIndex(1)
julia> f = MOI.ScalarNonlinearFunction(:sin, Any[x])
sin(MOI.VariableIndex(1))
julia> f = MOI.ScalarNonlinearFunction(:sin, Any[x + 1.0])
sin(1.0 + 1.0 MOI.VariableIndex(1))
julia> df_dx = MOI.Nonlinear.SymbolicAD.derivative(f, x)
*(cos(MOI.VariableIndex(1)), (true))
*(cos(1.0 + 1.0 MOI.VariableIndex(1)), 1.0)
julia> MOI.Nonlinear.SymbolicAD.simplify!(df_dx)
cos(MOI.VariableIndex(1))
cos(1.0 + 1.0 MOI.VariableIndex(1))
```

## `gradient_and_hessian`
Expand Down
18 changes: 13 additions & 5 deletions src/Nonlinear/SymbolicAD/SymbolicAD.jl
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,14 @@ that the user would never write themselves.
"""
const __DERIVATIVE__ = "__DERIVATIVE__"

# This function helps simplify df_du * du_dx in the commonn case that `du_dx`
# is `true` (when u = x), or `false` (when x ∉ u).
function _univariate_chain_rule(df_du, du_dx)
return MOI.ScalarNonlinearFunction(:*, Any[df_du, du_dx])
end

_univariate_chain_rule(df_du, du_dx::Bool) = ifelse(du_dx, df_du, du_dx)

function derivative(f::MOI.ScalarNonlinearFunction, x::MOI.VariableIndex)
if length(f.args) == 1
u = only(f.args)
Expand All @@ -571,28 +579,28 @@ function derivative(f::MOI.ScalarNonlinearFunction, x::MOI.VariableIndex)
:ifelse,
Any[MOI.ScalarNonlinearFunction(:>=, Any[u, 0]), 1, -1],
)
return MOI.ScalarNonlinearFunction(:*, Any[df_du, du_dx])
return _univariate_chain_rule(df_du, du_dx)
elseif f.head == :sign
return false
elseif f.head == :deg2rad
df_du = deg2rad(1)
return MOI.ScalarNonlinearFunction(:*, Any[df_du, du_dx])
return _univariate_chain_rule(df_du, du_dx)
elseif f.head == :rad2deg
df_du = rad2deg(1)
return MOI.ScalarNonlinearFunction(:*, Any[df_du, du_dx])
return _univariate_chain_rule(df_du, du_dx)
end
for (key, df, _) in MOI.Nonlinear.SYMBOLIC_UNIVARIATE_EXPRESSIONS
if key == f.head
# The chain rule: d(f(g(x))) / dx = f'(g(x)) * g'(x)
df_du = _replace_expression(copy(df), u)
return MOI.ScalarNonlinearFunction(:*, Any[df_du, du_dx])
return _univariate_chain_rule(df_du, du_dx)
end
end
# Delay derivative until evaluation. This may result in a later
# UnsupportedNonlinearOperator error, but we can't tell just yet.
d_op = Symbol(__DERIVATIVE__ * "$(f.head)")
df_du = MOI.ScalarNonlinearFunction(d_op, Any[u])
return MOI.ScalarNonlinearFunction(:*, Any[df_du, du_dx])
return _univariate_chain_rule(df_du, du_dx)
end
if f.head == :+
# d/dx(+(args...)) = +(d/dx args)
Expand Down
6 changes: 6 additions & 0 deletions test/Nonlinear/SymbolicAD.jl
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,12 @@ function test_derivative()
return
end

function test_derivative_univariate_simplification()
x = MOI.VariableIndex(1)
@test SymbolicAD.derivative(op(:sin, x), x) op(:cos, x)
return
end

function test_derivative_error()
x = MOI.VariableIndex(1)
f = MOI.ScalarNonlinearFunction(:foo, Any[x, x])
Expand Down

0 comments on commit fad13e6

Please sign in to comment.