Skip to content

Commit 87bb8b1

Browse files
authored
[TIR] Introduce Pass InjectPTXLDG32 (#13973)
This PR introduces a new pass InjectPTXLDG32 to change the `if_then_else` call node to `ptx_pred_ldg32` call node. When the store buffer is local and the load value is global, the pass can change the if_then_else pattern to a ptx pattern. Test the pass with ```python with tvm.transform.PassContext(config={"tir.ptx_pred_ldg32": True}): mod = tvm.build(f, target="cuda") ```
1 parent 054c11e commit 87bb8b1

7 files changed

Lines changed: 260 additions & 0 deletions

File tree

include/tvm/tir/builtin.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -610,6 +610,18 @@ TVM_DLL const Op& tvm_store_matrix_sync();
610610
*/
611611
TVM_DLL const Op& ptx_mma();
612612

613+
/*!
614+
* \brief tvm intrinsic for ptx predicate load with 32-bit data type.
615+
*
616+
*/
617+
TVM_DLL const Op& ptx_ldg32();
618+
619+
/*!
620+
* \brief tvm intrinsic for ptx predicate load with 32-bit data type.
621+
*
622+
*/
623+
TVM_DLL const Op& ptx_ldg32();
624+
613625
/*!
614626
* \brief tvm intrinsic for sparse tensor core ptx instructions.
615627
*

include/tvm/tir/transform.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -677,6 +677,12 @@ TVM_DLL Pass Filter(runtime::TypedPackedFunc<bool(PrimFunc)> fcond);
677677
*/
678678
TVM_DLL Pass InjectPTXAsyncCopy();
679679

680+
/*!
681+
* \brief Pass to rewrite global to local memory copy on CUDA with ldg32 instruction.
682+
* \return The pass.
683+
*/
684+
TVM_DLL Pass InjectPTXLDG32(bool enable_ptx_ldg32 = true);
685+
680686
/*!
681687
* \brief Remove the weight layout rewrite block
682688
* \param skip_ndarray_rewrite If True, exact rewrite of NDArray, according to the given index map,

src/driver/driver_api.cc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tir.use_async_copy", Bool);
5555
TVM_REGISTER_PASS_CONFIG_OPTION("tir.merge_async_commit_queue_scope", Bool);
5656
TVM_REGISTER_PASS_CONFIG_OPTION("tir.instrument_lwp", Bool);
5757
TVM_REGISTER_PASS_CONFIG_OPTION("tir.vtcm_capacity", Integer);
58+
TVM_REGISTER_PASS_CONFIG_OPTION("tir.ptx_ldg32", Bool);
5859

5960
// WARNING: May cause coherency issues resulting data miscompares
6061
// Experimental feature that, when enabled by the runtime, bypasses the cache when using DMA. When
@@ -159,6 +160,8 @@ Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition) {
159160
bool enable_equiv_terms_in_cse_tir =
160161
pass_ctx->GetConfig<Bool>("tir.enable_equiv_terms_in_cse_tir", Bool(false)).value();
161162

163+
bool ptx_ldg32 = pass_ctx->GetConfig<Bool>("tir.ptx_ldg32", Bool(false)).value();
164+
162165
// Get any user-added passes
163166
Array<Array<ObjectRef>> add_lower_pass =
164167
pass_ctx->GetConfig<Array<Array<ObjectRef>>>("tir.add_lower_pass", Array<Array<ObjectRef>>())
@@ -257,6 +260,10 @@ Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition) {
257260
pass_list.push_back(tir::transform::InstrumentBoundCheckers());
258261
}
259262

263+
if (ptx_ldg32) {
264+
pass_list.push_back(tir::transform::InjectPTXLDG32(true));
265+
}
266+
260267
pass_list.push_back(
261268
tir::transform::CommonSubexprElimTIR(!disable_cse_tir, enable_equiv_terms_in_cse_tir));
262269

@@ -584,6 +591,11 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target)
584591
mixed_pass_list.push_back(tir::transform::InjectPTXAsyncCopy());
585592
}
586593

594+
bool ptx_ldg32 = pass_ctx->GetConfig<Bool>("tir.ptx_ldg32", Bool(false)).value();
595+
if (ptx_ldg32) {
596+
mixed_pass_list.push_back(tir::transform::InjectPTXLDG32());
597+
}
598+
587599
bool unpacked_api = mixed_mod->GetAttr<relay::Executor>(tvm::attr::kExecutor)
588600
.value_or(relay::Executor::Create("graph", {}))
589601
->GetAttr<Bool>("unpacked-api")

src/target/source/codegen_cuda.cc

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -926,6 +926,37 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
926926
} else if (op->op.same_as(builtin::ptx_wait_group())) {
927927
std::string N = this->PrintExpr(op->args[0]);
928928
this->stream << "__asm__ __volatile__(\"cp.async.wait_group " + N + ";\");\n\n";
929+
} else if (op->op.same_as(builtin::ptx_ldg32())) {
930+
/*
931+
asm volatile (
932+
"{.reg .pred p;\n"
933+
" setp.ne.b32 p, %2, 0;\n"
934+
// " @p ld.global.nc.f32 %0, [%1];}\n"t
935+
" @p ld.global.nc.L2::128B.f32 %0, [%1];}\n"
936+
: "=f"(reg)
937+
: "l"(addr), "r"((int)guard)
938+
);
939+
*/
940+
941+
// get local
942+
std::string reg = this->PrintExpr(op->args[0]);
943+
// get guard
944+
std::string guard = this->PrintExpr(op->args[1]);
945+
const BufferLoadNode* addr_buffer = op->args[2].as<BufferLoadNode>();
946+
std::string global_addr = this->PrintExpr(addr_buffer->indices[0]);
947+
std::string global_buffer = this->PrintExpr(addr_buffer->buffer->data);
948+
std::string local_addr = this->PrintExpr(op->args[3]);
949+
this->stream << "asm volatile (\n";
950+
this->stream << "\"{.reg .pred p;\\n\"\n";
951+
this->stream << "\" setp.ne.b32 p, %2, 0;\\n\"\n";
952+
this->stream << "\" @!p mov.b32 %0, 0;\\n\"\n";
953+
this->stream << "\" @p ld.global.nc.f32 %0, [%1];}\\n\"\n";
954+
// stream << "\" @p ld.global.nc.L2::128B.f32 %0, [%1];}\\n\"\n" ;
955+
stream << ": \"=f\"(" << reg << "[" << local_addr << "]"
956+
<< ")\n";
957+
stream << ": \"l\"((void*)(" << global_buffer << "+" << global_addr << ")), \"r\"((int)"
958+
<< guard << ")\n";
959+
stream << ");\n";
929960
} else {
930961
CodeGenC::VisitExpr_(op, os);
931962
}

src/tir/op/builtin.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,8 @@ TIR_DEFINE_BUILTIN_FUNC(tvm_store_matrix_sync)
251251

252252
TIR_DEFINE_BUILTIN_FUNC(ptx_mma).set_attr<TCallEffectKind>("TCallEffectKind",
253253
Integer(CallEffectKind::kOpaque));
254+
TIR_DEFINE_BUILTIN_FUNC(ptx_ldg32).set_num_inputs(4).set_attr<TCallEffectKind>(
255+
"TCallEffectKind", Integer(CallEffectKind::kPure));
254256

255257
TIR_DEFINE_BUILTIN_FUNC(ptx_mma_sp)
256258
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
#include <tvm/arith/analyzer.h>
21+
#include <tvm/arith/iter_affine_map.h>
22+
#include <tvm/runtime/registry.h>
23+
#include <tvm/tir/analysis.h>
24+
#include <tvm/tir/op.h>
25+
#include <tvm/tir/stmt.h>
26+
#include <tvm/tir/stmt_functor.h>
27+
#include <tvm/tir/transform.h>
28+
29+
#include "../../arith/const_fold.h"
30+
#include "../../arith/pattern_match.h"
31+
32+
namespace tvm {
33+
namespace tir {
34+
35+
class PTXRewriter : public StmtMutator {
36+
public:
37+
Stmt VisitStmt_(const AllocateNode* allocate) final {
38+
if (!has_buffer_1) {
39+
has_buffer_1 = true;
40+
// addr[0] -> global_addr / addr[1] -> local_addr
41+
addr_buffer = decl_buffer({IntImm(DataType::Int(32), 2)}, DataType::Int(32), "addr", "local");
42+
predicate_buffer =
43+
decl_buffer({IntImm(DataType::Int(32), 1)}, DataType::Bool(1), "predicate", "local");
44+
}
45+
Stmt result = StmtMutator::VisitStmt_(allocate);
46+
if (!has_buffer_2) {
47+
has_buffer_2 = true;
48+
result =
49+
Allocate(addr_buffer->data, addr_buffer->dtype, addr_buffer->shape, Bool(true), result);
50+
result = Allocate(predicate_buffer->data, predicate_buffer->dtype, predicate_buffer->shape,
51+
Bool(true), result);
52+
}
53+
return result;
54+
}
55+
56+
Stmt VisitStmt_(const BufferStoreNode* store) final {
57+
Stmt result = StmtMutator::VisitStmt_(store);
58+
Buffer load_buffer = store->buffer;
59+
PrimExpr load_value = store->value;
60+
// const BufferLoadNode* gload = load_value.as<BufferLoadNode>(); // take
61+
// the place of instance of
62+
const CallNode* call = load_value.as<CallNode>();
63+
if (call != nullptr) {
64+
const OpNode* op = call->op.as<OpNode>();
65+
if (op != nullptr && op->name == "tir.if_then_else") {
66+
const PrimExpr& predicate = call->args[0];
67+
const PrimExpr& lhs = call->args[1];
68+
const PrimExpr& rhs = call->args[2];
69+
PrimExpr global_addr, local_addr;
70+
const BufferLoadNode* load = lhs.as<BufferLoadNode>();
71+
PrimExpr imm_value = rhs;
72+
if (load == nullptr) {
73+
load = rhs.as<BufferLoadNode>();
74+
imm_value = lhs;
75+
if (load == nullptr) {
76+
return result;
77+
}
78+
}
79+
global_addr = load->indices[0];
80+
const RampNode* ramp = global_addr.as<RampNode>();
81+
if (ramp != nullptr) {
82+
return result;
83+
}
84+
local_addr = store->indices[0];
85+
BufferStore addr_store(addr_buffer, global_addr, {IntImm(DataType::Int(32), 0)});
86+
BufferStore local_addr_store(addr_buffer, local_addr, {IntImm(DataType::Int(32), 1)});
87+
BufferStore predicate_store(predicate_buffer, predicate, {IntImm(DataType::Int(32), 0)});
88+
PrimExpr new_lhs, new_rhs, new_predicate, new_indice;
89+
new_lhs =
90+
BufferLoad(load->buffer, {BufferLoad(addr_buffer, {IntImm(DataType::Int(32), 0)})});
91+
new_rhs = IntImm(DataType::Int(32), 0);
92+
new_predicate = BufferLoad(predicate_buffer, {IntImm(DataType::Int(32), 0)});
93+
new_indice = BufferLoad(addr_buffer, {IntImm(DataType::Int(32), 1)});
94+
BufferStore value_store(store->buffer, imm_value, {new_indice});
95+
Evaluate ptx_load(Call(store->buffer->dtype, tvm::tir::builtin::ptx_ldg32(),
96+
{store->buffer->data, new_predicate, new_lhs, new_indice}));
97+
Array<Stmt> tmp_seq = {addr_store, local_addr_store, predicate_store, value_store,
98+
ptx_load};
99+
SeqStmt seq_stmt = SeqStmt(tmp_seq);
100+
return seq_stmt;
101+
}
102+
}
103+
return result;
104+
}
105+
106+
bool has_buffer_1 = false, has_buffer_2 = false;
107+
Buffer addr_buffer, predicate_buffer;
108+
};
109+
110+
namespace transform {
111+
112+
Pass InjectPTXLDG32(bool enable_inject_ptx_intrin) {
113+
auto pass_func = [enable_inject_ptx_intrin](PrimFunc f, IRModule m, PassContext ctx) {
114+
if (enable_inject_ptx_intrin) {
115+
auto* n = f.CopyOnWrite();
116+
n->body = PTXRewriter()(n->body);
117+
// inject ptx
118+
}
119+
return f;
120+
};
121+
return CreatePrimFuncPass(pass_func, 0, "tir.InjectPTXLDG32", {});
122+
}
123+
124+
// The pass can now be invoked via the pass infrastructure, but we also add a
125+
// Python binding for it
126+
TVM_REGISTER_GLOBAL("tir.transform.InjectPTXLDG32").set_body_typed(InjectPTXLDG32);
127+
128+
} // namespace transform
129+
} // namespace tir
130+
} // namespace tvm
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
import tvm
18+
from tvm.script import tir as T
19+
import numpy as np
20+
import tvm.testing
21+
22+
23+
@T.prim_func
24+
def vector_add(A: T.Buffer((16), "float32"), B: T.Buffer((32), "float32")) -> None:
25+
T.func_attr({"global_symbol": "default_function", "tir.noalias": True})
26+
bx = T.env_thread("blockIdx.x")
27+
tx = T.env_thread("threadIdx.x")
28+
T.launch_thread(bx, 1)
29+
T.launch_thread(tx, 32)
30+
A_local = T.Buffer((32), "float32", scope="local")
31+
32+
with T.block():
33+
T.reads(A[0:16])
34+
T.writes(A_local[0:32])
35+
A_local[tx] = T.if_then_else(tx % 2 == 0, A[tx / 2], T.float32(0), dtype="float32")
36+
B[tx] = A_local[tx] + 1.0
37+
38+
39+
@tvm.testing.requires_cuda
40+
def test_inject_ptx_intrin():
41+
f = vector_add
42+
arch = tvm.contrib.nvcc.get_target_compute_version()
43+
major, _ = tvm.contrib.nvcc.parse_compute_version(arch)
44+
if major < 8:
45+
# Require at least SM80
46+
return
47+
with tvm.transform.PassContext(config={"tir.ptx_ldg32": True}):
48+
mod = tvm.build(f, target="cuda")
49+
A_np = np.random.rand(16).astype("float32")
50+
B_np = np.zeros((32)).astype("float32")
51+
dev = tvm.cuda(0)
52+
A_nd = tvm.nd.array(A_np, device=dev)
53+
B_nd = tvm.nd.array(B_np, device=dev)
54+
mod(A_nd, B_nd)
55+
56+
C_np = np.zeros((32)).astype("float32")
57+
58+
for i in range(32):
59+
if i % 2 == 0:
60+
C_np[i] = A_np[i // 2]
61+
C_np[i] += 1.0
62+
63+
tvm.testing.assert_allclose(B_nd.numpy(), C_np)
64+
65+
66+
if __name__ == "__main__":
67+
test_inject_ptx_intrin()

0 commit comments

Comments
 (0)