@@ -138,7 +138,7 @@ def argsort_ir(data_buf, out_index_buf):
138138 p_data = ib .buffer_ptr (data_buf )
139139 index_out = ib .buffer_ptr (out_index_buf )
140140 nthread_tx = max_threads
141- nthread_bx = num_bbox // max_threads + 1
141+ nthread_bx = ( num_bbox + 1 ) // 2 // max_threads + 1
142142 tx = tvm .thread_axis ("threadIdx.x" )
143143 bx = tvm .thread_axis ("vthread" )
144144 ib .scope_attr (tx , "thread_extent" , nthread_tx )
@@ -149,9 +149,10 @@ def argsort_ir(data_buf, out_index_buf):
149149
150150 with ib .for_range (0 , batch , for_type = "unroll" ) as b :
151151 start = b * num_bbox
152- with ib .if_scope (tid < num_bbox ):
153- index_out [start + tid ] = tid
154-
152+ for i in range (2 ):
153+ bbox_id = tid * 2 + i
154+ with ib .if_scope (bbox_id < num_bbox ):
155+ index_out [start + bbox_id ] = bbox_id
155156 with ib .for_range (0 , num_bbox ) as k :
156157 offset = start + 2 * tid + (k % 2 )
157158 with ib .if_scope (
@@ -213,17 +214,16 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx):
213214 nthread_bx = num_bbox // max_threads + 1
214215 ib .scope_attr (tx , "thread_extent" , nthread_tx )
215216 ib .scope_attr (bx , "thread_extent" , nthread_bx )
216- j = bx * max_threads + tx
217+ i = bx * max_threads + tx
217218 with ib .for_range (0 , batch , for_type = "unroll" , name = "n" ) as b :
218- start = b * num_bbox
219- with ib .if_scope (j < num_bbox ):
220- p_out [start + j ] = False
221-
222- with ib .for_range (0 , num_bbox - 1 ) as i :
223- with ib .if_scope (tvm .all (j < num_bbox , j > i , p_out [start + i ] == False )):
224- iou = calculate_overlap (p_data , (start + i ) * 5 , (start + j ) * 5 )
219+ base_idx = b * num_bbox
220+ with ib .if_scope (i < num_bbox ):
221+ p_out [base_idx + i ] = False
222+ with ib .for_range (0 , num_bbox - 1 ) as l :
223+ with ib .if_scope (tvm .all (i < num_bbox , i > l , p_out [base_idx + l ] == False )):
224+ iou = calculate_overlap (p_data , (base_idx + l ) * 5 , (base_idx + i ) * 5 )
225225 with ib .if_scope (iou > nms_threshold ):
226- p_out [start + j ] = True
226+ p_out [base_idx + i ] = True
227227 return ib .get ()
228228
229229
0 commit comments