diff --git a/Project.toml b/Project.toml index 114038e21..6676b82b2 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRulesCore" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "1.25.1" +version = "1.25.2" [deps] Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" diff --git a/src/projection.jl b/src/projection.jl index e4ed4d8dc..ba1696fd5 100644 --- a/src/projection.jl +++ b/src/projection.jl @@ -128,6 +128,8 @@ ProjectTo(::Any) = identity ProjectTo(::AbstractZero) = ProjectTo{NoTangent}() # Any x::Zero in forward pass makes this one projector, (::ProjectTo{NoTangent})(dx) = NoTangent() # but this is the projection only for nonzero gradients, (::ProjectTo{NoTangent})(dx::AbstractZero) = dx # and this one solves an ambiguity. +(::ProjectTo{NoTangent})(::InplaceableThunk) = NoTangent() # solves ambiguity, #685 +(::ProjectTo{NoTangent})(::Thunk) = NoTangent() # solves ambiguity, #685 # Also, any explicit construction with fields, where all fields project to zero, itself # projects to zero. This simplifies projectors for wrapper types like Diagonal([true, false]). @@ -277,7 +279,7 @@ end # but as `Ref{Any}((x=val,))`. Here we use a Tangent, there is at present no mutable version, but see # https://github.com/JuliaDiff/ChainRulesCore.jl/issues/105 function ProjectTo(x::Ref) - sub = ProjectTo(x[]) # should we worry about isdefined(Ref{Vector{Int}}(), :x)? + sub = ProjectTo(x[]) # should we worry about isdefined(Ref{Vector{Int}}(), :x)? return ProjectTo{Tangent{typeof(x)}}(; x=sub) end (project::ProjectTo{<:Tangent{<:Ref}})(dx::Tangent) = project(Ref(first(backing(dx)))) diff --git a/test/projection.jl b/test/projection.jl index 8cae1802a..052bd67f5 100644 --- a/test/projection.jl +++ b/test/projection.jl @@ -80,7 +80,7 @@ struct NoSuperType end prow = ProjectTo([1im 2 3im]) @test prow(transpose([1, 2, 3 + 4.0im])) == [1 2 3 + 4im] - @test prow(transpose([1, 2, 3 + 4.0im])) isa Matrix # row vectors may not pass through + @test prow(transpose([1, 2, 3 + 4.0im])) isa Matrix # row vectors may not pass through @test prow(adjoint([1, 2, 3 + 5im])) == [1 2 3 - 5im] @test prow(adjoint([1, 2, 3])) isa Matrix @@ -145,7 +145,7 @@ struct NoSuperType end @test ProjectTo(Ref(true)) isa ProjectTo{NoTangent} @test ProjectTo(Ref([false]')) isa ProjectTo{NoTangent} - + @test ProjectTo(Ref(1.0))(Ref(NoTangent())) === NoTangent() # collapse all-zero end @@ -376,7 +376,7 @@ struct NoSuperType end pvec3 = ProjectTo([1, 2, 3]) @test axes(pvec3(OffsetArray(rand(3), 0:2))) == (1:3,) - @test pvec3(OffsetArray(rand(3), 0:2)) isa Vector # relies on axes === axes test + @test pvec3(OffsetArray(rand(3), 0:2)) isa Vector # relies on axes === axes test @test pvec3(OffsetArray(rand(3,1), 0:2, 0:0)) isa Vector end @@ -463,4 +463,12 @@ struct NoSuperType end psymm = ProjectTo(Symmetric(rand(10^3, 10^3))) @test_broken 0 == @ballocated $psymm(dx) setup = (dx = Symmetric(rand(10^3, 10^3))) # 64 end + + @testset "#685" begin + @test ProjectTo(BitArray([0]))([1.0]) == NoTangent() + @test ProjectTo(BitArray([0]))(@thunk [1.0]) == NoTangent() + + it = InplaceableThunk(x -> x + [1], @thunk [1.0]) + @test ProjectTo(BitArray([0]))(it) == NoTangent() + end end