Support functions that splat namedtuples as keyword arguments#1059
Support functions that splat namedtuples as keyword arguments#1059
Conversation
| g(; kwargs...) = kwargs[:x] * kwargs[:z] | ||
| h(somedata) = g(; somedata...) | ||
| @test gradient(h, (; x=3.0, y=4.0, z=2.3)) == ((x = 2.3, y = 0.0, z = 3.0),) | ||
| @test gradient(h, Dict(:x=>3.0, :y=>4.0, :z=>2.3)) == ((y = 0.0, z = 3.0, x = 2.3),) |
There was a problem hiding this comment.
I am not sure what type Zygote wants to use to represent Dict.
accum isn't defined for Dict AFAICT.
So NamedTuple seemed reasonable.
There was a problem hiding this comment.
the Dict story is not consistent at the moment and there are many missing features
https://github.com/FluxML/Zygote.jl/issues?q=is%3Aissue+is%3Aopen+label%3Adictionary
I guess generally the returned gradient should be a dictionary, but for dicts with symbol keys maybe a namedtuple is good enough for the time being
DhairyaLGandhi
left a comment
There was a problem hiding this comment.
Thanks for the follow up!
| if VERSION >= v"1.6" | ||
| @adjoint merge(nt::NamedTuple, dict::Dict) = pullback(merge, nt, NamedTuple(dict)) | ||
| else | ||
| @adjoint merge(nt::NamedTuple, dict::Dict) = pullback(merge, nt, (;dict...)) |
There was a problem hiding this comment.
What do you mean?
You mean use if for both 1.6 and pre-1.6?
|
The CUDA failures seem related to CUDA.jl v3.4 -- We should get some kind of fix in Flux/ Zygote or CUDA since |
|
Bump |
|
CI failures are unrelated. |
@lsindoni ran into this.
This problem occurs for Dict and NamedTuples with a error about the methods of
Base.setindex.The
Dictin thepairspullback in this case ends up containingSymbols notIntegers.So we branch to support that.
(unfortunately it is a
Dict{Any,Any}so we can't branch earlier, and so it is dynamic dispatch).The chain to the rule for
pairsis enough to make NamedTuple's work,If we also want to support Dicts, we need the rule for
mergeas wellI am surprised noone has run into this before.