@@ -239,10 +239,14 @@ class MTLPipelineState : public offloadtest::PipelineState {
239239 std::string Name;
240240 IRRootSignaturePtr RootSig;
241241 std::unique_ptr<MTLTopLevelArgumentBuffer> ArgBuffer;
242- IRShaderReflectionPtr Reflection;
243242 MTL::ComputePipelineState *ComputePipeline = nullptr ;
244243 MTL::RenderPipelineState *RenderPipeline = nullptr ;
245244
245+ // Compute pipeline only state. Threadgroup size comes from numthreads() in
246+ // the HLSL source and is captured from shader reflection at pipeline
247+ // creation, so dispatch() doesn't need to re-query reflection each time.
248+ MTL::Size ThreadsPerGroup = MTL::Size(1 , 1 , 1 );
249+
246250 // Rasterization pipeline only state.
247251 // These are part of the pipeline in DX and VK, but dynamic state in Metal.
248252 // To have a shared API we store these here and set the state when the
@@ -252,11 +256,11 @@ class MTLPipelineState : public offloadtest::PipelineState {
252256
253257 MTLPipelineState (llvm::StringRef Name, IRRootSignaturePtr RootSig,
254258 std::unique_ptr<MTLTopLevelArgumentBuffer> ArgBuffer,
255- IRShaderReflectionPtr Reflection ,
256- MTL::ComputePipelineState *ComputePipeline )
259+ MTL::ComputePipelineState *ComputePipeline ,
260+ MTL::Size ThreadsPerGroup )
257261 : offloadtest::PipelineState(GPUAPI::Metal), Name(Name),
258262 RootSig(std::move(RootSig)), ArgBuffer(std::move(ArgBuffer)),
259- Reflection(std::move(Reflection)), ComputePipeline(ComputePipeline ) {}
263+ ComputePipeline(ComputePipeline), ThreadsPerGroup(ThreadsPerGroup ) {}
260264
261265 MTLPipelineState (llvm::StringRef Name, IRRootSignaturePtr RootSig,
262266 std::unique_ptr<MTLTopLevelArgumentBuffer> ArgBuffer,
@@ -422,11 +426,6 @@ class MTLComputeEncoder : public offloadtest::ComputeEncoder {
422426 MTL::ComputeCommandEncoder *ComputeEnc = nullptr ;
423427 MTL::BlitCommandEncoder *BlitEnc = nullptr ;
424428
425- // / Threadgroup size from shader reflection (the numthreads() attribute
426- // / persisted in the transpiled Metallib). Must be set via
427- // / setThreadGroupSize() before dispatching.
428- MTL::Size ThreadsPerGroup = {1 , 1 , 1 };
429-
430429 // / Accumulated barrier scope from commands recorded since the last barrier.
431430 MTL::BarrierScope PendingScope = MTL::BarrierScope(0 );
432431
@@ -483,13 +482,6 @@ class MTLComputeEncoder : public offloadtest::ComputeEncoder {
483482
484483 MTL::ComputeCommandEncoder *getNative () const { return ComputeEnc; }
485484
486- // / Set the threadgroup size for subsequent dispatch calls. The values must
487- // / come from shader reflection (the numthreads() attribute in the HLSL
488- // / source, persisted in the transpiled Metallib).
489- void setThreadGroupSize (NS::UInteger X, NS::UInteger Y, NS::UInteger Z) {
490- ThreadsPerGroup = MTL::Size (X, Y, Z);
491- }
492-
493485 MTL::CommandEncoder *getActiveEncoder () const {
494486 if (ComputeEnc)
495487 return ComputeEnc;
@@ -513,18 +505,26 @@ class MTLComputeEncoder : public offloadtest::ComputeEncoder {
513505 NS::String::string (Label.data(), NS::UTF8StringEncoding));
514506 }
515507
516- llvm::Error dispatch (uint32_t GroupCountX, uint32_t GroupCountY,
508+ llvm::Error dispatch (const offloadtest::PipelineState &PSO,
509+ uint32_t GroupCountX, uint32_t GroupCountY,
517510 uint32_t GroupCountZ) override {
511+ const auto &MTLPSO = llvm::cast<MTLPipelineState>(PSO);
512+ if (!MTLPSO.ComputePipeline )
513+ return llvm::createStringError (
514+ std::errc::invalid_argument,
515+ " PipelineState bound to dispatch() is not a compute pipeline." );
518516 if (auto Err = ensureComputeEncoder ())
519517 return Err;
520518 flushBarrier ();
521519 insertDebugSignpost (llvm::formatv (" Dispatch [{0},{1},{2}]" , GroupCountX,
522520 GroupCountY, GroupCountZ)
523521 .str ());
524- const MTL::Size GridSize (ThreadsPerGroup.width * GroupCountX,
525- ThreadsPerGroup.height * GroupCountY,
526- ThreadsPerGroup.depth * GroupCountZ);
527- ComputeEnc->dispatchThreads (GridSize, ThreadsPerGroup);
522+ ComputeEnc->setComputePipelineState (MTLPSO.ComputePipeline );
523+
524+ const MTL::Size GridSize (MTLPSO.ThreadsPerGroup .width * GroupCountX,
525+ MTLPSO.ThreadsPerGroup .height * GroupCountY,
526+ MTLPSO.ThreadsPerGroup .depth * GroupCountZ);
527+ ComputeEnc->dispatchThreads (GridSize, MTLPSO.ThreadsPerGroup );
528528 addBarrierScope (MTL::BarrierScopeBuffers | MTL::BarrierScopeTextures);
529529 return llvm::Error::success ();
530530 }
@@ -1287,7 +1287,6 @@ class MTLDevice : public offloadtest::Device {
12871287 MTL::ComputeCommandEncoder *NativeEncoder = Encoder.getNative ();
12881288
12891289 const auto &PS = llvm::cast<MTLPipelineState>(IS.Pipeline .get ());
1290- NativeEncoder->setComputePipelineState (PS->ComputePipeline );
12911290 MTLGPUDescriptorHandle Handle = {};
12921291 if (IS.DescHeap ) {
12931292 IS.DescHeap ->bind (NativeEncoder);
@@ -1307,21 +1306,8 @@ class MTLDevice : public offloadtest::Device {
13071306 MTL::ResourceUsageRead |
13081307 MTL::ResourceUsageWrite);
13091308
1310- NS::UInteger TGS[3 ] = {PS->ComputePipeline ->maxTotalThreadsPerThreadgroup (),
1311- 1 , 1 };
1312- if (PS->Reflection ) {
1313- IRVersionedCSInfo Info;
1314- if (IRShaderReflectionCopyComputeInfo (PS->Reflection .get (),
1315- IRReflectionVersion_1_0, &Info)) {
1316- TGS[0 ] = Info.info_1_0 .tg_size [0 ];
1317- TGS[1 ] = Info.info_1_0 .tg_size [1 ];
1318- TGS[2 ] = Info.info_1_0 .tg_size [2 ];
1319- }
1320- IRShaderReflectionReleaseComputeInfo (&Info);
1321- }
1322- Encoder.setThreadGroupSize (TGS[0 ], TGS[1 ], TGS[2 ]);
1323-
1324- if (auto Err = Encoder.dispatch (P.DispatchParameters .DispatchGroupCount [0 ],
1309+ if (auto Err = Encoder.dispatch (*IS.Pipeline .get (),
1310+ P.DispatchParameters .DispatchGroupCount [0 ],
13251311 P.DispatchParameters .DispatchGroupCount [1 ],
13261312 P.DispatchParameters .DispatchGroupCount [2 ]))
13271313 return Err;
@@ -1574,9 +1560,20 @@ class MTLDevice : public offloadtest::Device {
15741560 if (Error)
15751561 return toError (Error);
15761562
1563+ IRVersionedCSInfo Info;
1564+ if (!IRShaderReflectionCopyComputeInfo (MetalIR->Reflection .get (),
1565+ IRReflectionVersion_1_0, &Info))
1566+ return llvm::createStringError (
1567+ " Failed to read compute reflection for entry point '%s'; cannot "
1568+ " determine threadgroup size from numthreads()." ,
1569+ CS.EntryPoint .c_str ());
1570+ const MTL::Size ThreadsPerGroup (Info.info_1_0 .tg_size [0 ],
1571+ Info.info_1_0 .tg_size [1 ],
1572+ Info.info_1_0 .tg_size [2 ]);
1573+ IRShaderReflectionReleaseComputeInfo (&Info);
1574+
15771575 return std::make_unique<MTLPipelineState>(
1578- Name, std::move (RootSig), std::move (ArgBuffer),
1579- std::move (MetalIR->Reflection ), PSO);
1576+ Name, std::move (RootSig), std::move (ArgBuffer), PSO, ThreadsPerGroup);
15801577 }
15811578
15821579 llvm::Expected<std::unique_ptr<PipelineState>>
0 commit comments