Skip to content

Commit c967959

Browse files
MarijnS95claude
andauthored
Bind compute PSO inside ComputeEncoder::dispatch() (#1217)
`ComputeEncoder::dispatch()` now takes a `PipelineState &` as its first argument and binds the PSO before issuing the dispatch, the same shape as `RenderEncoder::drawInstanced()`. Making the PSO a required parameter enforces at the type level that callers cannot forget to bind a pipeline (an invalid dispatch on every backend), and keeps the bind adjacent to the command it applies to instead of relying on separate "last bound" state on the command list. It also hides per-API requirements behind the encoder: on Metal the threadgroup-size from the HLSL `numthreads()` annotation is read from shader reflection at pipeline creation and cached on `MTLPipelineState`, instead of leaking out as a `setThreadGroupSize()` call every caller had to remember. `dispatch()` just consumes the cached size, and we drop the encoder-state helpers `ThreadsPerGroup` / `setThreadGroupSize()` entirely. With the bind moved, the per-backend compute PSO bind in `createComputeCommands()` (DX, VK, MTL) is removed. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 2034cde commit c967959

4 files changed

Lines changed: 84 additions & 81 deletions

File tree

include/API/Encoder.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,11 @@ class ComputeEncoder : public CommandEncoder {
7171
using CommandEncoder::CommandEncoder;
7272

7373
/// Dispatch a compute grid. GroupCount specifies how many workgroups to
74-
/// launch in each dimension. The workgroup size is derived from the bound
75-
/// pipeline state (e.g. the shader's numthreads attribute).
76-
virtual llvm::Error dispatch(uint32_t GroupCountX, uint32_t GroupCountY,
77-
uint32_t GroupCountZ) = 0;
74+
/// launch in each dimension. The workgroup size is derived from \p PSO
75+
/// (e.g. the shader's numthreads attribute), which is also bound for the
76+
/// dispatch.
77+
virtual llvm::Error dispatch(const PipelineState &PSO, uint32_t GroupCountX,
78+
uint32_t GroupCountY, uint32_t GroupCountZ) = 0;
7879

7980
/// Copy \p Size bytes from \p Src at \p SrcOffset to \p Dst at
8081
/// \p DstOffset.

lib/API/DX/Device.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -657,12 +657,15 @@ class DXComputeEncoder : public offloadtest::ComputeEncoder {
657657
void popDebugGroup() override {}
658658
void insertDebugSignpost(llvm::StringRef Label) override {}
659659

660-
llvm::Error dispatch(uint32_t GroupCountX, uint32_t GroupCountY,
660+
llvm::Error dispatch(const offloadtest::PipelineState &PSO,
661+
uint32_t GroupCountX, uint32_t GroupCountY,
661662
uint32_t GroupCountZ) override {
663+
const auto &DXPSO = llvm::cast<DXPipelineState>(PSO);
662664
addUAVBarrier();
663665
insertDebugSignpost(llvm::formatv("Dispatch [{0},{1},{2}]", GroupCountX,
664666
GroupCountY, GroupCountZ)
665667
.str());
668+
CB.CmdList->SetPipelineState(DXPSO.PSO.Get());
666669
CB.CmdList->Dispatch(GroupCountX, GroupCountY, GroupCountZ);
667670
return llvm::Error::success();
668671
}
@@ -2058,7 +2061,6 @@ class DXDevice : public offloadtest::Device {
20582061
const DXPipelineState &DXPipeline =
20592062
llvm::cast<DXPipelineState>(*IS.Pipeline.get());
20602063
IS.CB->CmdList->SetComputeRootSignature(DXPipeline.RootSig.Get());
2061-
IS.CB->CmdList->SetPipelineState(DXPipeline.PSO.Get());
20622064

20632065
const uint32_t Inc = Device->GetDescriptorHandleIncrementSize(
20642066
D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV);
@@ -2133,10 +2135,10 @@ class DXDevice : public offloadtest::Device {
21332135
if (!EncoderOrErr)
21342136
return EncoderOrErr.takeError();
21352137
auto &Encoder = *EncoderOrErr.get();
2136-
if (auto Err =
2137-
Encoder.dispatch(P.DispatchParameters.DispatchGroupCount[0],
2138-
P.DispatchParameters.DispatchGroupCount[1],
2139-
P.DispatchParameters.DispatchGroupCount[2]))
2138+
if (auto Err = Encoder.dispatch(
2139+
*IS.Pipeline.get(), P.DispatchParameters.DispatchGroupCount[0],
2140+
P.DispatchParameters.DispatchGroupCount[1],
2141+
P.DispatchParameters.DispatchGroupCount[2]))
21402142
return Err;
21412143
Encoder.endEncoding();
21422144
}

lib/API/MTL/MTLDevice.cpp

Lines changed: 36 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -239,10 +239,14 @@ class MTLPipelineState : public offloadtest::PipelineState {
239239
std::string Name;
240240
IRRootSignaturePtr RootSig;
241241
std::unique_ptr<MTLTopLevelArgumentBuffer> ArgBuffer;
242-
IRShaderReflectionPtr Reflection;
243242
MTL::ComputePipelineState *ComputePipeline = nullptr;
244243
MTL::RenderPipelineState *RenderPipeline = nullptr;
245244

245+
// Compute pipeline only state. Threadgroup size comes from numthreads() in
246+
// the HLSL source and is captured from shader reflection at pipeline
247+
// creation, so dispatch() doesn't need to re-query reflection each time.
248+
MTL::Size ThreadsPerGroup = MTL::Size(1, 1, 1);
249+
246250
// Rasterization pipeline only state.
247251
// These are part of the pipeline in DX and VK, but dynamic state in Metal.
248252
// To have a shared API we store these here and set the state when the
@@ -252,11 +256,11 @@ class MTLPipelineState : public offloadtest::PipelineState {
252256

253257
MTLPipelineState(llvm::StringRef Name, IRRootSignaturePtr RootSig,
254258
std::unique_ptr<MTLTopLevelArgumentBuffer> ArgBuffer,
255-
IRShaderReflectionPtr Reflection,
256-
MTL::ComputePipelineState *ComputePipeline)
259+
MTL::ComputePipelineState *ComputePipeline,
260+
MTL::Size ThreadsPerGroup)
257261
: offloadtest::PipelineState(GPUAPI::Metal), Name(Name),
258262
RootSig(std::move(RootSig)), ArgBuffer(std::move(ArgBuffer)),
259-
Reflection(std::move(Reflection)), ComputePipeline(ComputePipeline) {}
263+
ComputePipeline(ComputePipeline), ThreadsPerGroup(ThreadsPerGroup) {}
260264

261265
MTLPipelineState(llvm::StringRef Name, IRRootSignaturePtr RootSig,
262266
std::unique_ptr<MTLTopLevelArgumentBuffer> ArgBuffer,
@@ -422,11 +426,6 @@ class MTLComputeEncoder : public offloadtest::ComputeEncoder {
422426
MTL::ComputeCommandEncoder *ComputeEnc = nullptr;
423427
MTL::BlitCommandEncoder *BlitEnc = nullptr;
424428

425-
/// Threadgroup size from shader reflection (the numthreads() attribute
426-
/// persisted in the transpiled Metallib). Must be set via
427-
/// setThreadGroupSize() before dispatching.
428-
MTL::Size ThreadsPerGroup = {1, 1, 1};
429-
430429
/// Accumulated barrier scope from commands recorded since the last barrier.
431430
MTL::BarrierScope PendingScope = MTL::BarrierScope(0);
432431

@@ -483,13 +482,6 @@ class MTLComputeEncoder : public offloadtest::ComputeEncoder {
483482

484483
MTL::ComputeCommandEncoder *getNative() const { return ComputeEnc; }
485484

486-
/// Set the threadgroup size for subsequent dispatch calls. The values must
487-
/// come from shader reflection (the numthreads() attribute in the HLSL
488-
/// source, persisted in the transpiled Metallib).
489-
void setThreadGroupSize(NS::UInteger X, NS::UInteger Y, NS::UInteger Z) {
490-
ThreadsPerGroup = MTL::Size(X, Y, Z);
491-
}
492-
493485
MTL::CommandEncoder *getActiveEncoder() const {
494486
if (ComputeEnc)
495487
return ComputeEnc;
@@ -513,18 +505,26 @@ class MTLComputeEncoder : public offloadtest::ComputeEncoder {
513505
NS::String::string(Label.data(), NS::UTF8StringEncoding));
514506
}
515507

516-
llvm::Error dispatch(uint32_t GroupCountX, uint32_t GroupCountY,
508+
llvm::Error dispatch(const offloadtest::PipelineState &PSO,
509+
uint32_t GroupCountX, uint32_t GroupCountY,
517510
uint32_t GroupCountZ) override {
511+
const auto &MTLPSO = llvm::cast<MTLPipelineState>(PSO);
512+
if (!MTLPSO.ComputePipeline)
513+
return llvm::createStringError(
514+
std::errc::invalid_argument,
515+
"PipelineState bound to dispatch() is not a compute pipeline.");
518516
if (auto Err = ensureComputeEncoder())
519517
return Err;
520518
flushBarrier();
521519
insertDebugSignpost(llvm::formatv("Dispatch [{0},{1},{2}]", GroupCountX,
522520
GroupCountY, GroupCountZ)
523521
.str());
524-
const MTL::Size GridSize(ThreadsPerGroup.width * GroupCountX,
525-
ThreadsPerGroup.height * GroupCountY,
526-
ThreadsPerGroup.depth * GroupCountZ);
527-
ComputeEnc->dispatchThreads(GridSize, ThreadsPerGroup);
522+
ComputeEnc->setComputePipelineState(MTLPSO.ComputePipeline);
523+
524+
const MTL::Size GridSize(MTLPSO.ThreadsPerGroup.width * GroupCountX,
525+
MTLPSO.ThreadsPerGroup.height * GroupCountY,
526+
MTLPSO.ThreadsPerGroup.depth * GroupCountZ);
527+
ComputeEnc->dispatchThreads(GridSize, MTLPSO.ThreadsPerGroup);
528528
addBarrierScope(MTL::BarrierScopeBuffers | MTL::BarrierScopeTextures);
529529
return llvm::Error::success();
530530
}
@@ -1287,7 +1287,6 @@ class MTLDevice : public offloadtest::Device {
12871287
MTL::ComputeCommandEncoder *NativeEncoder = Encoder.getNative();
12881288

12891289
const auto &PS = llvm::cast<MTLPipelineState>(IS.Pipeline.get());
1290-
NativeEncoder->setComputePipelineState(PS->ComputePipeline);
12911290
MTLGPUDescriptorHandle Handle = {};
12921291
if (IS.DescHeap) {
12931292
IS.DescHeap->bind(NativeEncoder);
@@ -1307,21 +1306,8 @@ class MTLDevice : public offloadtest::Device {
13071306
MTL::ResourceUsageRead |
13081307
MTL::ResourceUsageWrite);
13091308

1310-
NS::UInteger TGS[3] = {PS->ComputePipeline->maxTotalThreadsPerThreadgroup(),
1311-
1, 1};
1312-
if (PS->Reflection) {
1313-
IRVersionedCSInfo Info;
1314-
if (IRShaderReflectionCopyComputeInfo(PS->Reflection.get(),
1315-
IRReflectionVersion_1_0, &Info)) {
1316-
TGS[0] = Info.info_1_0.tg_size[0];
1317-
TGS[1] = Info.info_1_0.tg_size[1];
1318-
TGS[2] = Info.info_1_0.tg_size[2];
1319-
}
1320-
IRShaderReflectionReleaseComputeInfo(&Info);
1321-
}
1322-
Encoder.setThreadGroupSize(TGS[0], TGS[1], TGS[2]);
1323-
1324-
if (auto Err = Encoder.dispatch(P.DispatchParameters.DispatchGroupCount[0],
1309+
if (auto Err = Encoder.dispatch(*IS.Pipeline.get(),
1310+
P.DispatchParameters.DispatchGroupCount[0],
13251311
P.DispatchParameters.DispatchGroupCount[1],
13261312
P.DispatchParameters.DispatchGroupCount[2]))
13271313
return Err;
@@ -1574,9 +1560,20 @@ class MTLDevice : public offloadtest::Device {
15741560
if (Error)
15751561
return toError(Error);
15761562

1563+
IRVersionedCSInfo Info;
1564+
if (!IRShaderReflectionCopyComputeInfo(MetalIR->Reflection.get(),
1565+
IRReflectionVersion_1_0, &Info))
1566+
return llvm::createStringError(
1567+
"Failed to read compute reflection for entry point '%s'; cannot "
1568+
"determine threadgroup size from numthreads().",
1569+
CS.EntryPoint.c_str());
1570+
const MTL::Size ThreadsPerGroup(Info.info_1_0.tg_size[0],
1571+
Info.info_1_0.tg_size[1],
1572+
Info.info_1_0.tg_size[2]);
1573+
IRShaderReflectionReleaseComputeInfo(&Info);
1574+
15771575
return std::make_unique<MTLPipelineState>(
1578-
Name, std::move(RootSig), std::move(ArgBuffer),
1579-
std::move(MetalIR->Reflection), PSO);
1576+
Name, std::move(RootSig), std::move(ArgBuffer), PSO, ThreadsPerGroup);
15801577
}
15811578

15821579
llvm::Expected<std::unique_ptr<PipelineState>>

lib/API/VK/Device.cpp

Lines changed: 35 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -703,6 +703,32 @@ class VulkanCommandBuffer : public offloadtest::CommandBuffer {
703703
VulkanCommandBuffer() : CommandBuffer(GPUAPI::Vulkan) {}
704704
};
705705

706+
class VulkanPipelineState : public offloadtest::PipelineState {
707+
public:
708+
std::string Name;
709+
VkDevice Dev;
710+
VkPipeline Pipeline;
711+
VkPipelineLayout Layout;
712+
llvm::SmallVector<VkDescriptorSetLayout> SetLayouts;
713+
714+
VulkanPipelineState(llvm::StringRef Name, VkDevice Dev, VkPipeline Pipeline,
715+
VkPipelineLayout Layout,
716+
llvm::SmallVector<VkDescriptorSetLayout> SetLayouts)
717+
: offloadtest::PipelineState(GPUAPI::Vulkan), Name(Name.str()), Dev(Dev),
718+
Pipeline(Pipeline), Layout(Layout), SetLayouts(std::move(SetLayouts)) {}
719+
720+
~VulkanPipelineState() override {
721+
vkDestroyPipeline(Dev, Pipeline, nullptr);
722+
vkDestroyPipelineLayout(Dev, Layout, nullptr);
723+
for (VkDescriptorSetLayout L : SetLayouts)
724+
vkDestroyDescriptorSetLayout(Dev, L, nullptr);
725+
}
726+
727+
static bool classof(const offloadtest::PipelineState *B) {
728+
return B->getAPI() == GPUAPI::Vulkan;
729+
}
730+
};
731+
706732
class VKComputeEncoder : public offloadtest::ComputeEncoder {
707733
VulkanCommandBuffer &CB;
708734

@@ -729,13 +755,17 @@ class VKComputeEncoder : public offloadtest::ComputeEncoder {
729755
CB.insertDebugSignpost(Label);
730756
}
731757

732-
llvm::Error dispatch(uint32_t GroupCountX, uint32_t GroupCountY,
758+
llvm::Error dispatch(const offloadtest::PipelineState &PSO,
759+
uint32_t GroupCountX, uint32_t GroupCountY,
733760
uint32_t GroupCountZ) override {
761+
const auto &VKPSO = llvm::cast<VulkanPipelineState>(PSO);
734762
addDstBarrier(VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT,
735763
VK_ACCESS_SHADER_WRITE_BIT | VK_ACCESS_SHADER_READ_BIT);
736764
insertDebugSignpost(llvm::formatv("Dispatch [{0},{1},{2}]", GroupCountX,
737765
GroupCountY, GroupCountZ)
738766
.str());
767+
vkCmdBindPipeline(CB.CmdBuffer, VK_PIPELINE_BIND_POINT_COMPUTE,
768+
VKPSO.Pipeline);
739769
vkCmdDispatch(CB.CmdBuffer, GroupCountX, GroupCountY, GroupCountZ);
740770
return llvm::Error::success();
741771
}
@@ -765,32 +795,6 @@ VulkanCommandBuffer::createComputeEncoder() {
765795
return Enc;
766796
}
767797

768-
class VulkanPipelineState : public offloadtest::PipelineState {
769-
public:
770-
std::string Name;
771-
VkDevice Dev;
772-
VkPipeline Pipeline;
773-
VkPipelineLayout Layout;
774-
llvm::SmallVector<VkDescriptorSetLayout> SetLayouts;
775-
776-
VulkanPipelineState(llvm::StringRef Name, VkDevice Dev, VkPipeline Pipeline,
777-
VkPipelineLayout Layout,
778-
llvm::SmallVector<VkDescriptorSetLayout> SetLayouts)
779-
: offloadtest::PipelineState(GPUAPI::Vulkan), Name(Name.str()), Dev(Dev),
780-
Pipeline(Pipeline), Layout(Layout), SetLayouts(std::move(SetLayouts)) {}
781-
782-
~VulkanPipelineState() override {
783-
vkDestroyPipeline(Dev, Pipeline, nullptr);
784-
vkDestroyPipelineLayout(Dev, Layout, nullptr);
785-
for (VkDescriptorSetLayout L : SetLayouts)
786-
vkDestroyDescriptorSetLayout(Dev, L, nullptr);
787-
}
788-
789-
static bool classof(const offloadtest::PipelineState *B) {
790-
return B->getAPI() == GPUAPI::Vulkan;
791-
}
792-
};
793-
794798
static VkAttachmentLoadOp getVkLoadOp(offloadtest::LoadAction Action) {
795799
switch (Action) {
796800
case offloadtest::LoadAction::Load:
@@ -2966,7 +2970,6 @@ class VulkanDevice : public offloadtest::Device {
29662970
: VK_PIPELINE_BIND_POINT_COMPUTE;
29672971
const VulkanPipelineState &VulkanPipeline =
29682972
llvm::cast<VulkanPipelineState>(*IS.Pipeline.get());
2969-
vkCmdBindPipeline(IS.CB->CmdBuffer, BindPoint, VulkanPipeline.Pipeline);
29702973
if (IS.DescriptorSets.size() > 0)
29712974
vkCmdBindDescriptorSets(
29722975
IS.CB->CmdBuffer, BindPoint, VulkanPipeline.Layout, 0,
@@ -2985,10 +2988,10 @@ class VulkanDevice : public offloadtest::Device {
29852988
if (!EncoderOrErr)
29862989
return EncoderOrErr.takeError();
29872990
auto &Encoder = *EncoderOrErr.get();
2988-
if (auto Err =
2989-
Encoder.dispatch(P.DispatchParameters.DispatchGroupCount[0],
2990-
P.DispatchParameters.DispatchGroupCount[1],
2991-
P.DispatchParameters.DispatchGroupCount[2]))
2991+
if (auto Err = Encoder.dispatch(
2992+
*IS.Pipeline.get(), P.DispatchParameters.DispatchGroupCount[0],
2993+
P.DispatchParameters.DispatchGroupCount[1],
2994+
P.DispatchParameters.DispatchGroupCount[2]))
29922995
return Err;
29932996
Encoder.endEncoding();
29942997
llvm::outs() << "Dispatched compute shader: { "

0 commit comments

Comments
 (0)