Skip to content

Commit 619bb1d

Browse files
[FRONTEND][TFLITE] Add support for TFLite's regular NMS operator (#15117)
This PR adds support of regular NMS operator. Open questions: 1. How to properly test added functionality? Other NMS implementations, e.g., fast NMS, use a TF frozen graph from TF official website to convert a model to TFLite and keep NMS operations only. In order to create a similar test, we need to find an archive on TF official website that contains a frozen graph of a model compiled with --use-regular-nms=True flag. We haven't found it yet, so any help is appreciated. 2. Regular NMS requires two sort operations: Sorting the scores after selecting scores above nms_score_threshold. This PR implements this with a simple bubble sort in order to prove the algorithm's semantics. We tried to replace it with tvm.contrib.sort.argsort. It works well when testing the regular NMS with run_tvm_graph as fast NMS test does or building and running the regular NMS with llvm target. At the same time, it fails to build (error is provided below) when target=ethos-u,cmsis-nn,c. It seems that __tvm_module_ctx variable is only being initialized when cpp runtime is chosen. The error: error: '__tvm_module_ctx' undeclared (first use in this function) 203 | if (TVMBackendGetFuncFromEnv(__tvm_module_ctx, "tvm.contrib.sort.argsort", &tvm_contrib_sort_argsort_packed) != 0) { Sorting the scores of previous and current NMS steps. There are two alternatives here: implement some sorting algorithm as part of hybrid script (to replace current bubble sort) save the result of each NMS step and use argsort after the hybrid script part. This approach has a drawback as it requires significant amount of memory to store the results of each NMS step.
1 parent 0556653 commit 619bb1d

14 files changed

Lines changed: 695 additions & 49 deletions

File tree

include/tvm/relay/attrs/vision.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,13 +61,17 @@ struct MultiBoxTransformLocAttrs : public tvm::AttrsNode<MultiBoxTransformLocAtt
6161
bool clip;
6262
double threshold;
6363
Array<IndexExpr> variances;
64+
bool keep_background;
6465

6566
TVM_DECLARE_ATTRS(MultiBoxTransformLocAttrs, "relay.attrs.MultiBoxTransformLocAttrs") {
6667
TVM_ATTR_FIELD(clip).set_default(true).describe("Clip out-of-boundary boxes.");
6768
TVM_ATTR_FIELD(threshold).set_default(0.01).describe("Threshold to be a positive prediction.");
6869
TVM_ATTR_FIELD(variances)
6970
.set_default(Array<IndexExpr>({0.1f, 0.1f, 0.2f, 0.2f}))
7071
.describe("Variances to be decoded from box regression output.");
72+
TVM_ATTR_FIELD(keep_background)
73+
.set_default(false)
74+
.describe("Whether to keep boxes detected as background or not");
7175
}
7276
};
7377

@@ -129,6 +133,27 @@ struct AllClassNonMaximumSuppressionAttrs
129133
}
130134
};
131135

136+
/*! \brief Attributes used in regular_non_maximum_suppression operator */
137+
struct RegularNonMaximumSuppressionAttrs
138+
: public tvm::AttrsNode<RegularNonMaximumSuppressionAttrs> {
139+
int32_t max_detections_per_class;
140+
int32_t max_detections;
141+
int32_t num_classes;
142+
double iou_threshold;
143+
double score_threshold;
144+
145+
TVM_DECLARE_ATTRS(RegularNonMaximumSuppressionAttrs,
146+
"relay.attrs.RegularNonMaximumSuppressionAttrs") {
147+
TVM_ATTR_FIELD(max_detections_per_class)
148+
.describe("The maxinum number of output selected boxes per class.");
149+
TVM_ATTR_FIELD(max_detections).describe("The maxinum number of output selected boxes.");
150+
TVM_ATTR_FIELD(num_classes).describe("The number of classes without background.");
151+
TVM_ATTR_FIELD(iou_threshold).describe("The IoU threshold for box the overlap test.");
152+
TVM_ATTR_FIELD(score_threshold)
153+
.describe("Score threshold to filter out low score boxes early.");
154+
}
155+
};
156+
132157
/*! \brief Attributes used in roi_align operators */
133158
struct ROIAlignAttrs : public tvm::AttrsNode<ROIAlignAttrs> {
134159
Array<IndexExpr> pooled_size;

python/tvm/relay/frontend/tflite.py

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3443,12 +3443,7 @@ def convert_detection_postprocess(self, op):
34433443
flexbuffer = op.CustomOptionsAsNumpy().tobytes()
34443444
custom_options = FlexBufferDecoder(flexbuffer).decode()
34453445

3446-
if "use_regular_nms" in custom_options:
3447-
if custom_options["use_regular_nms"]:
3448-
raise tvm.error.OpAttributeUnImplemented(
3449-
"use_regular_nms=True is not yet supported for operator "
3450-
"TFLite_Detection_PostProcess."
3451-
)
3446+
use_regular_nms = "use_regular_nms" in custom_options and custom_options["use_regular_nms"]
34523447

34533448
inputs = self.get_input_tensors(op)
34543449
assert len(inputs) == 3, "inputs length should be 3"
@@ -3481,15 +3476,14 @@ def convert_detection_postprocess(self, op):
34813476
input_zero_point=inputs[2].qnn_params["zero_point"],
34823477
)
34833478

3484-
# reshape the cls_pred and loc_prob tensors so
3485-
# they can be consumed by multibox_transform_loc
3486-
cls_pred = _op.transpose(cls_pred, [0, 2, 1])
34873479
# loc_prob coords are in yxhw format
34883480
# need to convert to xywh
34893481
loc_coords = _op.split(loc_prob, 4, axis=2)
34903482
loc_prob = _op.concatenate(
34913483
[loc_coords[1], loc_coords[0], loc_coords[3], loc_coords[2]], axis=2
34923484
)
3485+
# reshape loc_prob tensor so is can be consumed by
3486+
# multibox_transform_loc
34933487
loc_prob = _op.reshape(loc_prob, [batch_size, anchor_boxes * 4])
34943488

34953489
# anchor coords are in yxhw format
@@ -3511,13 +3505,41 @@ def convert_detection_postprocess(self, op):
35113505
# attributes for multibox_transform_loc
35123506
multibox_transform_loc_attrs = {}
35133507
multibox_transform_loc_attrs["clip"] = False
3514-
multibox_transform_loc_attrs["threshold"] = custom_options["nms_score_threshold"]
3508+
multibox_transform_loc_attrs["threshold"] = (
3509+
0.0 if use_regular_nms else custom_options["nms_score_threshold"]
3510+
)
35153511
multibox_transform_loc_attrs["variances"] = (
35163512
1 / custom_options["x_scale"],
35173513
1 / custom_options["y_scale"],
35183514
1 / custom_options["w_scale"],
35193515
1 / custom_options["h_scale"],
35203516
)
3517+
multibox_transform_loc_attrs["keep_background"] = use_regular_nms
3518+
3519+
ret = _op.vision.multibox_transform_loc(
3520+
# reshape cls_pred so it can be consumed by
3521+
# multibox_transform_loc
3522+
_op.transpose(cls_pred, [0, 2, 1]),
3523+
loc_prob,
3524+
anchor_expr,
3525+
**multibox_transform_loc_attrs,
3526+
)
3527+
3528+
if use_regular_nms:
3529+
# box coordinates need to be converted from ltrb to (ymin, xmin, ymax, xmax)
3530+
_, transformed_boxes = _op.split(ret[0], (2,), axis=2)
3531+
box_l, box_t, box_r, box_b = _op.split(transformed_boxes, 4, axis=2)
3532+
transformed_boxes = _op.concatenate([box_t, box_l, box_b, box_r], axis=2)
3533+
3534+
return _op.vision.regular_non_max_suppression(
3535+
boxes=transformed_boxes,
3536+
scores=cls_pred,
3537+
max_detections_per_class=custom_options["detections_per_class"],
3538+
max_detections=custom_options["max_detections"],
3539+
num_classes=custom_options["num_classes"],
3540+
iou_threshold=custom_options["nms_iou_threshold"],
3541+
score_threshold=custom_options["nms_score_threshold"],
3542+
)
35213543

35223544
# attributes for non_max_suppression
35233545
non_max_suppression_attrs = {}
@@ -3528,9 +3550,6 @@ def convert_detection_postprocess(self, op):
35283550
non_max_suppression_attrs["max_output_size"] = custom_options["max_detections"]
35293551
non_max_suppression_attrs["invalid_to_bottom"] = False
35303552

3531-
ret = _op.vision.multibox_transform_loc(
3532-
cls_pred, loc_prob, anchor_expr, **multibox_transform_loc_attrs
3533-
)
35343553
ret = _op.vision.non_max_suppression(ret[0], ret[1], ret[1], **non_max_suppression_attrs)
35353554
ret = _op.vision.get_valid_counts(ret, 0)
35363555
valid_count = ret[0]

python/tvm/relay/op/strategy/generic.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1184,7 +1184,10 @@ def _compute_multibox_transform_loc(attrs, inputs, _):
11841184
clip = bool(get_const_int(attrs.clip))
11851185
threshold = get_const_float(attrs.threshold)
11861186
variances = get_float_tuple(attrs.variances)
1187-
return topi_compute(inputs[0], inputs[1], inputs[2], clip, threshold, variances)
1187+
keep_background = bool(get_const_int(attrs.keep_background))
1188+
return topi_compute(
1189+
inputs[0], inputs[1], inputs[2], clip, threshold, variances, keep_background
1190+
)
11881191

11891192
return _compute_multibox_transform_loc
11901193

@@ -1316,6 +1319,35 @@ def all_class_nms_strategy(attrs, inputs, out_type, target):
13161319
return strategy
13171320

13181321

1322+
def wrap_compute_regular_nms(topi_compute):
1323+
"""wrap regular nms topi compute"""
1324+
1325+
def _compute_nms(attrs, inputs, out_type):
1326+
return topi_compute(
1327+
inputs[0],
1328+
inputs[1],
1329+
attrs.max_detections_per_class,
1330+
attrs.max_detections,
1331+
attrs.num_classes,
1332+
attrs.iou_threshold,
1333+
attrs.score_threshold,
1334+
)
1335+
1336+
return _compute_nms
1337+
1338+
1339+
@override_native_generic_func("regular_non_max_suppression_strategy")
1340+
def regular_nms_strategy(attrs, inputs, out_type, target):
1341+
"""regular nms generic strategy"""
1342+
strategy = _op.OpStrategy()
1343+
strategy.add_implementation(
1344+
wrap_compute_regular_nms(topi.vision.regular_non_max_suppression),
1345+
wrap_topi_schedule(topi.generic.schedule_nms),
1346+
name="regular_nms.generic",
1347+
)
1348+
return strategy
1349+
1350+
13191351
# roi_align
13201352
def wrap_compute_roi_align(topi_compute):
13211353
"""wrap roi_align topi compute"""

python/tvm/relay/op/vision/_vision.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@
4848
reg.register_strategy("vision.all_class_non_max_suppression", strategy.all_class_nms_strategy)
4949
reg.register_pattern("vision.all_class_non_max_suppression", OpPattern.OPAQUE)
5050

51+
reg.register_strategy("vision.regular_non_max_suppression", strategy.regular_nms_strategy)
52+
reg.register_pattern("vision.regular_non_max_suppression", OpPattern.OPAQUE)
53+
5154

5255
@script
5356
def _get_valid_counts_shape_func(data_shape):
@@ -122,6 +125,33 @@ def all_class_nms_shape_func(attrs, inputs, _):
122125
return _all_class_nms_shape_func_tf(inputs[0], inputs[1])
123126

124127

128+
@script
129+
def _regular_nms_shape_func(boxes_shape, scores_shape, attrs):
130+
out_boxes_shape = output_tensor((3,), "int64")
131+
out_classes_shape = output_tensor((2,), "int64")
132+
out_scores_shape = output_tensor((2,), "int64")
133+
out_num_detections_shape = output_tensor((1,), "int64")
134+
135+
out_boxes_shape[0] = boxes_shape[0]
136+
out_boxes_shape[1] = int64(attrs.max_detections)
137+
out_boxes_shape[2] = int64(4)
138+
139+
out_classes_shape[0] = boxes_shape[0]
140+
out_classes_shape[1] = int64(attrs.max_detections)
141+
142+
out_scores_shape[0] = boxes_shape[0]
143+
out_scores_shape[1] = int64(attrs.max_detections)
144+
145+
out_num_detections_shape[0] = boxes_shape[0]
146+
147+
return out_boxes_shape, out_classes_shape, out_scores_shape, out_num_detections_shape
148+
149+
150+
@reg.register_shape_func("vision.regular_non_max_suppression", False)
151+
def regular_nms_shape_func(attrs, inputs, _):
152+
return _regular_nms_shape_func(inputs[0], inputs[1], attrs)
153+
154+
125155
@script
126156
def _roi_align_shape_func_nchw(data_shape, rois_shape, pooled_size):
127157
out = output_tensor((4,), "int64")

python/tvm/relay/op/vision/multibox.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,13 @@ def multibox_prior(
5353

5454

5555
def multibox_transform_loc(
56-
cls_prob, loc_pred, anchor, clip=True, threshold=0.01, variances=(0.1, 0.1, 0.2, 0.2)
56+
cls_prob,
57+
loc_pred,
58+
anchor,
59+
clip=True,
60+
threshold=0.01,
61+
variances=(0.1, 0.1, 0.2, 0.2),
62+
keep_background=False,
5763
):
5864
"""Location transformation for multibox detection
5965
@@ -77,10 +83,22 @@ def multibox_transform_loc(
7783
variances : Tuple of float, optional
7884
variances to be decoded from box regression output.
7985
86+
keep_background : boolean, optional
87+
Whether to keep boxes detected as background or not.
88+
8089
Returns
8190
-------
8291
ret : tuple of tvm.relay.Expr
8392
"""
8493
return expr.TupleWrapper(
85-
_make.multibox_transform_loc(cls_prob, loc_pred, anchor, clip, threshold, variances), 2
94+
_make.multibox_transform_loc(
95+
cls_prob,
96+
loc_pred,
97+
anchor,
98+
clip,
99+
threshold,
100+
variances,
101+
keep_background,
102+
),
103+
2,
86104
)

python/tvm/relay/op/vision/nms.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,3 +226,62 @@ def all_class_non_max_suppression(
226226
return expr.TupleWrapper(out, 2)
227227

228228
return expr.TupleWrapper(out, 3)
229+
230+
231+
def regular_non_max_suppression(
232+
boxes,
233+
scores,
234+
max_detections_per_class,
235+
max_detections,
236+
num_classes,
237+
iou_threshold,
238+
score_threshold,
239+
):
240+
"""Regular non-maximum suppression operator for object detection, corresponding to TFLite's
241+
regular NMS. NMS is performed for each class separately.
242+
243+
Parameters
244+
----------
245+
boxes : relay.Expr
246+
3-D tensor with shape (batch_size, num_boxes, 4). The four values in boxes
247+
encode (ymin, xmin, ymax, xmax) coordinates of a box
248+
249+
scores: relay.Expr
250+
3-D tensor with shape (batch_size, num_boxes, num_classes_with_background)
251+
252+
max_detections_per_class : int
253+
The maxinum number of output selected boxes per class
254+
255+
max_detections : int
256+
The maxinum number of output selected boxes
257+
258+
num_classes : int
259+
The number of classes without background
260+
261+
iou_threshold : float
262+
IoU test threshold
263+
264+
score_threshold : float
265+
Score threshold to filter out low score boxes early
266+
267+
Returns
268+
-------
269+
out : relay.Tuple
270+
The output is a relay.Tuple of four tensors. The first is `detection_boxes` of size
271+
`(batch_size, max_detections , 4)`, the second is `detection_classes` of size
272+
`(batch_size, max_detections)`, the third is `detection_scores` of size
273+
`(batch_size, max_detections)`, and the fourth is `num_detections` of size `(batch_size,)`
274+
representing the total number of selected boxes per batch.
275+
"""
276+
return expr.TupleWrapper(
277+
_make.regular_non_max_suppression(
278+
boxes,
279+
scores,
280+
max_detections_per_class,
281+
max_detections,
282+
num_classes,
283+
iou_threshold,
284+
score_threshold,
285+
),
286+
4,
287+
)

0 commit comments

Comments
 (0)