Skip to content

Support functions that splat namedtuples as keyword arguments#1059

Merged
oxinabox merged 3 commits intomasterfrom
ox/kwsplat
Sep 10, 2021
Merged

Support functions that splat namedtuples as keyword arguments#1059
oxinabox merged 3 commits intomasterfrom
ox/kwsplat

Conversation

@oxinabox
Copy link
Copy Markdown
Member

@oxinabox oxinabox commented Sep 2, 2021

@lsindoni ran into this.
This problem occurs for Dict and NamedTuples with a error about the methods of Base.setindex.

The Dict in the pairs pullback in this case ends up containing Symbols not Integers.
So we branch to support that.
(unfortunately it is a Dict{Any,Any} so we can't branch earlier, and so it is dynamic dispatch).

The chain to the rule for pairs is enough to make NamedTuple's work,
If we also want to support Dicts, we need the rule for merge as well

I am surprised noone has run into this before.

Comment thread test/features.jl
g(; kwargs...) = kwargs[:x] * kwargs[:z]
h(somedata) = g(; somedata...)
@test gradient(h, (; x=3.0, y=4.0, z=2.3)) == ((x = 2.3, y = 0.0, z = 3.0),)
@test gradient(h, Dict(:x=>3.0, :y=>4.0, :z=>2.3)) == ((y = 0.0, z = 3.0, x = 2.3),)
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure what type Zygote wants to use to represent Dict.
accum isn't defined for Dict AFAICT.
So NamedTuple seemed reasonable.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the Dict story is not consistent at the moment and there are many missing features
https://github.com/FluxML/Zygote.jl/issues?q=is%3Aissue+is%3Aopen+label%3Adictionary
I guess generally the returned gradient should be a dictionary, but for dicts with symbol keys maybe a namedtuple is good enough for the time being

Copy link
Copy Markdown
Member

@DhairyaLGandhi DhairyaLGandhi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the follow up!

Comment thread src/lib/base.jl
if VERSION >= v"1.6"
@adjoint merge(nt::NamedTuple, dict::Dict) = pullback(merge, nt, NamedTuple(dict))
else
@adjoint merge(nt::NamedTuple, dict::Dict) = pullback(merge, nt, (;dict...))
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just keep this?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean?
You mean use if for both 1.6 and pre-1.6?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, exactly.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

@DhairyaLGandhi
Copy link
Copy Markdown
Member

The CUDA failures seem related to CUDA.jl v3.4 -- We should get some kind of fix in Flux/ Zygote or CUDA since cu seems to have changed behaviour in a couple places. @vchuravy would you know what needs to be done here?

@oxinabox
Copy link
Copy Markdown
Member Author

Bump

@oxinabox
Copy link
Copy Markdown
Member Author

CI failures are unrelated.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants