Skip to content

Fix nnx.tabulate crash with empty dict/None values (fixes #4889)#4891

Merged
copybara-service[bot] merged 4 commits intogoogle:mainfrom
mohsinm-dev:fix-nnx-tabulate-empty-dict-issue-4889
Sep 29, 2025
Merged

Fix nnx.tabulate crash with empty dict/None values (fixes #4889)#4891
copybara-service[bot] merged 4 commits intogoogle:mainfrom
mohsinm-dev:fix-nnx-tabulate-empty-dict-issue-4889

Conversation

@mohsinm-dev
Copy link
Contributor

Fixes #4889 by handling gaps in sequence indices when JAX tree flattening omits empty containers.

What does this PR do?

This PR fixes a crash in nnx.tabulate when called with depth >= 1 on modules that contain empty dictionaries {} or None values.

Problem

When nnx.tabulate processes function arguments that include empty containers (like {}), JAX's tree flattening skips these empty containers completely. This creates gaps in the sequence indices, causing an assertion error in the _unflatten_to_simple_structure function.

For example, with input ({}, array), JAX produces [((1,), array)] instead of [((0,), {}), ((1,), array)]. The code expected index 0 but got index 1, causing assert 1 == 0 to fail.

Solution

Replace the problematic assertion with logic that handles index gaps by filling missing positions with None values:

# Before (crashed):
assert path[-1] == len(cursor)
cursor.append(value)

# After (works):
while len(cursor) <= path[-1]:
    cursor.append(None) 
cursor[path[-1]] = value

Test Case

class Model(nnx.Module):
    def __init__(self):
        self.foo = {}  # This used to crash tabulate
    def __call__(self, x):
        return x

# This now works (was failing before):
nnx.tabulate(Model(), jnp.zeros((1, 10)), depth=1)

Checklist

  • This PR fixes a minor issue (e.g.: typo or small bug) or improves the docs (you can dismiss the other checks if that's the case).
  • This change is discussed in a Github issue/discussion (please add a link).
  • The documentation and docstrings adhere to the documentation guidelines.
  • This change includes necessary high-coverage tests. (No quality testing = no merge!)

Fixes google#4889 by handling gaps in sequence indices when JAX tree
flattening omits empty containers.
@samanklesaria
Copy link
Collaborator

samanklesaria commented Sep 15, 2025

Thanks for the fix! A couple comments:

  1. The test case you're using above doesn't actually test for the error observed in Edge case in nnx.tabulate when nnx.Module stores empty dictionary #4889 because self.foo is never referred to by the __call__ method (which is all the tabulate function is looking at). The _unflatten_to_simple_structure function never gets called. As written, tabulate works fine on your test case even without your fix. Instead, I'd recommend a test case like the following:
class Model(nnx.Module):
    def subroutine(self, foo, x):
        return x

    def __call__(self, x):
        return self.subroutine({}, x)
        
model = Model()
nnx.tabulate(model, jnp.zeros((1, 8)), depth=1)

This breaks using the current version of tabulate, but works fine with your fix, giving inputs null and float32[1,8] for the call to subroutine.

  1. Your fix only applies when empty dictionary arguments are given before a non-empty argument. If we switch the test case above to the following, the empty input to subroutine is ignored by the call to tabulate.
class Model(nnx.Module):
   def subroutine(self, foo, x):
       return x

   def __call__(self, x):
       return self.subroutine(x, {})
       
model = Model()
nnx.tabulate(model, jnp.zeros((1, 8)), depth=1)

This dropping of the empty dictionary argument also occurs when it is passed as a keyword argument instead of a positional one.

  1. Could you add test cases like these to the tests/nnx/summary_test.py file?

  2. Both None and empty dictionaries get represented as null in the output table. Ideally, we'd see None for None arguments and {} for empty dictionary arguments. Would this kind of behavior be possible?

Copy link
Collaborator

@samanklesaria samanklesaria left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the updates! Besides the nitpicks above, everything looks good to me.

@copybara-service copybara-service bot merged commit f65afdb into google:main Sep 29, 2025
15 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Edge case in nnx.tabulate when nnx.Module stores empty dictionary

3 participants