Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions model/orbax/experimental/model/voxel2obm/main_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,12 @@
from collections.abc import Mapping
import os
import re
from typing import Callable
from typing import Callable, Any

from orbax.experimental.model import core as obm
from orbax.experimental.model.core.python import file_utils
from orbax.experimental.model.voxel2obm import voxel_asset_map_pb2

from .learning.brain.experimental import jax_data as jd


VOXEL_PROCESSOR_MIME_TYPE = 'application/protobuf; type=voxel.PlanProto'
VOXEL_PROCESSOR_VERSION = '0.0.1'
Expand All @@ -37,7 +35,7 @@


def voxel_plan_to_obm(
voxel_module: jd.AbstractVoxelModule,
voxel_module: Any,
input_signature: obm.Tree[obm.ShloTensorSpec],
output_signature: obm.Tree[obm.ShloTensorSpec],
subfolder: str = DEFAULT_VOXEL_MODULE_FOLDER,
Expand Down Expand Up @@ -211,7 +209,7 @@ def _asset_map_to_obm_supplemental(


def voxel_global_supplemental_closure(
voxel_module: jd.AbstractVoxelModule,
voxel_module: Any,
) -> Callable[[str], Mapping[str, obm.GlobalSupplemental]] | None:
"""Returns a closure for saving Voxel assets and creating supplemental data.

Expand Down
23 changes: 4 additions & 19 deletions model/orbax/experimental/model/voxel2obm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@
"""Utilities for converting Voxel signatures to OBM."""

import pprint
from typing import Any
import jax
import numpy as np
from orbax.experimental.model import core as obm
from .learning.brain.experimental import jax_data as jd

VoxelSignatureTree = dict[str, Any]


def _obm_to_voxel_dtype(t):
Expand All @@ -27,31 +29,14 @@ def _obm_to_voxel_dtype(t):
return t


def obm_spec_to_voxel_signature(
spec: obm.Tree[obm.ShloTensorSpec],
) -> jd.VoxelSchemaTree:
try:
return jax.tree_util.tree_map(
lambda x: jd.VoxelTensorSpec(
shape=x.shape, dtype=obm.shlo_dtype_to_np_dtype(x.dtype)
),
spec,
)
except Exception as err:
raise ValueError(
'Failed to convert OBM spec of type'
f' {type(spec)} to Voxel:\n{pprint.pformat(spec)}'
) from err


def _voxel_to_obm_dtype(t) -> obm.ShloDType:
if not isinstance(t, np.dtype):
raise ValueError(f'Expected a numpy.dtype, got {t!r} of type {type(t)}')
return obm.np_dtype_to_shlo_dtype(t)


def voxel_signature_to_obm_spec(
signature: jd.VoxelSchemaTree,
signature: VoxelSignatureTree,
) -> obm.Tree[obm.ShloTensorSpec]:
try:
return jax.tree_util.tree_map(
Expand Down
33 changes: 9 additions & 24 deletions model/orbax/experimental/model/voxel2obm/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,42 +14,27 @@

"""Tests for utils."""

import dataclasses

from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
from orbax.experimental.model import core as obm
from orbax.experimental.model.voxel2obm import utils
from .learning.brain.experimental import jax_data as jd


class UtilsTest(parameterized.TestCase):
@dataclasses.dataclass(frozen=True)
class MockVoxelSignature:
dtype: np.dtype
shape: tuple[int, ...]

def test_obm_spec_to_voxel_signature(self):
obm_spec = {
'a': obm.ShloTensorSpec(shape=(1, 2), dtype=obm.ShloDType.i32),
'b': obm.ShloTensorSpec(shape=(3,), dtype=obm.ShloDType.f32),
}
voxel_sig = utils.obm_spec_to_voxel_signature(obm_spec)
expected_voxel_sig = {
'a': jd.VoxelTensorSpec(
shape=(1, 2), dtype=np.dtype(np.int32)
),
'b': jd.VoxelTensorSpec(
shape=(3,), dtype=np.dtype(np.float32)
),
}

self.assertEqual(voxel_sig['a'], expected_voxel_sig['a'])
self.assertEqual(voxel_sig['b'], expected_voxel_sig['b'])
class UtilsTest(parameterized.TestCase):

def test_voxel_signature_to_obm_spec(self):
voxel_sig = {
'a': jd.VoxelTensorSpec(
shape=(1, 2), dtype=np.dtype(np.int32)
),
'b': jd.VoxelTensorSpec(
shape=(3,), dtype=np.dtype(np.float32)
),
'a': MockVoxelSignature(shape=(1, 2), dtype=np.dtype(np.int32)),
'b': MockVoxelSignature(shape=(3,), dtype=np.dtype(np.float32)),
}
obm_spec = utils.voxel_signature_to_obm_spec(voxel_sig)
expected_obm_spec = {
Expand Down
Loading