Skip to content

Commit 580555b

Browse files
committed
Nicer test_pullbacks_match
1 parent 0717789 commit 580555b

File tree

1 file changed

+56
-18
lines changed

1 file changed

+56
-18
lines changed

test/mooncake.jl

Lines changed: 56 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -41,38 +41,76 @@ The arguments to this function are:
4141
- `alg` optional algorithm keyword argument
4242
- `rdata` Mooncake reverse data to supply to the pullback, in case `f` and `f!` return scalar results (as truncating functions do)
4343
"""
44-
function test_pullbacks_match(rng, f!, f, A, args, Δargs, alg = nothing; rdata = Mooncake.NoRData())
45-
f_c = isnothing(alg) ? (A, args) -> f!(MatrixAlgebraKit.copy_input(f, A), args) : (A, args, alg) -> f!(MatrixAlgebraKit.copy_input(f, A), args, alg)
46-
sig = isnothing(alg) ? Tuple{typeof(f_c), typeof(A), typeof(args)} : Tuple{typeof(f_c), typeof(A), typeof(args), typeof(alg)}
47-
rvs_interp = Mooncake.get_interpreter(Mooncake.ReverseMode)
48-
rrule = Mooncake.build_rrule(rvs_interp, sig)
49-
Ac = copy(A)
50-
ΔA = randn(rng, eltype(A), size(A))
51-
dA_copy = make_mooncake_tangent(copy(ΔA))
52-
dA_inplace = make_mooncake_tangent(copy(ΔA))
53-
44+
# no `alg` argument
45+
function _get_copying_derivative(f_c, rrule, A, ΔA, args, Δargs, ::Nothing, rdata)
46+
dA_copy = make_mooncake_tangent(copy(ΔA))
47+
A_copy = copy(A)
5448
dargs_copy = Δargs isa Tuple ? make_mooncake_fdata.(deepcopy(Δargs)) : make_mooncake_fdata(deepcopy(Δargs))
49+
copy_out, copy_pb!! = rrule(Mooncake.CoDual(f_c, Mooncake.NoFData()), Mooncake.CoDual(A_copy, dA_copy), Mooncake.CoDual(args, dargs_copy))
50+
copy_pb!!(rdata)
51+
return dA_copy
52+
end
5553

56-
dargs_inplace = Δargs isa Tuple ? make_mooncake_fdata.(deepcopy(Δargs)) : make_mooncake_fdata(deepcopy(Δargs))
57-
58-
copy_out, copy_pb!! = isnothing(alg) ? rrule(Mooncake.CoDual(f_c, Mooncake.NoFData()), Mooncake.CoDual(Ac, dA_copy), Mooncake.CoDual(args, dargs_copy)) : rrule(Mooncake.CoDual(f_c, Mooncake.NoFData()), Mooncake.CoDual(Ac, dA_copy), Mooncake.CoDual(args, dargs_copy), Mooncake.CoDual(alg, Mooncake.NoFData()))
54+
# `alg` argument
55+
function _get_copying_derivative(f_c, rrule, A, ΔA, args, Δargs, alg, rdata)
56+
dA_copy = make_mooncake_tangent(copy(ΔA))
57+
A_copy = copy(A)
58+
dargs_copy = Δargs isa Tuple ? make_mooncake_fdata.(deepcopy(Δargs)) : make_mooncake_fdata(deepcopy(Δargs))
59+
copy_out, copy_pb!! = rrule(Mooncake.CoDual(f_c, Mooncake.NoFData()), Mooncake.CoDual(A_copy, dA_copy), Mooncake.CoDual(args, dargs_copy), Mooncake.CoDual(alg, Mooncake.NoFData()))
5960
copy_pb!!(rdata)
61+
return dA_copy
62+
end
6063

64+
function _get_inplace_derivative(f!, A, ΔA, args, Δargs, ::Nothing, rdata)
65+
dA_inplace = make_mooncake_tangent(copy(ΔA))
66+
A_inplace = copy(A)
67+
dargs_inplace = Δargs isa Tuple ? make_mooncake_fdata.(deepcopy(Δargs)) : make_mooncake_fdata(deepcopy(Δargs))
6168
# not every f! has a handwritten rrule!!
62-
inplace_sig = isnothing(alg) ? Tuple{typeof(f!), typeof(A), typeof(args)} : Tuple{typeof(f!), typeof(A), typeof(args), typeof(alg)}
69+
inplace_sig = Tuple{typeof(f!), typeof(A), typeof(args)}
70+
has_handwritten_rule = hasmethod(Mooncake.rrule!!, inplace_sig)
71+
if has_handwritten_rule
72+
inplace_out, inplace_pb!! = Mooncake.rrule!!(Mooncake.CoDual(f!, Mooncake.NoFData()), Mooncake.CoDual(A_inplace, dA_inplace), Mooncake.CoDual(args, dargs_inplace))
73+
else
74+
inplace_sig = Tuple{typeof(f!), typeof(A), typeof(args)}
75+
rvs_interp = Mooncake.get_interpreter(Mooncake.ReverseMode)
76+
inplace_rrule = Mooncake.build_rrule(rvs_interp, inplace_sig)
77+
inplace_out, inplace_pb!! = inplace_rrule(Mooncake.CoDual(f!, Mooncake.NoFData()), Mooncake.CoDual(A_inplace, dA_inplace), Mooncake.CoDual(args, dargs_inplace))
78+
end
79+
inplace_pb!!(rdata)
80+
return dA_inplace
81+
end
6382

83+
function _get_inplace_derivative(f!, A, ΔA, args, Δargs, alg, rdata)
84+
dA_inplace = make_mooncake_tangent(copy(ΔA))
85+
A_inplace = copy(A)
86+
dargs_inplace = Δargs isa Tuple ? make_mooncake_fdata.(deepcopy(Δargs)) : make_mooncake_fdata(deepcopy(Δargs))
87+
# not every f! has a handwritten rrule!!
88+
inplace_sig = Tuple{typeof(f!), typeof(A), typeof(args), typeof(alg)}
6489
has_handwritten_rule = hasmethod(Mooncake.rrule!!, inplace_sig)
6590
if has_handwritten_rule
66-
inplace_out, inplace_pb!! = isnothing(alg) ? Mooncake.rrule!!(Mooncake.CoDual(f!, Mooncake.NoFData()), Mooncake.CoDual(Ac, dA_inplace), Mooncake.CoDual(args, dargs_inplace)) : Mooncake.rrule!!(Mooncake.CoDual(f!, Mooncake.NoFData()), Mooncake.CoDual(Ac, dA_inplace), Mooncake.CoDual(args, dargs_inplace), Mooncake.CoDual(alg, Mooncake.NoFData()))
91+
inplace_out, inplace_pb!! = Mooncake.rrule!!(Mooncake.CoDual(f!, Mooncake.NoFData()), Mooncake.CoDual(A_inplace, dA_inplace), Mooncake.CoDual(args, dargs_inplace), Mooncake.CoDual(alg, Mooncake.NoFData()))
6792
else
68-
rvs_interp = Mooncake.get_interpreter(Mooncake.ReverseMode)
93+
inplace_sig = Tuple{typeof(f!), typeof(A), typeof(args), typeof(alg)}
94+
rvs_interp = Mooncake.get_interpreter(Mooncake.ReverseMode)
6995
inplace_rrule = Mooncake.build_rrule(rvs_interp, inplace_sig)
70-
inplace_out, inplace_pb!! = isnothing(alg) ? inplace_rrule(Mooncake.CoDual(f!, Mooncake.NoFData()), Mooncake.CoDual(Ac, dA_inplace), Mooncake.CoDual(args, dargs_inplace)) : inplace_rrule(Mooncake.CoDual(f!, Mooncake.NoFData()), Mooncake.CoDual(Ac, dA_inplace), Mooncake.CoDual(args, dargs_inplace), Mooncake.CoDual(alg, Mooncake.NoFData()))
96+
inplace_out, inplace_pb!! = inplace_rrule(Mooncake.CoDual(f!, Mooncake.NoFData()), Mooncake.CoDual(A_inplace, dA_inplace), Mooncake.CoDual(args, dargs_inplace), Mooncake.CoDual(alg, Mooncake.NoFData()))
7197
end
7298
inplace_pb!!(rdata)
99+
return dA_inplace
100+
end
101+
102+
function test_pullbacks_match(rng, f!, f, A, args, Δargs, alg = nothing; rdata = Mooncake.NoRData())
103+
f_c = isnothing(alg) ? (A, args) -> f!(MatrixAlgebraKit.copy_input(f, A), args) : (A, args, alg) -> f!(MatrixAlgebraKit.copy_input(f, A), args, alg)
104+
sig = isnothing(alg) ? Tuple{typeof(f_c), typeof(A), typeof(args)} : Tuple{typeof(f_c), typeof(A), typeof(args), typeof(alg)}
105+
rvs_interp = Mooncake.get_interpreter(Mooncake.ReverseMode)
106+
rrule = Mooncake.build_rrule(rvs_interp, sig)
107+
ΔA = randn(rng, eltype(A), size(A))
108+
109+
dA_copy = _get_copying_derivative(f_c, rrule, A, ΔA, args, Δargs, alg, rdata)
110+
dA_inplace = _get_inplace_derivative(f!, A, ΔA, args, Δargs, alg, rdata)
73111

74112
dA_inplace_ = Mooncake.arrayify(A, dA_inplace)[2]
75-
dA_copy_ = Mooncake.arrayify(A, dA_copy)[2]
113+
dA_copy_ = Mooncake.arrayify(A, dA_copy)[2]
76114
@test dA_inplace_ dA_copy_
77115
return
78116
end

0 commit comments

Comments
 (0)