Skip to content

Commit 21b9076

Browse files
kshyattJutho
andauthored
Update test/mooncake.jl
Co-authored-by: Jutho <Jutho@users.noreply.github.com>
1 parent 615dfeb commit 21b9076

File tree

1 file changed

+19
-12
lines changed

1 file changed

+19
-12
lines changed

test/mooncake.jl

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -55,20 +55,27 @@ function test_pullbacks_match(rng, f!, f, A, args, Δargs, alg = nothing; rdata
5555

5656
dargs_inplace = Δargs isa Tuple ? make_mooncake_fdata.(deepcopy(Δargs)) : make_mooncake_fdata(deepcopy(Δargs))
5757

58-
copy_out, copy_pb!! = isnothing(alg) ? rrule(Mooncake.CoDual(f_c, Mooncake.NoFData()), Mooncake.CoDual(Ac, dA_copy), Mooncake.CoDual(args, dargs_copy)) : rrule(Mooncake.CoDual(f_c, Mooncake.NoFData()), Mooncake.CoDual(Ac, dA_copy), Mooncake.CoDual(args, dargs_copy), Mooncake.CoDual(alg, Mooncake.NoFData()))
59-
copy_pb!!(rdata)
60-
61-
# not every f! has a handwritten rrule!!
62-
inplace_sig = isnothing(alg) ? Tuple{typeof(f!), typeof(A), typeof(args)} : Tuple{typeof(f!), typeof(A), typeof(args), typeof(alg)}
63-
64-
has_handwritten_rule = hasmethod(Mooncake.rrule!!, inplace_sig)
65-
if has_handwritten_rule
66-
inplace_out, inplace_pb!! = isnothing(alg) ? Mooncake.rrule!!(Mooncake.CoDual(f!, Mooncake.NoFData()), Mooncake.CoDual(Ac, dA_inplace), Mooncake.CoDual(args, dargs_inplace)) : Mooncake.rrule!!(Mooncake.CoDual(f!, Mooncake.NoFData()), Mooncake.CoDual(Ac, dA_inplace), Mooncake.CoDual(args, dargs_inplace), Mooncake.CoDual(alg, Mooncake.NoFData()))
58+
if isnothing(alg)
59+
copy_out, copy_pb!! = rrule(Mooncake.CoDual(f_c, Mooncake.NoFData()), Mooncake.CoDual(Ac, dA_copy), Mooncake.CoDual(args, dargs_copy))
60+
inplace_sig = Tuple{typeof(f!), typeof(A), typeof(args)})
61+
if hasmethod(Mooncake.rrule!!, inplace_sig) # hand written rule
62+
inplace_out, inplace_pb!! = Mooncake.rrule!!(Mooncake.CoDual(f!, Mooncake.NoFData()), Mooncake.CoDual(Ac, dA_inplace), Mooncake.CoDual(args, dargs_inplace))
63+
else
64+
inplace_rrule = Mooncake.build_rrule(rvs_interp, inplace_sig)
65+
inplace_out, inplace_pb!! = inplace_rrule(Mooncake.CoDual(f!, Mooncake.NoFData()), Mooncake.CoDual(Ac, dA_inplace), Mooncake.CoDual(args, dargs_inplace))
66+
end
6767
else
68-
rvs_interp = Mooncake.get_interpreter(Mooncake.ReverseMode)
69-
inplace_rrule = Mooncake.build_rrule(rvs_interp, inplace_sig)
70-
inplace_out, inplace_pb!! = isnothing(alg) ? inplace_rrule(Mooncake.CoDual(f!, Mooncake.NoFData()), Mooncake.CoDual(Ac, dA_inplace), Mooncake.CoDual(args, dargs_inplace)) : inplace_rrule(Mooncake.CoDual(f!, Mooncake.NoFData()), Mooncake.CoDual(Ac, dA_inplace), Mooncake.CoDual(args, dargs_inplace), Mooncake.CoDual(alg, Mooncake.NoFData()))
68+
copy_out, copy_pb!! = rrule(Mooncake.CoDual(f_c, Mooncake.NoFData()), Mooncake.CoDual(Ac, dA_copy), Mooncake.CoDual(args, dargs_copy), Mooncake.CoDual(alg, Mooncake.NoFData()))
69+
inplace_sig = Tuple{typeof(f!), typeof(A), typeof(args), typeof(alg)}
70+
if hasmethod(Mooncake.rrule!!, inplace_sig) # hand written rule
71+
inplace_out, inplace_pb!! = Mooncake.rrule!!(Mooncake.CoDual(f!, Mooncake.NoFData()), Mooncake.CoDual(Ac, dA_inplace), Mooncake.CoDual(args, dargs_inplace), Mooncake.CoDual(alg, Mooncake.NoFData()))
72+
else
73+
inplace_rrule = Mooncake.build_rrule(rvs_interp, inplace_sig)
74+
inplace_out, inplace_pb!! = inplace_rrule(Mooncake.CoDual(f!, Mooncake.NoFData()), Mooncake.CoDual(Ac, dA_inplace), Mooncake.CoDual(args, dargs_inplace), Mooncake.CoDual(alg, Mooncake.NoFData()))
75+
end
7176
end
77+
78+
copy_pb!!(rdata)
7279
inplace_pb!!(rdata)
7380

7481
dA_inplace_ = Mooncake.arrayify(A, dA_inplace)[2]

0 commit comments

Comments
 (0)