@@ -9,7 +9,7 @@ use crate::builder_spirv::{SpirvFunctionCursor, SpirvValue, SpirvValueExt};
99use crate :: spirv_type:: SpirvType ;
1010use rspirv:: dr:: Operand ;
1111use rspirv:: spirv:: {
12- Capability , Decoration , Dim , ExecutionModel , FunctionControl , StorageClass , Word ,
12+ BuiltIn , Capability , Decoration , Dim , ExecutionModel , FunctionControl , StorageClass , Word ,
1313} ;
1414use rustc_abi:: FieldsShape ;
1515use rustc_codegen_ssa:: traits:: { BaseTypeCodegenMethods , BuilderMethods , MiscCodegenMethods as _} ;
@@ -722,8 +722,9 @@ impl<'tcx> CodegenCx<'tcx> {
722722 Ok ( StorageClass :: Input | StorageClass :: Output | StorageClass :: UniformConstant )
723723 ) ;
724724 let mut assign_location = |var_id : Result < Word , & str > , explicit : Option < u32 > | {
725+ let storage_class = storage_class. unwrap ( ) ;
725726 let location = decoration_locations
726- . entry ( storage_class. unwrap ( ) )
727+ . entry ( storage_class)
727728 . or_insert_with ( || 0 ) ;
728729 if let Some ( explicit) = explicit {
729730 * location = explicit;
@@ -733,7 +734,46 @@ impl<'tcx> CodegenCx<'tcx> {
733734 Decoration :: Location ,
734735 std:: iter:: once ( Operand :: LiteralBit32 ( * location) ) ,
735736 ) ;
736- let spirv_type = self . lookup_type ( value_spirv_type) ;
737+ let mut spirv_type = self . lookup_type ( value_spirv_type) ;
738+
739+ // These shader types and storage classes skip the outer array or pointer of the declaration when computing
740+ // the location layout, see bug at https://github.com/Rust-GPU/rust-gpu/issues/500.
741+ //
742+ // The match statment follows the rules at:
743+ // https://registry.khronos.org/vulkan/specs/latest/html/vkspec.html#interfaces-iointerfaces-matching
744+ #[ allow( clippy:: match_same_arms) ]
745+ let can_skip_outer_array =
746+ match ( execution_model, storage_class, attrs. per_primitive_ext ) {
747+ // > if the input is declared in a tessellation control or geometry shader...
748+ (
749+ ExecutionModel :: TessellationControl | ExecutionModel :: Geometry ,
750+ StorageClass :: Input ,
751+ _,
752+ ) => true ,
753+ // > if the maintenance4 feature is enabled, they are declared as OpTypeVector variables, and the
754+ // > output has a Component Count value higher than that of the input but the same Component Type
755+ // Irrelevant: This allows a vertex shader to output a Vec4 and a fragment shader to accept a vector
756+ // type with fewer components, like Vec3, Vec2 (or f32?). Which has no influence on locations.
757+ // > if the output is declared in a mesh shader...
758+ ( ExecutionModel :: MeshEXT | ExecutionModel :: MeshNV , StorageClass :: Output , _) => {
759+ true
760+ }
761+ // > if the input is decorated with PerVertexKHR, and is declared in a fragment shader...
762+ ( ExecutionModel :: Fragment , StorageClass :: Input , Some ( _) ) => true ,
763+ // > if in any other case...
764+ ( _, _, _) => false ,
765+ } ;
766+ if can_skip_outer_array {
767+ spirv_type = match spirv_type {
768+ SpirvType :: Array { element, .. }
769+ | SpirvType :: RuntimeArray { element, .. }
770+ | SpirvType :: Pointer {
771+ pointee : element, ..
772+ } => self . lookup_type ( element) ,
773+ e => e,
774+ } ;
775+ }
776+
737777 if let Some ( location_size) = spirv_type. location_size ( self ) {
738778 * location += location_size;
739779 } else {
@@ -916,6 +956,11 @@ impl<'tcx> CodegenCx<'tcx> {
916956 ) ;
917957 }
918958
959+ // Check builtin-specific type requirements.
960+ if let Some ( builtin) = attrs. builtin {
961+ self . check_builtin_type ( hir_param. ty_span , value_layout. ty , builtin) ;
962+ }
963+
919964 if let Ok ( storage_class) = storage_class {
920965 self . check_for_bad_types (
921966 execution_model,
@@ -1083,4 +1128,15 @@ impl<'tcx> CodegenCx<'tcx> {
10831128 }
10841129 }
10851130 }
1131+
1132+ /// Check that builtin variables have the correct type.
1133+ fn check_builtin_type ( & self , span : Span , rust_ty : Ty < ' tcx > , builtin : Spanned < BuiltIn > ) {
1134+ // LocalInvocationIndex must be a u32.
1135+ if builtin. value == BuiltIn :: LocalInvocationIndex && rust_ty != self . tcx . types . u32 {
1136+ self . tcx . dcx ( ) . span_err (
1137+ span,
1138+ format ! ( "`#[spirv(local_invocation_index)]` must be a `u32`, not `{rust_ty}`" ) ,
1139+ ) ;
1140+ }
1141+ }
10861142}
0 commit comments