@@ -41,38 +41,76 @@ The arguments to this function are:
4141 - `alg` optional algorithm keyword argument
4242 - `rdata` Mooncake reverse data to supply to the pullback, in case `f` and `f!` return scalar results (as truncating functions do)
4343"""
44- function test_pullbacks_match (rng, f!, f, A, args, Δargs, alg = nothing ; rdata = Mooncake. NoRData ())
45- f_c = isnothing (alg) ? (A, args) -> f! (MatrixAlgebraKit. copy_input (f, A), args) : (A, args, alg) -> f! (MatrixAlgebraKit. copy_input (f, A), args, alg)
46- sig = isnothing (alg) ? Tuple{typeof (f_c), typeof (A), typeof (args)} : Tuple{typeof (f_c), typeof (A), typeof (args), typeof (alg)}
47- rvs_interp = Mooncake. get_interpreter (Mooncake. ReverseMode)
48- rrule = Mooncake. build_rrule (rvs_interp, sig)
49- Ac = copy (A)
50- ΔA = randn (rng, eltype (A), size (A))
51- dA_copy = make_mooncake_tangent (copy (ΔA))
52- dA_inplace = make_mooncake_tangent (copy (ΔA))
53-
44+ # no `alg` argument
45+ function _get_copying_derivative (f_c, rrule, A, ΔA, args, Δargs, :: Nothing , rdata)
46+ dA_copy = make_mooncake_tangent (copy (ΔA))
47+ A_copy = copy (A)
5448 dargs_copy = Δargs isa Tuple ? make_mooncake_fdata .(deepcopy (Δargs)) : make_mooncake_fdata (deepcopy (Δargs))
49+ copy_out, copy_pb!! = rrule (Mooncake. CoDual (f_c, Mooncake. NoFData ()), Mooncake. CoDual (A_copy, dA_copy), Mooncake. CoDual (args, dargs_copy))
50+ copy_pb!! (rdata)
51+ return dA_copy
52+ end
5553
56- dargs_inplace = Δargs isa Tuple ? make_mooncake_fdata .(deepcopy (Δargs)) : make_mooncake_fdata (deepcopy (Δargs))
57-
58- copy_out, copy_pb!! = isnothing (alg) ? rrule (Mooncake. CoDual (f_c, Mooncake. NoFData ()), Mooncake. CoDual (Ac, dA_copy), Mooncake. CoDual (args, dargs_copy)) : rrule (Mooncake. CoDual (f_c, Mooncake. NoFData ()), Mooncake. CoDual (Ac, dA_copy), Mooncake. CoDual (args, dargs_copy), Mooncake. CoDual (alg, Mooncake. NoFData ()))
54+ # `alg` argument
55+ function _get_copying_derivative (f_c, rrule, A, ΔA, args, Δargs, alg, rdata)
56+ dA_copy = make_mooncake_tangent (copy (ΔA))
57+ A_copy = copy (A)
58+ dargs_copy = Δargs isa Tuple ? make_mooncake_fdata .(deepcopy (Δargs)) : make_mooncake_fdata (deepcopy (Δargs))
59+ copy_out, copy_pb!! = rrule (Mooncake. CoDual (f_c, Mooncake. NoFData ()), Mooncake. CoDual (A_copy, dA_copy), Mooncake. CoDual (args, dargs_copy), Mooncake. CoDual (alg, Mooncake. NoFData ()))
5960 copy_pb!! (rdata)
61+ return dA_copy
62+ end
6063
64+ function _get_inplace_derivative (f!, A, ΔA, args, Δargs, :: Nothing , rdata)
65+ dA_inplace = make_mooncake_tangent (copy (ΔA))
66+ A_inplace = copy (A)
67+ dargs_inplace = Δargs isa Tuple ? make_mooncake_fdata .(deepcopy (Δargs)) : make_mooncake_fdata (deepcopy (Δargs))
6168 # not every f! has a handwritten rrule!!
62- inplace_sig = isnothing (alg) ? Tuple{typeof (f!), typeof (A), typeof (args)} : Tuple{typeof (f!), typeof (A), typeof (args), typeof (alg)}
69+ inplace_sig = Tuple{typeof (f!), typeof (A), typeof (args)}
70+ has_handwritten_rule = hasmethod (Mooncake. rrule!!, inplace_sig)
71+ if has_handwritten_rule
72+ inplace_out, inplace_pb!! = Mooncake. rrule!! (Mooncake. CoDual (f!, Mooncake. NoFData ()), Mooncake. CoDual (A_inplace, dA_inplace), Mooncake. CoDual (args, dargs_inplace))
73+ else
74+ inplace_sig = Tuple{typeof (f!), typeof (A), typeof (args)}
75+ rvs_interp = Mooncake. get_interpreter (Mooncake. ReverseMode)
76+ inplace_rrule = Mooncake. build_rrule (rvs_interp, inplace_sig)
77+ inplace_out, inplace_pb!! = inplace_rrule (Mooncake. CoDual (f!, Mooncake. NoFData ()), Mooncake. CoDual (A_inplace, dA_inplace), Mooncake. CoDual (args, dargs_inplace))
78+ end
79+ inplace_pb!! (rdata)
80+ return dA_inplace
81+ end
6382
83+ function _get_inplace_derivative (f!, A, ΔA, args, Δargs, alg, rdata)
84+ dA_inplace = make_mooncake_tangent (copy (ΔA))
85+ A_inplace = copy (A)
86+ dargs_inplace = Δargs isa Tuple ? make_mooncake_fdata .(deepcopy (Δargs)) : make_mooncake_fdata (deepcopy (Δargs))
87+ # not every f! has a handwritten rrule!!
88+ inplace_sig = Tuple{typeof (f!), typeof (A), typeof (args), typeof (alg)}
6489 has_handwritten_rule = hasmethod (Mooncake. rrule!!, inplace_sig)
6590 if has_handwritten_rule
66- inplace_out, inplace_pb!! = isnothing (alg) ? Mooncake. rrule!! (Mooncake. CoDual (f!, Mooncake. NoFData ()), Mooncake. CoDual (Ac, dA_inplace), Mooncake . CoDual (args, dargs_inplace)) : Mooncake . rrule!! (Mooncake . CoDual (f!, Mooncake . NoFData ()), Mooncake . CoDual (Ac , dA_inplace), Mooncake. CoDual (args, dargs_inplace), Mooncake. CoDual (alg, Mooncake. NoFData ()))
91+ inplace_out, inplace_pb!! = Mooncake. rrule!! (Mooncake. CoDual (f!, Mooncake. NoFData ()), Mooncake. CoDual (A_inplace , dA_inplace), Mooncake. CoDual (args, dargs_inplace), Mooncake. CoDual (alg, Mooncake. NoFData ()))
6792 else
68- rvs_interp = Mooncake. get_interpreter (Mooncake. ReverseMode)
93+ inplace_sig = Tuple{typeof (f!), typeof (A), typeof (args), typeof (alg)}
94+ rvs_interp = Mooncake. get_interpreter (Mooncake. ReverseMode)
6995 inplace_rrule = Mooncake. build_rrule (rvs_interp, inplace_sig)
70- inplace_out, inplace_pb!! = isnothing (alg) ? inplace_rrule (Mooncake. CoDual (f!, Mooncake. NoFData ()), Mooncake. CoDual (Ac, dA_inplace), Mooncake . CoDual (args, dargs_inplace)) : inplace_rrule (Mooncake . CoDual (f!, Mooncake . NoFData ()), Mooncake . CoDual (Ac , dA_inplace), Mooncake. CoDual (args, dargs_inplace), Mooncake. CoDual (alg, Mooncake. NoFData ()))
96+ inplace_out, inplace_pb!! = inplace_rrule (Mooncake. CoDual (f!, Mooncake. NoFData ()), Mooncake. CoDual (A_inplace , dA_inplace), Mooncake. CoDual (args, dargs_inplace), Mooncake. CoDual (alg, Mooncake. NoFData ()))
7197 end
7298 inplace_pb!! (rdata)
99+ return dA_inplace
100+ end
101+
102+ function test_pullbacks_match (rng, f!, f, A, args, Δargs, alg = nothing ; rdata = Mooncake. NoRData ())
103+ f_c = isnothing (alg) ? (A, args) -> f! (MatrixAlgebraKit. copy_input (f, A), args) : (A, args, alg) -> f! (MatrixAlgebraKit. copy_input (f, A), args, alg)
104+ sig = isnothing (alg) ? Tuple{typeof (f_c), typeof (A), typeof (args)} : Tuple{typeof (f_c), typeof (A), typeof (args), typeof (alg)}
105+ rvs_interp = Mooncake. get_interpreter (Mooncake. ReverseMode)
106+ rrule = Mooncake. build_rrule (rvs_interp, sig)
107+ ΔA = randn (rng, eltype (A), size (A))
108+
109+ dA_copy = _get_copying_derivative (f_c, rrule, A, ΔA, args, Δargs, alg, rdata)
110+ dA_inplace = _get_inplace_derivative (f!, A, ΔA, args, Δargs, alg, rdata)
73111
74112 dA_inplace_ = Mooncake. arrayify (A, dA_inplace)[2 ]
75- dA_copy_ = Mooncake. arrayify (A, dA_copy)[2 ]
113+ dA_copy_ = Mooncake. arrayify (A, dA_copy)[2 ]
76114 @test dA_inplace_ ≈ dA_copy_
77115 return
78116end
0 commit comments