@@ -2,7 +2,9 @@ use std::ffi::CString;
22
33use llvm:: Linkage :: * ;
44use rustc_abi:: Align ;
5+ use rustc_codegen_ssa:: mir:: operand:: { OperandRef , OperandValue } ;
56use rustc_codegen_ssa:: traits:: { BaseTypeCodegenMethods , BuilderMethods } ;
7+ use rustc_middle:: bug;
68use rustc_middle:: ty:: offload_meta:: OffloadMetadata ;
79
810use crate :: builder:: Builder ;
@@ -69,6 +71,57 @@ impl<'ll> OffloadGlobals<'ll> {
6971 }
7072}
7173
74+ pub ( crate ) struct OffloadKernelDims < ' ll > {
75+ num_workgroups : & ' ll Value ,
76+ threads_per_block : & ' ll Value ,
77+ workgroup_dims : & ' ll Value ,
78+ thread_dims : & ' ll Value ,
79+ }
80+
81+ impl < ' ll > OffloadKernelDims < ' ll > {
82+ pub ( crate ) fn from_operands < ' tcx > (
83+ builder : & mut Builder < ' _ , ' ll , ' tcx > ,
84+ workgroup_op : & OperandRef < ' tcx , & ' ll llvm:: Value > ,
85+ thread_op : & OperandRef < ' tcx , & ' ll llvm:: Value > ,
86+ ) -> Self {
87+ let cx = builder. cx ;
88+ let arr_ty = cx. type_array ( cx. type_i32 ( ) , 3 ) ;
89+ let four = Align :: from_bytes ( 4 ) . unwrap ( ) ;
90+
91+ let OperandValue :: Ref ( place) = workgroup_op. val else {
92+ bug ! ( "expected array operand by reference" ) ;
93+ } ;
94+ let workgroup_val = builder. load ( arr_ty, place. llval , four) ;
95+
96+ let OperandValue :: Ref ( place) = thread_op. val else {
97+ bug ! ( "expected array operand by reference" ) ;
98+ } ;
99+ let thread_val = builder. load ( arr_ty, place. llval , four) ;
100+
101+ fn mul_dim3 < ' ll , ' tcx > (
102+ builder : & mut Builder < ' _ , ' ll , ' tcx > ,
103+ arr : & ' ll Value ,
104+ ) -> & ' ll Value {
105+ let x = builder. extract_value ( arr, 0 ) ;
106+ let y = builder. extract_value ( arr, 1 ) ;
107+ let z = builder. extract_value ( arr, 2 ) ;
108+
109+ let xy = builder. mul ( x, y) ;
110+ builder. mul ( xy, z)
111+ }
112+
113+ let num_workgroups = mul_dim3 ( builder, workgroup_val) ;
114+ let threads_per_block = mul_dim3 ( builder, thread_val) ;
115+
116+ OffloadKernelDims {
117+ workgroup_dims : workgroup_val,
118+ thread_dims : thread_val,
119+ num_workgroups,
120+ threads_per_block,
121+ }
122+ }
123+ }
124+
72125// ; Function Attrs: nounwind
73126// declare i32 @__tgt_target_kernel(ptr, i64, i32, i32, ptr, ptr) #2
74127fn generate_launcher < ' ll > ( cx : & CodegenCx < ' ll , ' _ > ) -> ( & ' ll llvm:: Value , & ' ll llvm:: Type ) {
@@ -204,12 +257,12 @@ impl KernelArgsTy {
204257 num_args : u64 ,
205258 memtransfer_types : & ' ll Value ,
206259 geps : [ & ' ll Value ; 3 ] ,
260+ workgroup_dims : & ' ll Value ,
261+ thread_dims : & ' ll Value ,
207262 ) -> [ ( Align , & ' ll Value ) ; 13 ] {
208263 let four = Align :: from_bytes ( 4 ) . expect ( "4 Byte alignment should work" ) ;
209264 let eight = Align :: EIGHT ;
210265
211- let ti32 = cx. type_i32 ( ) ;
212- let ci32_0 = cx. get_const_i32 ( 0 ) ;
213266 [
214267 ( four, cx. get_const_i32 ( KernelArgsTy :: OFFLOAD_VERSION ) ) ,
215268 ( four, cx. get_const_i32 ( num_args) ) ,
@@ -222,8 +275,8 @@ impl KernelArgsTy {
222275 ( eight, cx. const_null ( cx. type_ptr ( ) ) ) , // dbg
223276 ( eight, cx. get_const_i64 ( KernelArgsTy :: TRIPCOUNT ) ) ,
224277 ( eight, cx. get_const_i64 ( KernelArgsTy :: FLAGS ) ) ,
225- ( four, cx . const_array ( ti32 , & [ cx . get_const_i32 ( 2097152 ) , ci32_0 , ci32_0 ] ) ) ,
226- ( four, cx . const_array ( ti32 , & [ cx . get_const_i32 ( 256 ) , ci32_0 , ci32_0 ] ) ) ,
278+ ( four, workgroup_dims ) ,
279+ ( four, thread_dims ) ,
227280 ( four, cx. get_const_i32 ( 0 ) ) ,
228281 ]
229282 }
@@ -413,10 +466,13 @@ pub(crate) fn gen_call_handling<'ll, 'tcx>(
413466 types : & [ & Type ] ,
414467 metadata : & [ OffloadMetadata ] ,
415468 offload_globals : & OffloadGlobals < ' ll > ,
469+ offload_dims : & OffloadKernelDims < ' ll > ,
416470) {
417471 let cx = builder. cx ;
418472 let OffloadKernelGlobals { offload_sizes, offload_entry, memtransfer_types, region_id } =
419473 offload_data;
474+ let OffloadKernelDims { num_workgroups, threads_per_block, workgroup_dims, thread_dims } =
475+ offload_dims;
420476
421477 let tgt_decl = offload_globals. launcher_fn ;
422478 let tgt_target_kernel_ty = offload_globals. launcher_ty ;
@@ -554,7 +610,8 @@ pub(crate) fn gen_call_handling<'ll, 'tcx>(
554610 num_args,
555611 s_ident_t,
556612 ) ;
557- let values = KernelArgsTy :: new ( & cx, num_args, memtransfer_types, geps) ;
613+ let values =
614+ KernelArgsTy :: new ( & cx, num_args, memtransfer_types, geps, workgroup_dims, thread_dims) ;
558615
559616 // Step 3)
560617 // Here we fill the KernelArgsTy, see the documentation above
@@ -567,9 +624,8 @@ pub(crate) fn gen_call_handling<'ll, 'tcx>(
567624 s_ident_t,
568625 // FIXME(offload) give users a way to select which GPU to use.
569626 cx. get_const_i64( u64 :: MAX ) , // MAX == -1.
570- // FIXME(offload): Don't hardcode the numbers of threads in the future.
571- cx. get_const_i32( 2097152 ) ,
572- cx. get_const_i32( 256 ) ,
627+ num_workgroups,
628+ threads_per_block,
573629 region_id,
574630 a5,
575631 ] ;
0 commit comments