Skip to content

Commit 8804c87

Browse files
committed
Custom rule for initialize_output
1 parent 045b79d commit 8804c87

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,17 @@ function Mooncake.rrule!!(::CoDual{typeof(copy_input)}, f_df::CoDual, A_dA::CoDu
2626
return CoDual(Ac, dAc), copy_input_pb
2727
end
2828

29+
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(initialize_output), Any, Any, Any}
30+
function Mooncake.rrule!!(::CoDual{typeof(initialize_output)}, f_df::CoDual, A_dA::CoDual, alg_dalg::CoDual)
31+
output = initialize_output(Mooncake.primal(f_df), Mooncake.primal(A_dA), Mooncake.primal(alg_dalg))
32+
doutput = Mooncake.zero_tangent(output)
33+
function initialize_output_pb(::NoRData)
34+
return NoRData(), NoRData(), NoRData(), NoRData()
35+
end
36+
return CoDual(output, doutput), initialize_output_pb
37+
end
38+
39+
2940
# two-argument in-place factorizations like LQ, QR, EIG
3041
for (f!, f, pb, adj) in (
3142
(:qr_full!, :qr_full, :qr_pullback!, :qr_adjoint),

0 commit comments

Comments
 (0)