Skip to content

Commit c774e92

Browse files
committed
test gpu
Signed-off-by: Wenqi Li <wenqil@nvidia.com>
1 parent 9a28579 commit c774e92

2 files changed

Lines changed: 21 additions & 16 deletions

File tree

monai/metrics/utils.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
MetricReduction,
2525
convert_data_type,
2626
convert_to_tensor,
27+
convert_to_numpy,
2728
ensure_tuple_rep,
2829
look_up_option,
2930
optional_import,
@@ -165,38 +166,41 @@ def get_mask_edges(
165166
seg_pred = seg_pred == label_idx
166167
if seg_gt.dtype not in (bool, torch.bool):
167168
seg_gt = seg_gt == label_idx
168-
169169
if crop:
170-
if not (seg_pred | seg_gt).any():
171-
pred, gt = np.zeros_like(seg_pred), np.zeros_like(seg_gt)
170+
or_vol = seg_pred | seg_gt
171+
if not or_vol.any():
172+
pred, gt = np.zeros(seg_pred.shape, dtype=bool), np.zeros(seg_gt.shape, dtype=bool)
172173
return (pred, gt) if spacing is None else (pred, gt, pred, gt) # type: ignore
174+
channel_first = [seg_pred[None], seg_gt[None], or_vol[None]]
175+
if spacing is None: # cpu only erosion
176+
seg_pred, seg_gt, or_vol = convert_to_tensor(channel_first, device="cpu", dtype=bool)
177+
else: # pytorch subvoxel, maybe on gpu, but croppad boolean values on GPU is not supported
178+
seg_pred, seg_gt, or_vol = convert_to_tensor(channel_first, dtype=torch.float16)
173179
cropper = CropForegroundD(
174180
["pred", "gt"], source_key="src", margin=1, allow_smaller=True, start_coord_key=None, end_coord_key=None
175181
)
176-
mask = seg_pred | seg_gt
177-
cropped = cropper({"pred": seg_pred[None], "gt": seg_gt[None], "src": mask[None]}) # type: ignore
178-
seg_pred = cropped["pred"][0]
179-
seg_gt = cropped["gt"][0]
180-
181-
if spacing is None:
182-
# Do binary erosion and use XOR to get edges
183-
seg_pred = convert_data_type(seg_pred, np.ndarray)[0]
184-
seg_gt = convert_data_type(seg_gt, np.ndarray)[0]
182+
cropped = cropper({"pred": seg_pred, "gt": seg_gt, "src": or_vol}) # type: ignore
183+
seg_pred, seg_gt = cropped["pred"][0], cropped["gt"][0]
184+
185+
if spacing is None: # Do binary erosion and use XOR to get edges
186+
seg_pred, seg_gt = convert_to_numpy([seg_pred, seg_gt], dtype=bool)
185187
edges_pred = binary_erosion(seg_pred) ^ seg_pred
186188
edges_gt = binary_erosion(seg_gt) ^ seg_gt
187189
return edges_pred, edges_gt
188-
code_to_area_table, k = get_code_to_measure_table(spacing)
190+
code_to_area_table, k = get_code_to_measure_table(spacing, device=seg_pred.device)
189191
spatial_dims = len(spacing)
190192
conv = torch.nn.functional.conv3d if spatial_dims == 3 else torch.nn.functional.conv2d
191-
code_pred, code_gt = conv(torch.stack([seg_pred[None], seg_gt[None]], dim=0).float(), k.float()) # type: ignore
193+
vol = torch.stack([seg_pred[None], seg_gt[None]], dim=0).float()
194+
code_pred, code_gt = conv(vol, k.to(vol)) # type: ignore
192195
# edges
193196
all_ones = len(code_to_area_table) - 1
194197
edges_pred = (code_pred != 0) & (code_pred != all_ones)
195198
edges_gt = (code_gt != 0) & (code_gt != all_ones)
196199
# areas of edges
197200
areas_pred = torch.index_select(code_to_area_table, 0, code_pred.view(-1).int()).reshape(code_pred.shape)
198201
areas_gt = torch.index_select(code_to_area_table, 0, code_gt.view(-1).int()).reshape(code_gt.shape)
199-
return edges_pred.array[0], edges_gt.array[0], areas_pred.array[0], areas_gt.array[0] # type: ignore
202+
ret = (edges_pred[0], edges_gt[0], areas_pred[0], areas_gt[0])
203+
return convert_to_numpy(ret, wrap_sequence=False)
200204

201205

202206
def get_surface_distance(

tests/test_surface_dice.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,8 @@ def test_compute_surface_dice_subvoxel(self):
402402
)
403403
assert_allclose(res, 0.5, type_test=False)
404404

405-
mask_gt, mask_pred = torch.zeros(1, 1, 100, 100, 100), torch.zeros(1, 1, 100, 100, 100)
405+
d = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
406+
mask_gt, mask_pred = torch.zeros(1, 1, 100, 100, 100, device=d), torch.zeros(1, 1, 100, 100, 100, device=d)
406407
mask_gt[0, 0, 0:50, :, :] = 1
407408
mask_pred[0, 0, 0:51, :, :] = 1
408409
res = compute_surface_dice(

0 commit comments

Comments
 (0)