99
1010import numba_dpex as dpex
1111import numba_dpex .experimental as dpex_exp
12- from numba_dpex .kernel_api import (
13- LocalAccessor ,
14- MemoryScope ,
15- NdItem ,
16- group_barrier ,
17- )
12+ from numba_dpex .kernel_api import LocalAccessor , NdItem
13+ from numba_dpex .kernel_api import call_kernel as kapi_call_kernel
1814from numba_dpex .tests ._helper import get_all_dtypes
1915
2016list_of_supported_dtypes = get_all_dtypes (
2117 no_bool = True , no_float16 = True , no_none = True , no_complex = True
2218)
2319
2420
25- @dpex_exp .kernel
2621def _kernel1 (nd_item : NdItem , a , slm ):
2722 i = nd_item .get_global_linear_id ()
2823
2924 # TODO: overload nd_item.get_local_id()
3025 j = (nd_item .get_local_id (0 ),)
3126
3227 slm [j ] = 0
33- group_barrier (nd_item .get_group (), MemoryScope .WORK_GROUP )
3428
3529 for m in range (100 ):
3630 slm [j ] += i * m
37- group_barrier (nd_item .get_group (), MemoryScope .WORK_GROUP )
3831
3932 a [i ] = slm [j ]
4033
4134
42- @dpex_exp .kernel
4335def _kernel2 (nd_item : NdItem , a , slm ):
4436 i = nd_item .get_global_linear_id ()
4537
4638 # TODO: overload nd_item.get_local_id()
4739 j = (nd_item .get_local_id (0 ), nd_item .get_local_id (1 ))
4840
4941 slm [j ] = 0
50- group_barrier (nd_item .get_group (), MemoryScope .WORK_GROUP )
5142
5243 for m in range (100 ):
5344 slm [j ] += i * m
54- group_barrier (nd_item .get_group (), MemoryScope .WORK_GROUP )
5545
5646 a [i ] = slm [j ]
5747
5848
59- @dpex_exp .kernel
6049def _kernel3 (nd_item : NdItem , a , slm ):
6150 i = nd_item .get_global_linear_id ()
6251
@@ -68,15 +57,23 @@ def _kernel3(nd_item: NdItem, a, slm):
6857 )
6958
7059 slm [j ] = 0
71- group_barrier (nd_item .get_group (), MemoryScope .WORK_GROUP )
7260
7361 for m in range (100 ):
7462 slm [j ] += i * m
75- group_barrier (nd_item .get_group (), MemoryScope .WORK_GROUP )
7663
7764 a [i ] = slm [j ]
7865
7966
67+ def device_func_kernel (func ):
68+ _df = dpex_exp .device_func (func )
69+
70+ @dpex_exp .kernel
71+ def _kernel (item , a , slm ):
72+ _df (item , a , slm )
73+
74+ return _kernel
75+
76+
8077@pytest .mark .parametrize ("supported_dtype" , list_of_supported_dtypes )
8178@pytest .mark .parametrize (
8279 "nd_range, _kernel" ,
@@ -86,7 +83,17 @@ def _kernel3(nd_item: NdItem, a, slm):
8683 (dpex .NdRange ((1 , 32 , 1 ), (1 , 32 , 1 )), _kernel3 ),
8784 ],
8885)
89- def test_local_accessor (supported_dtype , nd_range : dpex .NdRange , _kernel ):
86+ @pytest .mark .parametrize (
87+ "call_kernel, kernel" ,
88+ [
89+ (dpex_exp .call_kernel , dpex_exp .kernel ),
90+ (dpex_exp .call_kernel , device_func_kernel ),
91+ (kapi_call_kernel , lambda f : f ),
92+ ],
93+ )
94+ def test_local_accessor (
95+ supported_dtype , nd_range : dpex .NdRange , _kernel , call_kernel , kernel
96+ ):
9097 """A test for passing a LocalAccessor object as a kernel argument."""
9198
9299 N = 32
@@ -98,7 +105,7 @@ def test_local_accessor(supported_dtype, nd_range: dpex.NdRange, _kernel):
98105 # `4950 * get_global_linear_id` and stores it into the work groups local
99106 # memory. The local memory is of size 32*64 elements of the requested dtype.
100107 # The result is then stored into `a` in global memory
101- dpex_exp . call_kernel (_kernel , nd_range , a , slm )
108+ call_kernel (kernel ( _kernel ) , nd_range , a , slm )
102109
103110 for idx in range (N ):
104111 assert a [idx ] == 4950 * idx
0 commit comments