From 77ceb5140755507c238417013ef1fea753cc3300 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Belmant?= Date: Fri, 21 Feb 2025 10:45:10 -0500 Subject: [PATCH 1/3] Use `_return_type` from Base, not Core.Compiler These used to be the same, but no longer are --- src/rulesets/Base/broadcast.jl | 4 ++-- src/rulesets/Base/mapreduce.jl | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/rulesets/Base/broadcast.jl b/src/rulesets/Base/broadcast.jl index 4fb83c4e7..837863ebf 100644 --- a/src/rulesets/Base/broadcast.jl +++ b/src/rulesets/Base/broadcast.jl @@ -48,7 +48,7 @@ end # Path 2: This is roughly what `derivatives_given_output` is designed for, should be fast. function may_bc_derivatives(::Type{T}, f::F, args::Vararg{Any,N}) where {T,F,N} - TΔ = Core.Compiler._return_type(derivatives_given_output, Tuple{T, F, map(_eltype, args)...}) + TΔ = Base._return_type(derivatives_given_output, Tuple{T, F, map(_eltype, args)...}) return isconcretetype(TΔ) end @@ -98,7 +98,7 @@ function may_bc_forwards(cfg::C, f::F, arg) where {C,F} TA = _eltype(arg) TA <: Real || return false cfg isa RuleConfig{>:HasForwardsMode} && return true # allows frule_via_ad - TF = Core.Compiler._return_type(frule, Tuple{C, Tuple{NoTangent, TA}, F, TA}) + TF = Base._return_type(frule, Tuple{C, Tuple{NoTangent, TA}, F, TA}) return isconcretetype(TF) && TF <: Tuple end diff --git a/src/rulesets/Base/mapreduce.jl b/src/rulesets/Base/mapreduce.jl index f56ffa607..2c59a5a12 100644 --- a/src/rulesets/Base/mapreduce.jl +++ b/src/rulesets/Base/mapreduce.jl @@ -139,7 +139,7 @@ Works by seeing if the result of `derivatives_given_output(nothing, f, x)` can b The method of `derivatives_given_output` usually comes from `@scalar_rule`. """ function _uses_input_only(f::F, ::Type{xT}) where {F,xT} - gT = Core.Compiler._return_type(derivatives_given_output, Tuple{Nothing, F, xT}) + gT = Base._return_type(derivatives_given_output, Tuple{Nothing, F, xT}) # Here we must check `<: Number`, to avoid this, the one rule which can return the `nothing`: # ChainRules.derivatives_given_output("anything", exp, 1) == (("anything",),) return isconcretetype(gT) && gT <: Tuple{Tuple{Number}} From c0108b2709b2fe77f098ce7cf138cb043a9c7c57 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Belmant?= Date: Mon, 24 Feb 2025 07:53:46 -0500 Subject: [PATCH 2/3] Use Core.Compiler.return_type --- src/rulesets/Base/broadcast.jl | 4 ++-- src/rulesets/Base/mapreduce.jl | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/rulesets/Base/broadcast.jl b/src/rulesets/Base/broadcast.jl index 837863ebf..d1610ce24 100644 --- a/src/rulesets/Base/broadcast.jl +++ b/src/rulesets/Base/broadcast.jl @@ -48,7 +48,7 @@ end # Path 2: This is roughly what `derivatives_given_output` is designed for, should be fast. function may_bc_derivatives(::Type{T}, f::F, args::Vararg{Any,N}) where {T,F,N} - TΔ = Base._return_type(derivatives_given_output, Tuple{T, F, map(_eltype, args)...}) + TΔ = Core.Compiler.return_type(derivatives_given_output, Tuple{T, F, map(_eltype, args)...}) return isconcretetype(TΔ) end @@ -98,7 +98,7 @@ function may_bc_forwards(cfg::C, f::F, arg) where {C,F} TA = _eltype(arg) TA <: Real || return false cfg isa RuleConfig{>:HasForwardsMode} && return true # allows frule_via_ad - TF = Base._return_type(frule, Tuple{C, Tuple{NoTangent, TA}, F, TA}) + TF = Core.Compiler.return_type(frule, Tuple{C, Tuple{NoTangent, TA}, F, TA}) return isconcretetype(TF) && TF <: Tuple end diff --git a/src/rulesets/Base/mapreduce.jl b/src/rulesets/Base/mapreduce.jl index 2c59a5a12..999cfbb73 100644 --- a/src/rulesets/Base/mapreduce.jl +++ b/src/rulesets/Base/mapreduce.jl @@ -139,7 +139,7 @@ Works by seeing if the result of `derivatives_given_output(nothing, f, x)` can b The method of `derivatives_given_output` usually comes from `@scalar_rule`. """ function _uses_input_only(f::F, ::Type{xT}) where {F,xT} - gT = Base._return_type(derivatives_given_output, Tuple{Nothing, F, xT}) + gT = Core.Compiler.return_type(derivatives_given_output, Tuple{Nothing, F, xT}) # Here we must check `<: Number`, to avoid this, the one rule which can return the `nothing`: # ChainRules.derivatives_given_output("anything", exp, 1) == (("anything",),) return isconcretetype(gT) && gT <: Tuple{Tuple{Number}} From 332e91db3eb92fcacec615a959a761014a8fef42 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Belmant?= Date: Tue, 25 Feb 2025 12:47:32 -0500 Subject: [PATCH 3/3] Bump patch version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index f1ac45605..9d47625a1 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.72.2" +version = "1.72.3" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"