Skip to content

Commit d89b457

Browse files
authored
Update GDSDataset (#6787)
Fixes #6786 . ### Description - Update rst - Update the type of dtype to str ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: KumoLiu <yunl@nvidia.com>
1 parent 87d0ede commit d89b457

4 files changed

Lines changed: 52 additions & 14 deletions

File tree

.github/workflows/docker.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ jobs:
8989
steps:
9090
- name: Import
9191
run: |
92-
export CUDA_VISIBLE_DEVICES= # cpu-only
92+
export OMP_NUM_THREADS=4 MKL_NUM_THREADS=4 CUDA_VISIBLE_DEVICES= # cpu-only
9393
python -c 'import monai; monai.config.print_debug_info()'
9494
cd /opt/monai
9595
ls -al

docs/source/data.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,13 @@ Generic Interfaces
4545
:members:
4646
:special-members: __getitem__
4747

48+
`GDSDataset`
49+
~~~~~~~~~~~~~~~~~~~
50+
.. autoclass:: GDSDataset
51+
:members:
52+
:special-members: __getitem__
53+
54+
4855
`CacheNTransDataset`
4956
~~~~~~~~~~~~~~~~~~~~
5057
.. autoclass:: CacheNTransDataset

monai/data/dataset.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1521,6 +1521,8 @@ class GDSDataset(PersistentDataset):
15211521
bandwidth while decreasing latency and utilization load on the CPU and GPU.
15221522
15231523
A tutorial is available: https://github.com/Project-MONAI/tutorials/blob/main/modules/GDS_dataset.ipynb.
1524+
1525+
See also: https://github.com/rapidsai/kvikio
15241526
"""
15251527

15261528
def __init__(
@@ -1607,17 +1609,20 @@ def _cachecheck(self, item_transformed):
16071609
return item
16081610
elif isinstance(item_transformed, (np.ndarray, torch.Tensor)):
16091611
_meta = self._load_meta_cache(meta_hash_file_name=f"{hashfile.name}-meta")
1610-
_data = kvikio_numpy.fromfile(f"{hashfile}", dtype=_meta.pop("dtype"), like=cp.empty(()))
1611-
_data = convert_to_tensor(_data.reshape(_meta.pop("shape")), device=f"cuda:{self.device}")
1612-
if bool(_meta):
1612+
_data = kvikio_numpy.fromfile(f"{hashfile}", dtype=_meta["dtype"], like=cp.empty(()))
1613+
_data = convert_to_tensor(_data.reshape(_meta["shape"]), device=f"cuda:{self.device}")
1614+
filtered_keys = list(filter(lambda key: key not in ["dtype", "shape"], _meta.keys()))
1615+
if bool(filtered_keys):
16131616
return (_data, _meta)
16141617
return _data
16151618
else:
16161619
item: list[dict[Any, Any]] = [{} for _ in range(len(item_transformed))] # type:ignore
16171620
for i, _item in enumerate(item_transformed):
16181621
for k in _item:
16191622
meta_i_k = self._load_meta_cache(meta_hash_file_name=f"{hashfile.name}-{k}-meta-{i}")
1620-
item_k = kvikio_numpy.fromfile(f"{hashfile}-{k}-{i}", dtype=np.float32, like=cp.empty(()))
1623+
item_k = kvikio_numpy.fromfile(
1624+
f"{hashfile}-{k}-{i}", dtype=meta_i_k["dtype"], like=cp.empty(())
1625+
)
16211626
item_k = convert_to_tensor(item[i].reshape(meta_i_k["shape"]), device=f"cuda:{self.device}")
16221627
item[i].update({k: item_k, f"{k}_meta_dict": meta_i_k})
16231628
return item
@@ -1653,7 +1658,7 @@ def _create_new_cache(self, data, data_hashfile, meta_hash_file_name):
16531658
if isinstance(_item_transformed_data, torch.Tensor):
16541659
_item_transformed_data = _item_transformed_data.numpy()
16551660
self._meta_cache[meta_hash_file_name]["shape"] = _item_transformed_data.shape
1656-
self._meta_cache[meta_hash_file_name]["dtype"] = _item_transformed_data.dtype
1661+
self._meta_cache[meta_hash_file_name]["dtype"] = str(_item_transformed_data.dtype)
16571662
kvikio_numpy.tofile(_item_transformed_data, data_hashfile)
16581663
try:
16591664
# NOTE: Writing to a temporary directory and then using a nearly atomic rename operation

tests/test_gdsdataset.py

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import unittest
1818

1919
import numpy as np
20+
import torch
2021
from parameterized import parameterized
2122

2223
from monai.data import GDSDataset, json_hashing
@@ -48,6 +49,19 @@
4849

4950
TEST_CASE_3 = [None, (128, 128, 128)]
5051

52+
DTYPES = {
53+
np.dtype(np.uint8): torch.uint8,
54+
np.dtype(np.int8): torch.int8,
55+
np.dtype(np.int16): torch.int16,
56+
np.dtype(np.int32): torch.int32,
57+
np.dtype(np.int64): torch.int64,
58+
np.dtype(np.float16): torch.float16,
59+
np.dtype(np.float32): torch.float32,
60+
np.dtype(np.float64): torch.float64,
61+
np.dtype(np.complex64): torch.complex64,
62+
np.dtype(np.complex128): torch.complex128,
63+
}
64+
5165

5266
class _InplaceXform(Transform):
5367
def __call__(self, data):
@@ -93,16 +107,28 @@ def test_metatensor(self):
93107
shape = (1, 10, 9, 8)
94108
items = [TEST_NDARRAYS[-1](np.arange(0, np.prod(shape)).reshape(shape))]
95109
with tempfile.TemporaryDirectory() as tempdir:
96-
ds = GDSDataset(
97-
data=items,
98-
transform=_InplaceXform(),
99-
cache_dir=tempdir,
100-
device=0,
101-
pickle_module="pickle",
102-
pickle_protocol=pickle.HIGHEST_PROTOCOL,
103-
)
110+
ds = GDSDataset(data=items, transform=_InplaceXform(), cache_dir=tempdir, device=0)
104111
assert_allclose(ds[0], ds[0][0], type_test=False)
105112

113+
def test_dtype(self):
114+
shape = (1, 10, 9, 8)
115+
data = np.arange(0, np.prod(shape)).reshape(shape)
116+
for _dtype in DTYPES.keys():
117+
items = [np.array(data).astype(_dtype)]
118+
with tempfile.TemporaryDirectory() as tempdir:
119+
ds = GDSDataset(data=items, transform=_InplaceXform(), cache_dir=tempdir, device=0)
120+
ds1 = GDSDataset(data=items, transform=_InplaceXform(), cache_dir=tempdir, device=0)
121+
self.assertEqual(ds[0].dtype, _dtype)
122+
self.assertEqual(ds1[0].dtype, DTYPES[_dtype])
123+
124+
for _dtype in DTYPES.keys():
125+
items = [torch.tensor(data, dtype=DTYPES[_dtype])]
126+
with tempfile.TemporaryDirectory() as tempdir:
127+
ds = GDSDataset(data=items, transform=_InplaceXform(), cache_dir=tempdir, device=0)
128+
ds1 = GDSDataset(data=items, transform=_InplaceXform(), cache_dir=tempdir, device=0)
129+
self.assertEqual(ds[0].dtype, DTYPES[_dtype])
130+
self.assertEqual(ds1[0].dtype, DTYPES[_dtype])
131+
106132
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])
107133
def test_shape(self, transform, expected_shape):
108134
test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]).astype(float), np.eye(4))

0 commit comments

Comments
 (0)