[TIR] Shuffle in PointerValueTypeRewrite for scalar reads #15517
[TIR] Shuffle in PointerValueTypeRewrite for scalar reads #15517tqchen merged 5 commits intoapache:mainfrom
Conversation
Added an option `rewrite_scalar_read_to_vector_shuffle` in `PointerValueTypeRewrite` (currently only enabled for Vulkan). When enabled, when a buffer has both scalar and vector reads, the buffer will be vectorized if possible and scalar reads will be achieved via T.Shuffle. Close apache#15463.
|
Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.
Generated by tvm-bot |
|
@Lunderberg do you mind help reviewing this |
|
I can, though I probably won't have time available to do so until later next week. |
|
ah ok, i can help to take a look then |
| if (me->coeff > 0) { | ||
| // When coeff == 0, the index is constant and doesn't need to be recorded since it can | ||
| // always be rewritten to shuffle. | ||
| var_info.access_dtype.insert(access_dtype.with_lanes(me->coeff)); |
There was a problem hiding this comment.
i think we need a data structure that captures two things:
struct VarReadInfo {
DataType access_dtype;
// maintained as GCD of all coef of access index
int64_t simd_coeff;
};This is mainly to ensure that we don't end up use a very large vector here
There was a problem hiding this comment.
access_dtype also contains dtype for write, so the result dtype will be eventually bound by the vectorization lanes of writes
There was a problem hiding this comment.
the main corner case would be when we do not have writes and only reads.
e523259 to
9a90ec1
Compare
|
Hi @vinx13, we noticed that this PR breaks the WebGPU codegen as the WebGPU codegen right now does not support tir::ShuffleNode. Therefore, exceptions are thrown in pass Here is a reproducible example https://gist.github.com/MasterJH5574/3e6141a1d6dcafd383ed6f7ac27e3318. Need to run it under this commit mlc-ai/relax@631f37b since the WebGPU codegen in mlc-ai/relax is slightly different from the codegen here due to some conflicts. cc @tqchen PrimFunc after PointerValueTypeRewrite: # from tvm.script import tir as T
@T.prim_func
def fused_fused_decode4_NT_matmul9_kernel(lv1654: T.handle("float16x4", "global"), lv635: T.handle("uint32", "global"), lv636: T.handle("float16", "global"), var_NT_matmul_intermediate: T.handle("float16", "global")):
T.func_attr({"calling_conv": 2, "target": T.target({"host": {"keys": ["cpu"], "kind": "llvm", "mtriple": "wasm32-unknown-unknown-wasm", "tag": ""}, "keys": ["webgpu", "gpu"], "kind": "webgpu", "max_num_threads": 256, "tag": ""}), "tir.is_global_func": T.bool(True), "tir.kernel_launch_params": ["blockIdx.x", "threadIdx.x", "threadIdx.y"], "tir.noalias": T.bool(True)})
var_NT_matmul_intermediate_1 = T.decl_buffer((22016,), "float16", data=var_NT_matmul_intermediate)
red_buf0 = T.handle("float16", "shared")
red_buf0_1 = T.decl_buffer((256,), "float16", data=red_buf0, scope="shared")
var_NT_matmul_intermediate_rf_local_1 = T.handle("float16", "local")
var_NT_matmul_intermediate_rf_local_1_1 = T.decl_buffer((1,), "float16", data=var_NT_matmul_intermediate_rf_local_1, scope="local")
lv636_1 = T.decl_buffer((2818048,), "float16", data=lv636)
lv635_1 = T.decl_buffer((11272192,), "uint32", data=lv635)
lv635_local = T.handle("uint32", "local")
lv635_local_1 = T.decl_buffer((1,), "uint32", data=lv635_local, scope="local")
var_NT_matmul_intermediate_rf_local = T.handle("float16x2", "local")
var_NT_matmul_intermediate_rf_local_2 = T.decl_buffer((2,), "float16", data=var_NT_matmul_intermediate_rf_local, scope="local")
lv1654_1 = T.decl_buffer((4096,), "float16", data=lv1654)
lv1654_shared = T.handle("float16", "shared")
lv1654_shared_1 = T.decl_buffer((4096,), "float16", data=lv1654_shared, scope="shared")
blockIdx_x = T.launch_thread("blockIdx.x", 688)
lv1654_shared = T.allocate([4096], "float16", "shared")
var_NT_matmul_intermediate_rf_local = T.allocate([1], "float16x2", "local")
lv635_local = T.allocate([1], "uint32", "local")
var_NT_matmul_intermediate_rf_local_1 = T.allocate([1], "float16", "local")
red_buf0 = T.allocate([256], "float16", "shared")
T.attr(red_buf0, "volatile_scope", 1)
threadIdx_x = T.launch_thread("threadIdx.x", 8)
threadIdx_y = T.launch_thread("threadIdx.y", 32)
ax2_0 = T.int32()
with T.attr(ax2_0, "pragma_vectorize", 1):
for ax2_0 in range(4):
lv1654_2 = T.Buffer((1024,), "float16x4", data=lv1654)
lv1654_shared_1[ax2_0 * 1024 + threadIdx_y * 32 + threadIdx_x * 4:ax2_0 * 1024 + threadIdx_y * 32 + threadIdx_x * 4 + 4] = lv1654_2[T.Div(ax2_0 * 1024 + threadIdx_y * 32 + threadIdx_x * 4, 4)]
var_NT_matmul_intermediate_rf_local_3 = T.Buffer((1,), "float16x2", data=var_NT_matmul_intermediate_rf_local, scope="local")
var_NT_matmul_intermediate_rf_local_3[0] = T.Broadcast(T.float16(0), 2)
T.tvm_storage_sync("shared")
for ax1_0_fused_ax1_1_fused_0 in range(64):
lv635_local_1[0] = lv635_1[blockIdx_x * 16384 + threadIdx_y * 512 + ax1_0_fused_ax1_1_fused_0 * 8 + threadIdx_x]
var_NT_matmul_intermediate_rf_local_3[0] = T.call_pure_extern("float16x2", "fma", lv1654_shared_1[ax1_0_fused_ax1_1_fused_0 * 64 + threadIdx_x * 8:ax1_0_fused_ax1_1_fused_0 * 64 + threadIdx_x * 8 + 2], (T.Cast("float16x2", T.bitwise_and(T.shift_right(T.Broadcast(lv635_local_1[0], 2), T.Ramp(T.uint32(0), T.uint32(4), 2)), T.Broadcast(T.uint32(15), 2))) - T.Broadcast(T.float16(7), 2)) * T.Broadcast(lv636_1[blockIdx_x * 4096 + threadIdx_y * 128 + ax1_0_fused_ax1_1_fused_0 * 2 + T.shift_right(threadIdx_x, 2)], 2), var_NT_matmul_intermediate_rf_local_3[0])
var_NT_matmul_intermediate_rf_local_3[0] = T.call_pure_extern("float16x2", "fma", lv1654_shared_1[ax1_0_fused_ax1_1_fused_0 * 64 + threadIdx_x * 8 + 2:ax1_0_fused_ax1_1_fused_0 * 64 + threadIdx_x * 8 + 2 + 2], (T.Cast("float16x2", T.bitwise_and(T.shift_right(T.Broadcast(lv635_local_1[0], 2), T.Ramp(T.uint32(8), T.uint32(4), 2)), T.Broadcast(T.uint32(15), 2))) - T.Broadcast(T.float16(7), 2)) * T.Broadcast(lv636_1[blockIdx_x * 4096 + threadIdx_y * 128 + ax1_0_fused_ax1_1_fused_0 * 2 + T.shift_right(threadIdx_x, 2)], 2), var_NT_matmul_intermediate_rf_local_3[0])
var_NT_matmul_intermediate_rf_local_3[0] = T.call_pure_extern("float16x2", "fma", lv1654_shared_1[ax1_0_fused_ax1_1_fused_0 * 64 + threadIdx_x * 8 + 4:ax1_0_fused_ax1_1_fused_0 * 64 + threadIdx_x * 8 + 4 + 2], (T.Cast("float16x2", T.bitwise_and(T.shift_right(T.Broadcast(lv635_local_1[0], 2), T.Ramp(T.uint32(16), T.uint32(4), 2)), T.Broadcast(T.uint32(15), 2))) - T.Broadcast(T.float16(7), 2)) * T.Broadcast(lv636_1[blockIdx_x * 4096 + threadIdx_y * 128 + ax1_0_fused_ax1_1_fused_0 * 2 + T.shift_right(threadIdx_x, 2)], 2), var_NT_matmul_intermediate_rf_local_3[0])
var_NT_matmul_intermediate_rf_local_3[0] = T.call_pure_extern("float16x2", "fma", lv1654_shared_1[ax1_0_fused_ax1_1_fused_0 * 64 + threadIdx_x * 8 + 6:ax1_0_fused_ax1_1_fused_0 * 64 + threadIdx_x * 8 + 6 + 2], (T.Cast("float16x2", T.bitwise_and(T.shift_right(T.Broadcast(lv635_local_1[0], 2), T.Ramp(T.uint32(24), T.uint32(4), 2)), T.Broadcast(T.uint32(15), 2))) - T.Broadcast(T.float16(7), 2)) * T.Broadcast(lv636_1[blockIdx_x * 4096 + threadIdx_y * 128 + ax1_0_fused_ax1_1_fused_0 * 2 + T.shift_right(threadIdx_x, 2)], 2), var_NT_matmul_intermediate_rf_local_3[0])
var_NT_matmul_intermediate_rf_local_1_1[0] = T.float16(0)
var_NT_matmul_intermediate_rf_local_1_1[0] = var_NT_matmul_intermediate_rf_local_1_1[0] + T.Shuffle([var_NT_matmul_intermediate_rf_local_3[0]], [0])
var_NT_matmul_intermediate_rf_local_1_1[0] = var_NT_matmul_intermediate_rf_local_1_1[0] + T.Shuffle([var_NT_matmul_intermediate_rf_local_3[0]], [1])
with T.attr(T.comm_reducer(lambda x0, y0: x0 + y0, [T.float16(0)]), "reduce_scope", T.reinterpret("handle", T.uint64(0))):
T.tvm_storage_sync("shared")
red_buf0_1[threadIdx_y * 8 + threadIdx_x] = var_NT_matmul_intermediate_rf_local_1_1[0]
T.tvm_storage_sync("shared")
if threadIdx_x < 4:
red_buf0_1[threadIdx_y * 8 + threadIdx_x] = red_buf0_1[threadIdx_y * 8 + threadIdx_x] + red_buf0_1[threadIdx_y * 8 + threadIdx_x + 4]
T.tvm_storage_sync("shared")
if threadIdx_x < 2:
red_buf0_1[threadIdx_y * 8 + threadIdx_x] = red_buf0_1[threadIdx_y * 8 + threadIdx_x] + red_buf0_1[threadIdx_y * 8 + threadIdx_x + 2]
T.tvm_storage_sync("shared")
if threadIdx_x < 1:
red_buf0_1[threadIdx_y * 8 + threadIdx_x] = red_buf0_1[threadIdx_y * 8 + threadIdx_x] + red_buf0_1[threadIdx_y * 8 + threadIdx_x + 1]
T.tvm_storage_sync("shared")
if threadIdx_x == 0:
var_NT_matmul_intermediate_1[blockIdx_x * 32 + threadIdx_y] = red_buf0_1[threadIdx_y * 8]Codegen error message: |
|
Is it possible to support it in codegen? Usually this can be supported via element extraction e.g |
|
I think we should support via codegen |
…ache#15517)" This reverts commit 925148e.
|
Will look into adding Shuffle support for WebGPU! |
…ache#15517)" This reverts commit 925148e.
…ache#15517)" This reverts commit 925148e.
…ache#15517)" This reverts commit 925148e.
…ache#15517)" This reverts commit 925148e.
|
Hi @vinx13 here's a minimum reproducible script, should reproduce the same error running on up-to-date unity branch: https://gist.github.com/guoyaol/af7d1161124987b69b8eb2744b1c399e |
|
@guoyaol it may be unaligned vectorizations or a false alarm when the arithmetic analyzer can't handle the index pattern |
Added an option
rewrite_scalar_read_to_vector_shuffleinPointerValueTypeRewrite(currentlyonly enabled for Vulkan). When enabled, when a buffer has both scalar
and vector reads, the buffer will be vectorized if possible and scalar
reads will be achieved via T.Shuffle.
Close #15463.
cc @tqchen @Lunderberg @sunggg