Skip to content

Commit 2c5c89f

Browse files
authored
Add warning in RandHistogramShift (#5877)
Signed-off-by: KumoLiu <yunl@nvidia.com> Fixes #5875 . ### Description Add warning in `RandHistogramShift` when the image's intensity is a single value. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: KumoLiu <yunl@nvidia.com>
1 parent bdf5e1e commit 2c5c89f

2 files changed

Lines changed: 23 additions & 1 deletion

File tree

monai/transforms/intensity/array.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1468,9 +1468,15 @@ def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTen
14681468
if self.reference_control_points is None or self.floating_control_points is None:
14691469
raise RuntimeError("please call the `randomize()` function first.")
14701470
img_t = convert_to_tensor(img, track_meta=False)
1471+
img_min, img_max = img_t.min(), img_t.max()
1472+
if img_min == img_max:
1473+
warn(
1474+
f"The image's intensity is a single value {img_min}. "
1475+
"The original image is simply returned, no histogram shift is done."
1476+
)
1477+
return img
14711478
xp, *_ = convert_to_dst_type(self.reference_control_points, dst=img_t)
14721479
yp, *_ = convert_to_dst_type(self.floating_control_points, dst=img_t)
1473-
img_min, img_max = img_t.min(), img_t.max()
14741480
reference_control_points_scaled = xp * (img_max - img_min) + img_min
14751481
floating_control_points_scaled = yp * (img_max - img_min) + img_min
14761482
img_t = self.interp(img_t, reference_control_points_scaled, floating_control_points_scaled)

tests/test_rand_histogram_shift.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,16 @@
4444
]
4545
)
4646

47+
WARN_TESTS = []
48+
for p in TEST_NDARRAYS:
49+
WARN_TESTS.append(
50+
[
51+
{"num_control_points": 5, "prob": 1.0},
52+
{"img": p(np.zeros(8).reshape((1, 2, 2, 2)))},
53+
np.zeros(8).reshape((1, 2, 2, 2)),
54+
]
55+
)
56+
4757

4858
class TestRandHistogramShift(unittest.TestCase):
4959
@parameterized.expand(TESTS)
@@ -71,6 +81,12 @@ def test_interp(self):
7181
self.assertEqual(yi.shape, (3, 2))
7282
assert_allclose(yi, array_type([[1.0, 5.0], [0.5, -0.5], [4.0, 5.0]]))
7383

84+
@parameterized.expand(WARN_TESTS)
85+
def test_warn(self, input_param, input_data, expected_val):
86+
with self.assertWarns(Warning):
87+
result = RandHistogramShift(**input_param)(**input_data)
88+
assert_allclose(result, expected_val, type_test="tensor")
89+
7490

7591
if __name__ == "__main__":
7692
unittest.main()

0 commit comments

Comments
 (0)