Skip to content

Fix nnx tabulate variable hooks#5008

Merged
copybara-service[bot] merged 4 commits intogoogle:mainfrom
mohsinm-dev:fix-nnx-tabulate-variable-hooks
Oct 15, 2025
Merged

Fix nnx tabulate variable hooks#5008
copybara-service[bot] merged 4 commits intogoogle:mainfrom
mohsinm-dev:fix-nnx-tabulate-variable-hooks

Conversation

@mohsinm-dev
Copy link
Contributor

What does this PR do?

Fixes #5005

Summary

  • Prevents nnx.tabulate from crashing with yaml.representer.RepresenterError when a nnx.Variable stores value hooks (e.g.,
    on_set_value) or other non-serializable objects in metadata.
  • Filters out non-serializable callables and modules during YAML serialization in flax/nnx/summary.py::_as_yaml_str.
  • Adds a deterministic fallback to repr(obj) for unknown non-serializable metadata, ensuring tabulate never fails on arbitrary
    custom objects.
  • Adds a focused test that validates the repr fallback appears in the tabulate output.

Checklist

Handle callables in _normalize_values to fix YAML serialization errors
when Variables have hook methods like on_set_value.
Copy link
Collaborator

@vfdev-5 vfdev-5 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the update @mohsinm-dev !
Few other thoughts on the changes.

Also please check the failing test: https://github.com/google/flax/actions/runs/18472444634/job/52629427889?pr=5008

)
return file.getvalue().replace('\n...', '').replace("'", '').strip()
except yaml.representer.RepresenterError:
# Fallback for non-serializable objects not caught by _normalize_values
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I wonder why we can't just do this fallback in _normalize_values:

def _normalize_values(x):
  if isinstance(x, type):
    return f'type[{x.__name__}]'
  elif isinstance(x, ArrayRepr | SimpleObjectRepr):
    return str(x)
  else:
    return repr(x)  #  <--- here 

Less code to maintain.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, _normalize_values is already called via jax.tree.map() on all values before YAML serialization, so we can use repr(x) as the default fallback.

self.assertIsNotNone(table_repr)
# Ensure the table contains expected content
self.assertIn('Model Summary', table_repr)
self.assertIn('param', table_repr)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.assertIn('param', table_repr)
self.assertIn('param', table_repr)
self.assertIn('on_set_value', table_repr)

self.assertIn('Model Summary', table_repr)
self.assertIn('param', table_repr)

def test_tabulate_with_multiple_hooks_and_metadata(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we merge these test cases: test_tabulate_with_multiple_hooks_and_metadata and test_tabulate_with_custom_nonserializable_metadata with the above one: test_tabulate_with_variable_hooks.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think we can just merge it into the one that makes more sense.

Use repr(x) as default fallback in _normalize_values to handle all
non-serializable objects uniformly. Merge related test cases into
comprehensive test_tabulate_with_variable_hooks.
Copy link
Collaborator

@vfdev-5 vfdev-5 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks @mohsinm-dev !

@mohsinm-dev
Copy link
Contributor Author

LGTM, thanks @mohsinm-dev !

Thanks for the guidance and feedback!

Copy link
Collaborator

@cgarciae cgarciae left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @mohsinm-dev and @vfdev-5 !

@copybara-service copybara-service bot merged commit befe571 into google:main Oct 15, 2025
15 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

nnx.tabulate breaks if nnx.Variable implements one of the value hooks

3 participants