@@ -55,20 +55,27 @@ function test_pullbacks_match(rng, f!, f, A, args, Δargs, alg = nothing; rdata
5555
5656 dargs_inplace = Δargs isa Tuple ? make_mooncake_fdata .(deepcopy (Δargs)) : make_mooncake_fdata (deepcopy (Δargs))
5757
58- copy_out, copy_pb!! = isnothing (alg) ? rrule (Mooncake . CoDual (f_c, Mooncake . NoFData ()), Mooncake . CoDual (Ac, dA_copy), Mooncake . CoDual (args, dargs_copy)) : rrule (Mooncake . CoDual (f_c, Mooncake . NoFData ()), Mooncake . CoDual (Ac, dA_copy), Mooncake . CoDual (args, dargs_copy), Mooncake . CoDual (alg, Mooncake . NoFData ()) )
59- copy_pb!! (rdata )
60-
61- # not every f! has a handwritten rrule!!
62- inplace_sig = isnothing (alg) ? Tuple{ typeof (f!), typeof (A), typeof (args)} : Tuple{ typeof (f!), typeof (A), typeof (args ), typeof (alg)}
63-
64- has_handwritten_rule = hasmethod ( Mooncake. rrule!! , inplace_sig)
65- if has_handwritten_rule
66- inplace_out, inplace_pb!! = isnothing (alg) ? Mooncake . rrule!! (Mooncake . CoDual (f!, Mooncake . NoFData ()), Mooncake . CoDual (Ac, dA_inplace), Mooncake . CoDual (args, dargs_inplace)) : Mooncake . rrule!! (Mooncake . CoDual (f!, Mooncake . NoFData ()), Mooncake . CoDual (Ac, dA_inplace), Mooncake . CoDual (args, dargs_inplace), Mooncake . CoDual (alg, Mooncake . NoFData ()))
58+ if isnothing (alg)
59+ copy_out, copy_pb!! = rrule (Mooncake . CoDual (f_c, Mooncake . NoFData ()), Mooncake . CoDual (Ac, dA_copy), Mooncake . CoDual (args, dargs_copy) )
60+ inplace_sig = Tuple{ typeof (f!), typeof (A), typeof (args)})
61+ if hasmethod (Mooncake . rrule!!, inplace_sig) # hand written rule
62+ inplace_out, inplace_pb!! = Mooncake . rrule!! (Mooncake . CoDual (f!, Mooncake . NoFData ()), Mooncake . CoDual (Ac, dA_inplace ), Mooncake . CoDual (args, dargs_inplace))
63+ else
64+ inplace_rrule = Mooncake. build_rrule (rvs_interp , inplace_sig)
65+ inplace_out, inplace_pb!! = inplace_rrule (Mooncake . CoDual (f!, Mooncake . NoFData ()), Mooncake . CoDual (Ac, dA_inplace), Mooncake . CoDual (args, dargs_inplace))
66+ end
6767 else
68- rvs_interp = Mooncake. get_interpreter (Mooncake. ReverseMode)
69- inplace_rrule = Mooncake. build_rrule (rvs_interp, inplace_sig)
70- inplace_out, inplace_pb!! = isnothing (alg) ? inplace_rrule (Mooncake. CoDual (f!, Mooncake. NoFData ()), Mooncake. CoDual (Ac, dA_inplace), Mooncake. CoDual (args, dargs_inplace)) : inplace_rrule (Mooncake. CoDual (f!, Mooncake. NoFData ()), Mooncake. CoDual (Ac, dA_inplace), Mooncake. CoDual (args, dargs_inplace), Mooncake. CoDual (alg, Mooncake. NoFData ()))
68+ copy_out, copy_pb!! = rrule (Mooncake. CoDual (f_c, Mooncake. NoFData ()), Mooncake. CoDual (Ac, dA_copy), Mooncake. CoDual (args, dargs_copy), Mooncake. CoDual (alg, Mooncake. NoFData ()))
69+ inplace_sig = Tuple{typeof (f!), typeof (A), typeof (args), typeof (alg)}
70+ if hasmethod (Mooncake. rrule!!, inplace_sig) # hand written rule
71+ inplace_out, inplace_pb!! = Mooncake. rrule!! (Mooncake. CoDual (f!, Mooncake. NoFData ()), Mooncake. CoDual (Ac, dA_inplace), Mooncake. CoDual (args, dargs_inplace), Mooncake. CoDual (alg, Mooncake. NoFData ()))
72+ else
73+ inplace_rrule = Mooncake. build_rrule (rvs_interp, inplace_sig)
74+ inplace_out, inplace_pb!! = inplace_rrule (Mooncake. CoDual (f!, Mooncake. NoFData ()), Mooncake. CoDual (Ac, dA_inplace), Mooncake. CoDual (args, dargs_inplace), Mooncake. CoDual (alg, Mooncake. NoFData ()))
75+ end
7176 end
77+
78+ copy_pb!! (rdata)
7279 inplace_pb!! (rdata)
7380
7481 dA_inplace_ = Mooncake. arrayify (A, dA_inplace)[2 ]
0 commit comments