Fix nnx tabulate variable hooks#5008
Conversation
Handle callables in _normalize_values to fix YAML serialization errors when Variables have hook methods like on_set_value.
There was a problem hiding this comment.
Thanks for the update @mohsinm-dev !
Few other thoughts on the changes.
Also please check the failing test: https://github.com/google/flax/actions/runs/18472444634/job/52629427889?pr=5008
flax/nnx/summary.py
Outdated
| ) | ||
| return file.getvalue().replace('\n...', '').replace("'", '').strip() | ||
| except yaml.representer.RepresenterError: | ||
| # Fallback for non-serializable objects not caught by _normalize_values |
There was a problem hiding this comment.
Actually, I wonder why we can't just do this fallback in _normalize_values:
def _normalize_values(x):
if isinstance(x, type):
return f'type[{x.__name__}]'
elif isinstance(x, ArrayRepr | SimpleObjectRepr):
return str(x)
else:
return repr(x) # <--- here Less code to maintain.
There was a problem hiding this comment.
Yes, _normalize_values is already called via jax.tree.map() on all values before YAML serialization, so we can use repr(x) as the default fallback.
tests/nnx/summary_test.py
Outdated
| self.assertIsNotNone(table_repr) | ||
| # Ensure the table contains expected content | ||
| self.assertIn('Model Summary', table_repr) | ||
| self.assertIn('param', table_repr) |
There was a problem hiding this comment.
| self.assertIn('param', table_repr) | |
| self.assertIn('param', table_repr) | |
| self.assertIn('on_set_value', table_repr) |
tests/nnx/summary_test.py
Outdated
| self.assertIn('Model Summary', table_repr) | ||
| self.assertIn('param', table_repr) | ||
|
|
||
| def test_tabulate_with_multiple_hooks_and_metadata(self): |
There was a problem hiding this comment.
Can we merge these test cases: test_tabulate_with_multiple_hooks_and_metadata and test_tabulate_with_custom_nonserializable_metadata with the above one: test_tabulate_with_variable_hooks.
There was a problem hiding this comment.
Yes, I think we can just merge it into the one that makes more sense.
Use repr(x) as default fallback in _normalize_values to handle all non-serializable objects uniformly. Merge related test cases into comprehensive test_tabulate_with_variable_hooks.
vfdev-5
left a comment
There was a problem hiding this comment.
LGTM, thanks @mohsinm-dev !
Thanks for the guidance and feedback! |
What does this PR do?
Fixes #5005
Summary
on_set_value) or other non-serializable objects in metadata.
custom objects.
Checklist
nnx.Variableimplements one of the value hooks #5005