Fix nnx.tabulate crash with empty dict/None values (fixes #4889)#4891
Merged
copybara-service[bot] merged 4 commits intogoogle:mainfrom Sep 29, 2025
Merged
Conversation
Fixes google#4889 by handling gaps in sequence indices when JAX tree flattening omits empty containers.
Collaborator
|
Thanks for the fix! A couple comments:
class Model(nnx.Module):
def subroutine(self, foo, x):
return x
def __call__(self, x):
return self.subroutine({}, x)
model = Model()
nnx.tabulate(model, jnp.zeros((1, 8)), depth=1)This breaks using the current version of
class Model(nnx.Module):
def subroutine(self, foo, x):
return x
def __call__(self, x):
return self.subroutine(x, {})
model = Model()
nnx.tabulate(model, jnp.zeros((1, 8)), depth=1)This dropping of the empty dictionary argument also occurs when it is passed as a keyword argument instead of a positional one.
|
samanklesaria
approved these changes
Sep 29, 2025
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Fixes #4889 by handling gaps in sequence indices when JAX tree flattening omits empty containers.
What does this PR do?
This PR fixes a crash in
nnx.tabulatewhen called withdepth >= 1on modules that contain empty dictionaries{}orNonevalues.Problem
When
nnx.tabulateprocesses function arguments that include empty containers (like{}), JAX's tree flattening skips these empty containers completely. This creates gaps in the sequence indices, causing an assertion error in the_unflatten_to_simple_structurefunction.For example, with input
({}, array), JAX produces[((1,), array)]instead of[((0,), {}), ((1,), array)]. The code expected index 0 but got index 1, causingassert 1 == 0to fail.Solution
Replace the problematic assertion with logic that handles index gaps by filling missing positions with
Nonevalues:Test Case
Checklist