Skip to content

Commit 7d37fff

Browse files
committed
Update sort and nms ir
1 parent 8dd4bfa commit 7d37fff

1 file changed

Lines changed: 13 additions & 13 deletions

File tree

topi/python/topi/cuda/rcnn/proposal.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def argsort_ir(data_buf, out_index_buf):
138138
p_data = ib.buffer_ptr(data_buf)
139139
index_out = ib.buffer_ptr(out_index_buf)
140140
nthread_tx = max_threads
141-
nthread_bx = num_bbox // max_threads + 1
141+
nthread_bx = (num_bbox + 1) // 2 // max_threads + 1
142142
tx = tvm.thread_axis("threadIdx.x")
143143
bx = tvm.thread_axis("vthread")
144144
ib.scope_attr(tx, "thread_extent", nthread_tx)
@@ -149,9 +149,10 @@ def argsort_ir(data_buf, out_index_buf):
149149

150150
with ib.for_range(0, batch, for_type="unroll") as b:
151151
start = b * num_bbox
152-
with ib.if_scope(tid < num_bbox):
153-
index_out[start + tid] = tid
154-
152+
for i in range(2):
153+
bbox_id = tid * 2 + i
154+
with ib.if_scope(bbox_id < num_bbox):
155+
index_out[start + bbox_id] = bbox_id
155156
with ib.for_range(0, num_bbox) as k:
156157
offset = start + 2 * tid + (k % 2)
157158
with ib.if_scope(
@@ -213,17 +214,16 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx):
213214
nthread_bx = num_bbox // max_threads + 1
214215
ib.scope_attr(tx, "thread_extent", nthread_tx)
215216
ib.scope_attr(bx, "thread_extent", nthread_bx)
216-
j = bx * max_threads + tx
217+
i = bx * max_threads + tx
217218
with ib.for_range(0, batch, for_type="unroll", name="n") as b:
218-
start = b * num_bbox
219-
with ib.if_scope(j < num_bbox):
220-
p_out[start + j] = False
221-
222-
with ib.for_range(0, num_bbox - 1) as i:
223-
with ib.if_scope(tvm.all(j < num_bbox, j > i, p_out[start + i] == False)):
224-
iou = calculate_overlap(p_data, (start + i) * 5, (start + j) * 5)
219+
base_idx = b * num_bbox
220+
with ib.if_scope(i < num_bbox):
221+
p_out[base_idx + i] = False
222+
with ib.for_range(0, num_bbox - 1) as l:
223+
with ib.if_scope(tvm.all(i < num_bbox, i > l, p_out[base_idx + l] == False)):
224+
iou = calculate_overlap(p_data, (base_idx + l) * 5, (base_idx + i) * 5)
225225
with ib.if_scope(iou > nms_threshold):
226-
p_out[start + j] = True
226+
p_out[base_idx + i] = True
227227
return ib.get()
228228

229229

0 commit comments

Comments
 (0)