Skip to content

Commit 0d526e5

Browse files
author
Ziqiang XU
committed
Complete pytorch grid_sample
Pytorch's grid_sample() supports various interpolation options: (1) data dimension: 2D / 3D (2) interpolation method: nearest / bilinear / bicubic (3) padding_mode: zeros / border / reflection (4) align_corners: True / False However, TVM only supports a part of above options: (1) data dimension: 2D (2) interpolation method: bilinear (3) padding_mode: zeros / border (4) align_corners: True This commit completes the options not supported by TVM, and keeps existing grid_sample of onnx/pytorch uninfluenced. Co-authored-by: shukun.net
1 parent ae285c6 commit 0d526e5

File tree

11 files changed

+1086
-179
lines changed

11 files changed

+1086
-179
lines changed

include/tvm/relay/attrs/image.h

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -276,23 +276,44 @@ struct GridSampleAttrs : public tvm::AttrsNode<GridSampleAttrs> {
276276
String method;
277277
String layout;
278278
String padding_mode;
279+
bool align_corners;
279280

280281
TVM_DECLARE_ATTRS(GridSampleAttrs, "relay.attrs.GridSampleAttrs") {
281282
TVM_ATTR_FIELD(method)
282283
.set_default("bilinear")
283284
.describe(
284285
"Specify the mode to use for scaling."
285-
"bilinear - Bilinear Interpolation");
286+
"nearest - 2D or 3D Nearest Interpolation."
287+
"bilinear - '2D Bilinear' or '3D Trilinear' Interpolation."
288+
"bicubic - 2D Bicubic Interpolation.");
286289
TVM_ATTR_FIELD(layout).set_default("NCHW").describe(
287-
"Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc."
288-
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
289-
"dimensions respectively. Resize is applied on the 'H' and"
290-
"'W' dimensions.");
290+
"Dimension ordering of input data. Can be 'NCHW', 'NCDHW', etc."
291+
"'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width"
292+
"dimensions respectively."
293+
"2D Resize is applied on the 'H' and 'W' dimensions."
294+
"3D Resize is applied on the 'D' and 'H' and 'W' dimensions.");
291295
TVM_ATTR_FIELD(padding_mode)
292296
.set_default("zeros")
293297
.describe(
294-
"Specify the padding mode to use."
295-
"zeros, border etc.");
298+
"If :attr:'grid' has values outside the range of '[-1, 1]', the corresponding"
299+
"outputs are handled as defined by padding_mode. Options are"
300+
"padding_mode='zeros': use '0' for out-of-bound grid locations,"
301+
"padding_mode='border': use border values for out-of-bound grid locations"
302+
"padding_mode='reflection': use values at locations reflected by"
303+
"the border for out-of-bound grid locations. For location far away"
304+
"from the border, it will keep being reflected until becoming in bound,"
305+
"e.g., (normalized) pixel location 'x = -3.5' reflects by border '-1'"
306+
"and becomes 'x' = 1.5, then reflects by border '1' and becomes"
307+
"'x' = -0.5");
308+
TVM_ATTR_FIELD(align_corners)
309+
.set_default(true)
310+
.describe(
311+
"Geometrically, we consider the pixels of the"
312+
"input as squares rather than points."
313+
"If set to True, the extrema (-1 and 1) are considered as referring"
314+
"to the center points of the input's corner pixels. If set to False, they"
315+
"are instead considered as referring to the corner points of the input's corner"
316+
"pixels, making the sampling more resolution agnostic.");
296317
}
297318
};
298319

python/tvm/relay/frontend/pytorch.py

Lines changed: 37 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2931,23 +2931,46 @@ def mv(self, inputs, _):
29312931
return _op.transform.squeeze(dense_result)
29322932

29332933
def grid_sampler(self, inputs, input_types):
2934-
if inputs[2] == 0:
2935-
mode = "bilinear"
2934+
interpolate_mode = inputs[2]
2935+
padding_mode = inputs[3]
2936+
align_corners = inputs[4]
2937+
data_shape = self.infer_shape_with_prelude(inputs[0])
2938+
2939+
if len(data_shape) == 4:
2940+
layout = "NCHW"
2941+
axes = [0, 3, 1, 2]
2942+
grid = _op.transform.transpose(inputs[1], axes)
2943+
elif len(data_shape) == 5:
2944+
layout = "NCDHW"
2945+
axes = [0, 4, 1, 2, 3]
2946+
grid = _op.transform.transpose(inputs[1], axes)
29362947
else:
2937-
msg = "Only bilinear mode is supported in grid_sampler"
2938-
raise NotImplementedError(msg)
2939-
2940-
if inputs[3] == 0:
2941-
padding_mode = "zeros"
2942-
elif inputs[3] == 1:
2943-
padding_mode = "border"
2948+
msg = f"only 4D and 5D are supported."
2949+
raise ValueError(msg)
2950+
2951+
if interpolate_mode == 0:
2952+
interpolate_str = "bilinear"
2953+
elif interpolate_mode == 1:
2954+
interpolate_str = "nearest"
2955+
elif interpolate_mode == 2:
2956+
interpolate_str = "bicubic"
29442957
else:
2945-
msg = "Only zeros and border padding mode are supported in grid_sampler"
2946-
raise NotImplementedError(msg)
2958+
msg = f"interpolation method {interpolate_mode} is not supported"
2959+
raise ValueError(msg)
2960+
2961+
if padding_mode == 0:
2962+
padding_mode_str = "zeros"
2963+
elif padding_mode == 1:
2964+
padding_mode_str = "border"
2965+
elif padding_mode == 2:
2966+
padding_mode_str = "reflection"
2967+
else:
2968+
msg = f"padding_mode {padding_mode} is not supported"
2969+
raise ValueError(msg)
29472970

2948-
axes = [0, 3, 1, 2]
2949-
grid = _op.transform.transpose(inputs[1], axes)
2950-
return _op.image.grid_sample(inputs[0], grid, mode, "NCHW", padding_mode)
2971+
return _op.image.grid_sample(
2972+
inputs[0], grid, interpolate_str, layout, padding_mode_str, align_corners
2973+
)
29512974

29522975
# Operator mappings
29532976
def create_convert_map(self):

python/tvm/relay/op/image/_image.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -366,14 +366,17 @@ def compute_grid_sample(attrs, inputs, out_dtype):
366366
method = attrs.method
367367
layout = attrs.layout
368368
padding_mode = attrs.padding_mode
369-
return [topi.image.grid_sample(inputs[0], inputs[1], method, layout, padding_mode)]
369+
align_corners = attrs.align_corners
370+
return [
371+
topi.image.grid_sample(inputs[0], inputs[1], method, layout, padding_mode, align_corners)
372+
]
370373

371374

372375
reg.register_injective_schedule("image.grid_sample")
373376

374377

375378
@script
376-
def _grid_sample_func(data, grid):
379+
def _grid_sample_func_nchw(data, grid):
377380
out = output_tensor((4,), "int64")
378381
out[0] = int64(data[0])
379382
out[1] = int64(data[1])
@@ -382,9 +385,27 @@ def _grid_sample_func(data, grid):
382385
return out
383386

384387

388+
@script
389+
def _grid_sample_func_ncdhw(data, grid):
390+
out = output_tensor((5,), "int64")
391+
out[0] = int64(data[0])
392+
out[1] = int64(data[1])
393+
out[2] = int64(grid[2])
394+
out[3] = int64(grid[3])
395+
out[4] = int64(grid[4])
396+
return out
397+
398+
385399
@reg.register_shape_func("image.grid_sample", False)
386400
def grid_sample_func(attrs, inputs, _):
387401
"""
388402
Shape function for grid_sample op.
389403
"""
390-
return [_grid_sample_func(inputs[0], inputs[1])]
404+
if attrs.layout == "NCHW":
405+
script_func = _grid_sample_func_nchw
406+
elif attrs.layout == "NCDHW":
407+
script_func = _grid_sample_func_ncdhw
408+
else:
409+
msg = f"layout {attrs.layout} is not supported"
410+
raise ValueError(msg)
411+
return [script_func(inputs[0], inputs[1])]

python/tvm/relay/op/image/image.py

Lines changed: 34 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -455,22 +455,33 @@ def affine_grid(data, target_shape=None):
455455
return _make.affine_grid(data, target_shape)
456456

457457

458-
def grid_sample(data, grid, method="bilinear", layout="NCHW", padding_mode="zeros"):
459-
"""Applies bilinear sampling to input feature map.
458+
def grid_sample(
459+
data, grid, method="bilinear", layout="NCHW", padding_mode="zeros", align_corners=True
460+
):
461+
"""Applies grid sampling to input feature map.
460462
461-
Given :math:`data` and :math:`grid`, then the output is computed by
463+
Given :math:`data` and :math:`grid`, then for 4-D the output is computed by
462464
463465
.. math::
464466
465467
x_{src} = grid[batch, 0, y_{dst}, x_{dst}] \\
466468
y_{src} = grid[batch, 1, y_{dst}, x_{dst}] \\
467-
output[batch, channel, y_{dst}, x_{dst}] = G(data[batch, channel, y_{src}, x_{src})
469+
output[batch, channel, y_{dst}, x_{dst}] = G(data[batch, channel, y_{src}, x_{src}])
468470
469471
:math:`x_{dst}`, :math:`y_{dst}` enumerate all spatial locations in :math:`output`, and
470472
:math:`G()` denotes the interpolation function.
471-
The out-boundary points will be padded with zeros if padding_mode is "zeros".
473+
474+
The out-boundary points will be padded with zeros if padding_mode is `zeros`, or
475+
border pixel value if padding_mode is `border`, or
476+
inner pixel value if padding_mode is `reflection`.
477+
478+
The left-top corner (-1, -1) and right-bottom corner (1, 1) in grid will be map to
479+
(0, 0) and (h - 1, w - 1) of data if align_corners is `True`, or
480+
(-0.5, -0.5) and (h + 0.5, w + 0.5) of data if align_corners is `False`.
481+
472482
The shape of the output will be
473-
(data.shape[0], data.shape[1], grid.shape[2], grid.shape[3]).
483+
4-D (data.shape[0], data.shape[1], grid.shape[2], grid.shape[3]), or
484+
5-D (data.shape[0], data.shape[1], grid.shape[2], grid.shape[3], grid.shape[4]).
474485
475486
The operator assumes that :math:`grid` has been normalized to [-1, 1].
476487
@@ -479,23 +490,34 @@ def grid_sample(data, grid, method="bilinear", layout="NCHW", padding_mode="zero
479490
Parameters
480491
----------
481492
data : tvm.Tensor
482-
4-D with shape [batch, in_channel, in_height, in_width]
493+
4-D with shape [batch, in_channel, in_height, in_width], or
494+
5-D with shape [batch, in_channel, in_depth, in_height, in_width]
483495
484496
grid : tvm.Tensor
485-
4-D with shape [batch, 2, out_height, out_width]
497+
4-D with shape [batch, 2, out_height, out_width], or
498+
5-D with shape [batch, 3, out_depth, out_height, out_width]
486499
487500
method : str
488-
The interpolation method. Only 'bilinear' is supported.
501+
The interpolation method, 4-D `nearest`, `bilinear`, `bicubic` and
502+
5-D `nearest`, `bilinear`(trilinear) are supported.
489503
490504
layout : str
491505
The layout of input data and the output.
492506
493507
padding_mode : str
494-
The padding mode for outside grid values.
508+
The padding mode for outside grid values, `zeros`, `border`, `reflection` are supported.
509+
510+
align_corners: bool
511+
Geometrically, we consider the pixels of the input as squares rather than points.
512+
If set to `True`, the extrema (`-1` and `1`) are considered as referring
513+
to the center points of the input's corner pixels. If set to `False`, they
514+
are instead considered as referring to the corner points of the input's corner
515+
pixels, making the sampling more resolution agnostic.
495516
496517
Returns
497518
-------
498519
Output : tvm.Tensor
499-
4-D with shape [batch, 2, out_height, out_width]
520+
4-D with shape [batch, in_channel, out_height, out_width], or
521+
5-D with shape [batch, in_channel, out_depth, out_height, out_width]
500522
"""
501-
return _make.grid_sample(data, grid, method, layout, padding_mode)
523+
return _make.grid_sample(data, grid, method, layout, padding_mode, align_corners)

0 commit comments

Comments
 (0)