|
24 | 24 | MetricReduction, |
25 | 25 | convert_data_type, |
26 | 26 | convert_to_tensor, |
| 27 | + convert_to_numpy, |
27 | 28 | ensure_tuple_rep, |
28 | 29 | look_up_option, |
29 | 30 | optional_import, |
@@ -165,38 +166,41 @@ def get_mask_edges( |
165 | 166 | seg_pred = seg_pred == label_idx |
166 | 167 | if seg_gt.dtype not in (bool, torch.bool): |
167 | 168 | seg_gt = seg_gt == label_idx |
168 | | - |
169 | 169 | if crop: |
170 | | - if not (seg_pred | seg_gt).any(): |
171 | | - pred, gt = np.zeros_like(seg_pred), np.zeros_like(seg_gt) |
| 170 | + or_vol = seg_pred | seg_gt |
| 171 | + if not or_vol.any(): |
| 172 | + pred, gt = np.zeros(seg_pred.shape, dtype=bool), np.zeros(seg_gt.shape, dtype=bool) |
172 | 173 | return (pred, gt) if spacing is None else (pred, gt, pred, gt) # type: ignore |
| 174 | + channel_first = [seg_pred[None], seg_gt[None], or_vol[None]] |
| 175 | + if spacing is None: # cpu only erosion |
| 176 | + seg_pred, seg_gt, or_vol = convert_to_tensor(channel_first, device="cpu", dtype=bool) |
| 177 | + else: # pytorch subvoxel, maybe on gpu, but croppad boolean values on GPU is not supported |
| 178 | + seg_pred, seg_gt, or_vol = convert_to_tensor(channel_first, dtype=torch.float16) |
173 | 179 | cropper = CropForegroundD( |
174 | 180 | ["pred", "gt"], source_key="src", margin=1, allow_smaller=True, start_coord_key=None, end_coord_key=None |
175 | 181 | ) |
176 | | - mask = seg_pred | seg_gt |
177 | | - cropped = cropper({"pred": seg_pred[None], "gt": seg_gt[None], "src": mask[None]}) # type: ignore |
178 | | - seg_pred = cropped["pred"][0] |
179 | | - seg_gt = cropped["gt"][0] |
180 | | - |
181 | | - if spacing is None: |
182 | | - # Do binary erosion and use XOR to get edges |
183 | | - seg_pred = convert_data_type(seg_pred, np.ndarray)[0] |
184 | | - seg_gt = convert_data_type(seg_gt, np.ndarray)[0] |
| 182 | + cropped = cropper({"pred": seg_pred, "gt": seg_gt, "src": or_vol}) # type: ignore |
| 183 | + seg_pred, seg_gt = cropped["pred"][0], cropped["gt"][0] |
| 184 | + |
| 185 | + if spacing is None: # Do binary erosion and use XOR to get edges |
| 186 | + seg_pred, seg_gt = convert_to_numpy([seg_pred, seg_gt], dtype=bool) |
185 | 187 | edges_pred = binary_erosion(seg_pred) ^ seg_pred |
186 | 188 | edges_gt = binary_erosion(seg_gt) ^ seg_gt |
187 | 189 | return edges_pred, edges_gt |
188 | | - code_to_area_table, k = get_code_to_measure_table(spacing) |
| 190 | + code_to_area_table, k = get_code_to_measure_table(spacing, device=seg_pred.device) |
189 | 191 | spatial_dims = len(spacing) |
190 | 192 | conv = torch.nn.functional.conv3d if spatial_dims == 3 else torch.nn.functional.conv2d |
191 | | - code_pred, code_gt = conv(torch.stack([seg_pred[None], seg_gt[None]], dim=0).float(), k.float()) # type: ignore |
| 193 | + vol = torch.stack([seg_pred[None], seg_gt[None]], dim=0).float() |
| 194 | + code_pred, code_gt = conv(vol, k.to(vol)) # type: ignore |
192 | 195 | # edges |
193 | 196 | all_ones = len(code_to_area_table) - 1 |
194 | 197 | edges_pred = (code_pred != 0) & (code_pred != all_ones) |
195 | 198 | edges_gt = (code_gt != 0) & (code_gt != all_ones) |
196 | 199 | # areas of edges |
197 | 200 | areas_pred = torch.index_select(code_to_area_table, 0, code_pred.view(-1).int()).reshape(code_pred.shape) |
198 | 201 | areas_gt = torch.index_select(code_to_area_table, 0, code_gt.view(-1).int()).reshape(code_gt.shape) |
199 | | - return edges_pred.array[0], edges_gt.array[0], areas_pred.array[0], areas_gt.array[0] # type: ignore |
| 202 | + ret = (edges_pred[0], edges_gt[0], areas_pred[0], areas_gt[0]) |
| 203 | + return convert_to_numpy(ret, wrap_sequence=False) |
200 | 204 |
|
201 | 205 |
|
202 | 206 | def get_surface_distance( |
|
0 commit comments