diff --git a/src/rulesets/Base/indexing.jl b/src/rulesets/Base/indexing.jl index 7e1befd14..1334cc925 100644 --- a/src/rulesets/Base/indexing.jl +++ b/src/rulesets/Base/indexing.jl @@ -1,14 +1,6 @@ -##### -##### getindex(::Tuple) -##### - -function frule((_, ẋ), ::typeof(getindex), x::Tuple, i::Integer) - return x[i], ẋ[i] -end - -function frule((_, ẋ), ::typeof(getindex), x::Tuple, i) - y = x[i] - return y, Tangent{typeof(y)}(ẋ[i]...) +# Int rather than Int64/Integer is intentional +function frule((_, ẋ), ::typeof(getfield), x::Tuple, i::Int) + return x.i, ẋ.i end "for a given tuple type, returns a Val{N} where N is the length of the tuple" @@ -77,7 +69,7 @@ end """ ∇getindex(x, dy, inds...) -For the `rrule` of `y = x[inds...]`, this function is roughly +For the `rrule` of `y = x[inds...]`, this function is roughly `setindex(zero(x), dy, inds...)`, returning the array `dx`. Differentiable. Includes `ProjectTo(x)(dx)`. """ @@ -191,29 +183,6 @@ function ∇getindex!(dx::AbstractGPUArray, dy, inds...) return dx end -##### -##### first, tail -##### - -function frule((_, ẋ), ::typeof(first), x::Tuple) - return first(x), first(ẋ) -end - -function rrule(::typeof(first), x::T) where {T<:Tuple} - first_back(dy) = (NoTangent(), Tangent{T}(ntuple(j -> j == 1 ? dy : NoTangent(), _tuple_N(T))...)) - return first(x), first_back -end - -function frule((_, ẋ), ::typeof(Base.tail), x::Tuple) - y = Base.tail(x) - return y, Tangent{typeof(y)}(Base.tail(ẋ)...) -end - -function rrule(::typeof(Base.tail), x::T) where {T<:Tuple} - tail_pullback(dy) = (NoTangent(), Tangent{T}(NoTangent(), dy...)) - return Base.tail(x), tail_pullback -end - ##### ##### view ##### diff --git a/test/rulesets/Base/indexing.jl b/test/rulesets/Base/indexing.jl index 8928c55e7..d3c7ecfb4 100644 --- a/test/rulesets/Base/indexing.jl +++ b/test/rulesets/Base/indexing.jl @@ -3,12 +3,7 @@ x = (1.2, 3.4, 5.6) x2 = (rand(2), (a=1.0, b=x)) - # Forward - test_frule(getindex, x, 2) - test_frule(getindex, x2, 1) - test_frule(getindex, x, 1:2) - test_frule(getindex, x2, :) - + # don't test Forward because this will be handled by lowering to getfield # Reverse test_rrule(getindex, x, 2) @test_skip test_rrule(getindex, x2, 1, check_inferred=false) # method ambiguity, maybe fixed by https://github.com/JuliaDiff/ChainRulesTestUtils.jl/pull/253 @@ -168,22 +163,7 @@ end end -@testset "first & tail" begin - x = (1.2, 3.4, 5.6) - x2 = (rand(2), (a=1.0, b=x)) - - test_frule(first, x) - test_frule(first, x2) - - test_rrule(first, x) - # test_rrule(first, x2) # MethodError: (::ChainRulesTestUtils.var"#test_approx##kw")(::NamedTuple{(:rtol, :atol), Tuple{Float64, Float64}}, ::typeof(test_approx), ::NoTangent, ::Tangent{NamedTuple{(:a, :b), Tuple{Float64, Tuple{Float64, Float64, Float64}}}, NamedTuple{(:a, :b), Tuple{Float64, Tangent{Tuple{Float64, Float64, Float64}, Tuple{Float64, Float64, Float64}}}}}, ::String) is ambiguous - - test_frule(Base.tail, x, check_inferred=false) # return type Tuple{Tuple{Float64, Float64}, Tangent{Tuple{Float64, Float64}, Tuple{Float64, Float64}}} does not match inferred return type Tuple{Tuple{Float64, Float64}, Tangent{Tuple{Float64, Float64}}} - test_frule(Base.tail, x2, check_inferred=false) - - test_rrule(Base.tail, x) - test_rrule(Base.tail, x2) -end +# first & tail handled by getfield rules @testset "view" begin test_frule(view, rand(3, 4), :, 1)