diff --git a/aiter/jit/utils/chip_info.py b/aiter/jit/utils/chip_info.py index 99fd90f400..918a989781 100644 --- a/aiter/jit/utils/chip_info.py +++ b/aiter/jit/utils/chip_info.py @@ -27,6 +27,7 @@ 15: "gfx1153", 16: "gfx1200", 17: "gfx1201", + 18: "gfx1250", } diff --git a/aiter/ops/triton/_gluon_kernels/moe/moe_op_gemm_a8w4.py b/aiter/ops/triton/_gluon_kernels/moe/moe_op_gemm_a8w4.py new file mode 100644 index 0000000000..150885015e --- /dev/null +++ b/aiter/ops/triton/_gluon_kernels/moe/moe_op_gemm_a8w4.py @@ -0,0 +1,568 @@ +import torch +import triton +import triton.language as tl +from triton.experimental import gluon +import triton.experimental.gluon.language as gl +from aiter.ops.triton.utils._triton.pid_preprocessing import remap_xcd, pid_grid +from aiter.ops.triton._triton_kernels.moe.quant_moe import _compute_static_fp8_quant + + +def matmul_launch_metadata(grid, kernel, args): + ret = dict() + M, N, K = None, args["N"], args["K"] + Y, X, W = args["Y"], args["X"], args["W"] + hist = args["ExptHist"] + if hist is not None: + n_rows = int(hist.float().mean()) + n_tokens = float(hist.sum()) + n_w_bytes = (W.numel() * W.element_size() // hist.numel()) * (hist > 0).sum() + else: + n_tokens = None + n_w_bytes = W.numel() * W.element_size() + + def repr(s, x): + return f"{s}={x}" if x is not None else f"E_{len(hist)}({s})={n_rows}" + + nbits = X.dtype.itemsize * 8 + ret["name"] = f"{kernel.name} [{repr('M', M)}, {repr('N', N)}, {repr('K', K)}]" + gindx = args.get("GatherIndx", None) + # sindx = args.get("WriteBackIndx", None) + if gindx is not None: + ret["name"] += "_layer1" + else: + ret["name"] += "_layer2" + if args["B"] is not None: + ret["name"] += "_bias" + if args["APPLY_SWIGLU"]: + ret["name"] += "_swiglu" + if args["Quant_static_scale"] is not None: + ret["name"] += "_quant" + + fM = n_tokens + fK = K if K is not None else n_tokens + ret[f"flops{nbits}"] = 2.0 * fM * N * fK + + gindx = args.get("GatherIndx", None) + # sindx = args.get("WriteBackIndx", None) + n_x_bytes = X.numel() * X.element_size() + n_y_bytes = Y.numel() * Y.element_size() + if hist is not None: + assert n_tokens is not None + n_expts_act = args["N_EXPTS_ACT"] + + if gindx is not None: + # recreate inverse GatherIndx. + dst = torch.full_like(gindx, -1) + idx = torch.arange(len(gindx), device=gindx.device, dtype=torch.int32) + mask = gindx != -1 + dst[gindx[mask]] = idx[mask] + n_read_rows = (dst.view((-1, n_expts_act)) != -1).any(dim=1).sum() + else: + n_read_rows = n_tokens + n_x_bytes = n_read_rows * X.shape[-1] * X.element_size() + n_y_bytes = n_tokens * Y.shape[-1] * Y.element_size() + ret["bytes"] = int(n_x_bytes + n_y_bytes + n_w_bytes) + + return ret + + +@gluon.jit +def clip(x, limit, clip_lower: tl.constexpr): + res = gl.minimum(x, limit) + if clip_lower: + res = gl.maximum(-limit, res) + return res + + +@gluon.jit +def _swiglu(input, alpha, limit): + gelu, linear = gl.split(gl.reshape(input, (input.shape[0], input.shape[1] // 2, 2))) + gelu = gelu.to(gl.float32) + if limit is not None: + gelu = clip(gelu, limit, clip_lower=False) + linear = linear.to(tl.float32) + if limit is not None: + linear = clip(linear, limit, clip_lower=True) + s = gelu / (1 + gl.exp2(-1.44269504089 * alpha * gelu)) + return gl.fma(s, linear, s) # (s * (linear + 1)) + + +@triton.jit +def _reduce_grouped( + X, + stride_xb: tl.uint64, + stride_xm: tl.uint64, + stride_xn, # + Out, + stride_om: tl.uint64, + stride_on, # output tensor + InIndx, + B, + N, # + # fused activation function + APPLY_SWIGLU: tl.constexpr, + alpha, + limit, + ACTIVATION_REDUCTION_N: tl.constexpr, + K: tl.constexpr, + BLOCK_N: tl.constexpr, + EVEN_N: tl.constexpr, +): + pid_t = tl.program_id(1) + pid_n = tl.program_id(0) + + BLOCK_N_OUT: tl.constexpr = BLOCK_N // ACTIVATION_REDUCTION_N + start = pid_t * K + # load indices into a tuple + if InIndx is None: + indxs = (pid_t,) + else: + indxs = () + for i in tl.static_range(0, K): + indxs = indxs + (tl.load(InIndx + start + i),) + XPtrs = X + (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) * stride_xn + OutPtrs = Out + (pid_n * BLOCK_N_OUT + tl.arange(0, BLOCK_N_OUT)) * stride_on + + acc = tl.zeros([BLOCK_N_OUT], dtype=tl.float32) + x_n_mask = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) < N + # accumulate contributions for this tile + for i in tl.static_range(0, K): + curr = tl.zeros([BLOCK_N], dtype=tl.float32) + # iterate over split_k partial values + for b in tl.range(0, B): + x_row_ptr = XPtrs + indxs[i] * stride_xm + b * stride_xb + if EVEN_N: + vals = tl.load(x_row_ptr) + else: + vals = tl.load(x_row_ptr, mask=x_n_mask, other=0.0) + vals = vals.to(tl.float32) + curr += vals + + # apply nonlinearity to split-k output + if APPLY_SWIGLU: + curr = _swiglu(curr[None, :], alpha, limit) + curr = tl.reshape(curr, [curr.shape[-1]]) + # update final accumulator + acc += curr + # Compute per-32-col MXFP scales for this tile if requested + Nrem = N // ACTIVATION_REDUCTION_N + + # write-back for this tile + out_ptr = OutPtrs + pid_t * stride_om + if EVEN_N: + tl.store(out_ptr, acc) + else: + out_n_mask = pid_n * BLOCK_N_OUT + tl.arange(0, BLOCK_N_OUT) < Nrem + tl.store(out_ptr, acc, mask=out_n_mask) + + +@gluon.jit(launch_metadata=matmul_launch_metadata) +def _moe_gemm_a8w4( + Y, + stride_y_k, + stride_y_m, + stride_y_n, + X, + stride_x_m, + stride_x_k, + XMxScale, + stride_x_mx_m, + stride_x_mx_k, + W, + stride_w_e, + stride_w_n, + stride_w_k, + WMxScale, + stride_w_mx_e, + stride_w_mx_n, + stride_w_mx_k, + X_static_scale, + Quant_static_scale, + B, + stride_b_e, # Bias + Gammas, + num_tokens, + N, + K, # shapes + # expt data + GatherIndx, + ExptHist, + ExptOffs, + ExptOffsSum, + ExptData, + # true grid size + grid_m, + grid_n, + # fused activation function + APPLY_SWIGLU: gl.constexpr, + alpha, + limit, + ACTIVATION_REDUCTION_N: gl.constexpr, + # MoE config + N_EXPTS_ACT: gl.constexpr, + # optimization config + BLOCK_M: gl.constexpr, + BLOCK_N: gl.constexpr, + BLOCK_K: gl.constexpr, + XCD_SWIZZLE: gl.constexpr, + NUM_BUFFERS: gl.constexpr, + # One of ["GFX1250", None] + SWIZZLE_MX_SCALE: gl.constexpr, + EVEN_K: gl.constexpr, + MASK_K_LIMIT: gl.constexpr, + W_CACHE_MODIFIER: gl.constexpr, + UPCAST_INDICES: gl.constexpr = False, +): + + is_x_microscaled: gl.constexpr = XMxScale is not None + MX_PACK_DIVISOR: gl.constexpr = 32 + w_type: gl.constexpr = W.dtype.element_ty + gl.static_assert(w_type == gl.uint8, "mx_weight_ptr must be uint8 or fp8") + gl.static_assert( + WMxScale.dtype.element_ty == gl.uint8, "mx_scale_ptr must be uint8" + ) + gl.static_assert( + BLOCK_K % MX_PACK_DIVISOR == 0, "BLOCK_K must be a multiple of MX_PACK_DIVISOR" + ) + x_type: gl.constexpr = X.dtype.element_ty + if is_x_microscaled: + gl.static_assert(x_type == gl.float8e4nv, "mx_act_ptr must be float8e4nv") + gl.static_assert( + XMxScale.dtype.element_ty == gl.uint8, "mx_scale_ptr must be uint8" + ) + + OUT_BLOCK_N: tl.constexpr = BLOCK_N // ACTIVATION_REDUCTION_N + yN = N // ACTIVATION_REDUCTION_N + + pid = gl.program_id(0) + if ExptOffsSum is not None: + # Determine how much padding there is on the expert data. This allows us to + # know the true grid size and avoid processing padding tiles. + padding_m = grid_m - gl.load(ExptOffsSum) + else: + padding_m: tl.constexpr = 0 + + index_type: tl.constexpr = gl.int64 if UPCAST_INDICES else gl.int32 + + unpadded_m = grid_m - padding_m + total_actual_tiles = unpadded_m * grid_n + if padding_m > 0 and pid >= total_actual_tiles: + return + + pid_mn = pid % (unpadded_m * grid_n) + if XCD_SWIZZLE != 1: + pid_mn = remap_xcd(pid_mn, total_actual_tiles, XCD_SWIZZLE) + pid_m, pid_n = pid_grid(pid_mn, unpadded_m, grid_n, 1) + # unpack expert data + expt_data = gl.load(ExptData + pid_m) + if expt_data == -1: + return + expt_id = expt_data & 0x0000FFFF + block_id = expt_data >> 16 + M = gl.load(ExptHist + expt_id) + start_m = gl.load(ExptOffs + expt_id) + expt_id, block_id = expt_id.to(index_type), block_id.to(index_type) + start_m = start_m.to(index_type) + pid_n = pid_n.to(index_type) + + # A pointers + off_x_m = BLOCK_M * block_id + if GatherIndx is None: + X += start_m * stride_x_m + else: + IDX_LAYOUT: gl.constexpr = gl.SliceLayout( + 1, gl.BlockedLayout([BLOCK_M, 1], [1, 32], [1, 4], [1, 0]) + ) + offs_x_m = BLOCK_M * block_id + gl.arange(0, BLOCK_M, layout=IDX_LAYOUT) + GatherIndx += start_m + offs_x_m = gl.load(GatherIndx + offs_x_m) // N_EXPTS_ACT + + W_K_DIVISOR: gl.constexpr = 2 + PACKED_BLOCK_K_W: gl.constexpr = BLOCK_K // W_K_DIVISOR + PACKED_BLOCK_N_W: gl.constexpr = BLOCK_N + MX_SCALE_BLOCK_K: gl.constexpr = BLOCK_K // MX_PACK_DIVISOR + + WMxScale += expt_id * stride_w_mx_e + if SWIZZLE_MX_SCALE == "GFX1250_SCALE": + gl.static_assert(stride_w_mx_k is not None) + gl.static_assert(stride_w_mx_n is not None) + PRESHUFFLE_FACTOR: gl.constexpr = 128 + PACKED_MX_BLOCK: gl.constexpr = MX_SCALE_BLOCK_K * PRESHUFFLE_FACTOR + SCALE_BLOCK_N: gl.constexpr = BLOCK_N // PRESHUFFLE_FACTOR + SCALE_KWIDTH: gl.constexpr = 4 if MX_SCALE_BLOCK_K >= 4 else MX_SCALE_BLOCK_K + else: + PRESHUFFLE_FACTOR: gl.constexpr = 1 + PACKED_MX_BLOCK: gl.constexpr = MX_SCALE_BLOCK_K + SCALE_BLOCK_N: gl.constexpr = BLOCK_N + off_w_n_scale = pid_n * SCALE_BLOCK_N + + # B pointers + off_w_n = pid_n * PACKED_BLOCK_N_W + W += expt_id * stride_w_e + + SHARED_LAYOUT_X: gl.constexpr = gl.PaddedSharedLayout.with_identity_for( + [[BLOCK_K, 16]], [BLOCK_M, BLOCK_K], [1, 0] + ) + SHARED_LAYOUT_W: gl.constexpr = gl.PaddedSharedLayout.with_identity_for( + [[PACKED_BLOCK_K_W, 16]], [BLOCK_N, PACKED_BLOCK_K_W], [1, 0] + ) + SHARED_LAYOUT_W_SCALES: gl.constexpr = gl.PaddedSharedLayout.with_identity_for( + [[256, 16]], [SCALE_BLOCK_N, PACKED_MX_BLOCK], [1, 0] + ) + + if GatherIndx is None: + x_desc = gl.amd.gfx1250.tdm.make_tensor_descriptor( + base=X, + shape=(M, K), + strides=(stride_x_m, stride_x_k), + block_shape=(BLOCK_M, BLOCK_K), + layout=SHARED_LAYOUT_X, + ) + else: + x_desc = gl.amd.gfx1250.tdm.make_tensor_descriptor( + base=X, + shape=(num_tokens, K), + strides=(stride_x_m, stride_x_k), + block_shape=(BLOCK_M, BLOCK_K), + layout=SHARED_LAYOUT_X, + ) + w_desc = gl.amd.gfx1250.tdm.make_tensor_descriptor( + base=W, + shape=(N, K // W_K_DIVISOR), + strides=( + stride_w_n, + stride_w_k, + ), + block_shape=( + BLOCK_N, + PACKED_BLOCK_K_W, + ), + layout=SHARED_LAYOUT_W, + ) + w_scales_desc = gl.amd.gfx1250.tdm.make_tensor_descriptor( + base=WMxScale, + shape=(N // PRESHUFFLE_FACTOR, tl.cdiv(K, MX_PACK_DIVISOR) * PRESHUFFLE_FACTOR), + strides=(stride_w_mx_n, stride_w_mx_k), + block_shape=(SCALE_BLOCK_N, PACKED_MX_BLOCK), + layout=SHARED_LAYOUT_W_SCALES, + ) + + if SWIZZLE_MX_SCALE == "GFX1250_SCALE": + WMMA_LAYOUT: gl.constexpr = gl.amd.AMDWMMALayout( + 3, + transposed=True, + warp_bases=[[0, 2], [1, 0]], + reg_bases=[[0, 1]], + instr_shape=[16, 16, 128], + ) + WMMA_LAYOUT_PACKED: gl.constexpr = gl.amd.AMDWMMALayout( + 3, + transposed=True, + warp_bases=[[0, 2], [1, 0]], + reg_bases=[[0, 1]], + instr_shape=[16, 16, 64], + ) + else: + WMMA_LAYOUT: gl.constexpr = gl.amd.AMDWMMALayout( + 3, + transposed=True, + warp_bases=[[0, 1], [1, 0]], + reg_bases=[], + instr_shape=[16, 16, 128], + ) + WMMA_LAYOUT_PACKED: gl.constexpr = gl.amd.AMDWMMALayout( + 3, + transposed=True, + warp_bases=[[0, 1], [1, 0]], + reg_bases=[], + instr_shape=[16, 16, 64], + ) + DOT_LAYOUT_X: gl.constexpr = gl.DotOperandLayout(0, WMMA_LAYOUT, k_width=16) + DOT_LAYOUT_W: gl.constexpr = gl.DotOperandLayout(1, WMMA_LAYOUT_PACKED, k_width=16) + DOT_LAYOUT_W_SCALES: gl.constexpr = gl.amd.gfx1250.get_wmma_scale_layout( + DOT_LAYOUT_W, [BLOCK_N, MX_SCALE_BLOCK_K] + ) + + x_buffer = gl.allocate_shared_memory( + x_desc.dtype, shape=[NUM_BUFFERS] + x_desc.block_shape, layout=x_desc.layout + ) + w_buffer = gl.allocate_shared_memory( + w_desc.dtype, shape=[NUM_BUFFERS] + w_desc.block_shape, layout=w_desc.layout + ) + w_scales_buffer = gl.allocate_shared_memory( + w_scales_desc.dtype, + shape=[NUM_BUFFERS] + w_scales_desc.block_shape, + layout=w_scales_desc.layout, + ) + + read_idx = 0 + write_idx = 0 + for _ in gl.static_range(NUM_BUFFERS - 1): + if GatherIndx is None: + gl.amd.gfx1250.tdm.async_load( + x_desc, + [off_x_m, write_idx * BLOCK_K], + x_buffer.index(write_idx % NUM_BUFFERS), + ) + else: + gl.amd.gfx1250.tdm.async_gather( + x_desc, + offs_x_m, + write_idx * BLOCK_K, + x_buffer.index(write_idx % NUM_BUFFERS), + ) + gl.amd.gfx1250.tdm.async_load( + w_desc, + [off_w_n, write_idx * PACKED_BLOCK_K_W], + w_buffer.index(write_idx % NUM_BUFFERS), + ) + gl.amd.gfx1250.tdm.async_load( + w_scales_desc, + [off_w_n_scale, write_idx * PACKED_MX_BLOCK], + w_scales_buffer.index(write_idx % NUM_BUFFERS), + ) + write_idx += 1 + + # compute output + num_k_iter = tl.cdiv(K, BLOCK_K) + acc = gl.zeros((BLOCK_M, BLOCK_N), dtype=gl.float32, layout=WMMA_LAYOUT) + for k in range(num_k_iter - (NUM_BUFFERS - 1)): + if GatherIndx is None: + gl.amd.gfx1250.tdm.async_load( + x_desc, + [off_x_m, write_idx * BLOCK_K], + x_buffer.index(write_idx % NUM_BUFFERS), + ) + else: + gl.amd.gfx1250.tdm.async_gather( + x_desc, + offs_x_m, + write_idx * BLOCK_K, + x_buffer.index(write_idx % NUM_BUFFERS), + ) + gl.amd.gfx1250.tdm.async_load( + w_desc, + [off_w_n, write_idx * PACKED_BLOCK_K_W], + w_buffer.index(write_idx % NUM_BUFFERS), + ) + gl.amd.gfx1250.tdm.async_load( + w_scales_desc, + [off_w_n_scale, write_idx * PACKED_MX_BLOCK], + w_scales_buffer.index(write_idx % NUM_BUFFERS), + ) + write_idx += 1 + + gl.amd.gfx1250.tdm.async_wait((NUM_BUFFERS - 1) * 3) + + x = x_buffer.index(read_idx % NUM_BUFFERS).load(layout=DOT_LAYOUT_X) + w = ( + w_buffer.index(read_idx % NUM_BUFFERS) + .permute((1, 0)) + .load(layout=DOT_LAYOUT_W) + ) + w_scales_buffer_slice = w_scales_buffer.index(read_idx % NUM_BUFFERS) + if SWIZZLE_MX_SCALE == "GFX1250_SCALE": + w_scales_buffer_slice = ( + w_scales_buffer_slice.reshape( + ( + SCALE_BLOCK_N, + MX_SCALE_BLOCK_K // SCALE_KWIDTH, + PRESHUFFLE_FACTOR // 4, + 4, + SCALE_KWIDTH, + ) + ) + .permute((0, 3, 2, 1, 4)) + .reshape((BLOCK_N, MX_SCALE_BLOCK_K)) + ) + w_scales = w_scales_buffer_slice.load(layout=DOT_LAYOUT_W_SCALES) + read_idx += 1 + + acc = gl.amd.gfx1250.wmma_scaled(x, 0, "e4m3", w, w_scales, "e2m1", acc) + + for k_ep in gl.static_range(NUM_BUFFERS - 1): + gl.amd.gfx1250.tdm.async_wait((NUM_BUFFERS - 2 - k_ep) * 3) + + x = x_buffer.index(read_idx % NUM_BUFFERS).load(layout=DOT_LAYOUT_X) + w = ( + w_buffer.index(read_idx % NUM_BUFFERS) + .permute((1, 0)) + .load(layout=DOT_LAYOUT_W) + ) + w_scales_buffer_slice = w_scales_buffer.index(read_idx % NUM_BUFFERS) + if SWIZZLE_MX_SCALE == "GFX1250_SCALE": + w_scales_buffer_slice = ( + w_scales_buffer_slice.reshape( + ( + SCALE_BLOCK_N, + MX_SCALE_BLOCK_K // SCALE_KWIDTH, + PRESHUFFLE_FACTOR // 4, + 4, + SCALE_KWIDTH, + ) + ) + .permute((0, 3, 2, 1, 4)) + .reshape((BLOCK_N, MX_SCALE_BLOCK_K)) + ) + w_scales = w_scales_buffer_slice.load(layout=DOT_LAYOUT_W_SCALES) + + acc = gl.amd.gfx1250.wmma_scaled(x, 0, "e4m3", w, w_scales, "e2m1", acc) + + # scalar fp8 scale + if X_static_scale is not None: + acc = acc * gl.load(X_static_scale) + # bias + offs_m = BLOCK_M * block_id + gl.arange( + 0, BLOCK_M, layout=gl.SliceLayout(1, WMMA_LAYOUT) + ) + offs_y_n = BLOCK_N * pid_n + gl.arange( + 0, BLOCK_N, layout=gl.SliceLayout(0, WMMA_LAYOUT) + ) + mask_m = offs_m < M + mask_n = offs_y_n < N + if B is not None: + BPtrs = B + expt_id * stride_b_e + bias = gl.amd.gfx1250.buffer_load(BPtrs, offs_y_n, mask=mask_n) + acc = acc + bias[None, :] + if APPLY_SWIGLU: + out = _swiglu(acc, alpha, limit) + tl.static_assert( + out.shape[1] == OUT_BLOCK_N, + f"Activation fn out.shape[1] ({out.shape[1]}) doesn't match computed OUT_BLOCK_N ({OUT_BLOCK_N})", + ) + # STORE_LAYOUT: gl.constexpr = gl.BlockedLayout(size_per_thread=[1, 16], threads_per_warp=[4, 8], warps_per_cta=[4, 1], order=[1, 0]) + offs_m = BLOCK_M * block_id + gl.arange(0, BLOCK_M) + offs_y_n = OUT_BLOCK_N * pid_n + gl.arange(0, OUT_BLOCK_N) + mask_m = offs_m < M + mask_n = offs_y_n < yN + else: + tl.static_assert( + ACTIVATION_REDUCTION_N == 1, + "Activation reduction must be 1 if no activation fn is provided", + ) + out = acc + if Gammas is not None: + gammas = gl.load(Gammas + start_m + offs_m, mask=mask_m, other=0.0) + out *= gammas[:, None] + # quant + if Quant_static_scale is not None: + out = _compute_static_fp8_quant(out, gl.load(Quant_static_scale)) + # write-back + Y += start_m * stride_y_m + offs_y_m = offs_m + # YPtrs = ( + # Y + # + offs_y_m.to(index_type)[:, None] * stride_y_m + # + offs_y_n.to(index_type)[None, :] * stride_y_n + # ) + offs_y = ( + offs_y_m.to(index_type)[:, None] * stride_y_m + + offs_y_n.to(index_type)[None, :] * stride_y_n + ) + mask = mask_m[:, None] & mask_n[None, :] + if Quant_static_scale is None: + out = out.to(tl.bfloat16) + # if APPLY_SWIGLU: + # out = gl.convert_layout(out, layout=STORE_LAYOUT) + # gl.store(YPtrs, out, mask=mask) + gl.amd.gfx1250.buffer_store(out, Y, offs_y, mask=mask) diff --git a/aiter/ops/triton/moe/moe_op_gemm_a8w4.py b/aiter/ops/triton/moe/moe_op_gemm_a8w4.py index 46102d47d7..b55951f45b 100644 --- a/aiter/ops/triton/moe/moe_op_gemm_a8w4.py +++ b/aiter/ops/triton/moe/moe_op_gemm_a8w4.py @@ -6,13 +6,14 @@ import triton from aiter.ops.triton.moe.moe_routing.routing import RoutingData from aiter.ops.triton._triton_kernels.moe.moe_op_gemm_a8w4 import ( - _moe_gemm_a8w4, + _moe_gemm_a8w4 as _moe_gemm_a8w4_triton, _reduce_grouped, ) - -# ----------------------------------------------------------------------------- -# Matrix Multiplication + Outer Gather/Scatter -# ----------------------------------------------------------------------------- +from aiter.ops.triton._gluon_kernels.moe.moe_op_gemm_a8w4 import ( + _moe_gemm_a8w4 as _moe_gemm_a8w4_gluon, +) +from aiter.ops.triton.utils._triton.arch_info import get_arch +from aiter.ops.triton.utils.device_info import get_num_sms def can_overflow_int32(tensor: torch.Tensor): @@ -28,8 +29,8 @@ def should_upcast_indices(*args): def allocate_output( - x, - w, + M, + N, out_dtype, reduction_n_matmul, reduction_n_reduction, @@ -38,15 +39,8 @@ def allocate_output( scatter_indx, block_m, split_k, + device, ): - # ---- output ------ - N = w.shape[-1] - # by default - M is number of rows in the activations - M = x.shape[-2] - # if the activations are gathered, then M is number of gather indices - if gather_indx is not None: - M = gather_indx.shape[0] - # final output if routing_data.n_expts_act == 1 or scatter_indx is None: y_rows = M else: @@ -55,15 +49,15 @@ def allocate_output( ) # compressed number of rows matmul_shape = (split_k, M, N // reduction_n_matmul) final_shape = (y_rows, N // reduction_n_matmul // reduction_n_reduction) - matmul_output = torch.empty(matmul_shape, device=x.device, dtype=out_dtype) + matmul_output = torch.empty(matmul_shape, device=device, dtype=out_dtype) if scatter_indx is not None or split_k > 1: - final_output = torch.empty(final_shape, device=x.device, dtype=out_dtype) + final_output = torch.empty(final_shape, device=device, dtype=out_dtype) else: final_output = None return matmul_output, final_output -def get_kernel_config(m, n, k, routing_data): +def get_kernel_config_triton(m, n, k, routing_data): block_m = routing_data.block_m group_m = 4 num_xcds = 8 @@ -80,7 +74,7 @@ def get_kernel_config(m, n, k, routing_data): grid_m = routing_data.n_blocks(m, block_m) grid_n = triton.cdiv(n, block_n) grid = grid_m * grid_n * split_k - while block_n >= 64 and grid < 256: + while block_n >= 64 and grid < get_num_sms(): block_n = block_n // 2 grid_m = routing_data.n_blocks(m, block_m) grid_n = triton.cdiv(n, block_n) @@ -118,7 +112,57 @@ def get_kernel_config(m, n, k, routing_data): return ret -def swizzle_scales(data): +def get_kernel_config_gluon(m, n, k, routing_data): + block_m = routing_data.block_m + num_xcds = 1 + w_cache_modifier = ".cg" if block_m <= 32 else None + num_stages = 2 + split_k = 1 + block_k = 256 + + if block_m == 16: + block_n = 128 + num_warps = 4 + + grid_m = routing_data.n_blocks(m, block_m) + grid_n = triton.cdiv(n, block_n) + grid = grid_m * grid_n * split_k + while block_n >= 64 and grid < get_num_sms(): + block_n = block_n // 2 + grid_m = routing_data.n_blocks(m, block_m) + grid_n = triton.cdiv(n, block_n) + grid = grid_m * grid_n * split_k + + elif block_m == 32: + if n <= 1024: + block_n = 128 + num_warps = 4 + elif n <= 4096: + block_n = 256 + num_warps = 4 + else: + block_n = 512 + num_warps = 4 + + else: + block_n = 512 + num_warps = 4 + + ret = { + "block_m": block_m, + "block_n": block_n, + "block_k": block_k, + "num_warps": num_warps, + "num_stages": num_stages, + "xcd_swizzle": num_xcds, + "split_k": split_k, + "w_cache_modifier": w_cache_modifier, + "waves_per_eu": 0, + } + return ret + + +def swizzle_scales_gfx950(data): NON_K_PRESHUFFLE_BLOCK_SIZE = 32 block_shape = data.shape SCALE_K = block_shape[-2] @@ -131,6 +175,24 @@ def swizzle_scales(data): return data.transpose(-1, -2) +def swizzle_scales_gfx1250(data): + E, K_SCALE, N = data.shape + preshuffle_factor = 128 + num_chunk_n = N // preshuffle_factor + SCALE_KWIDTH = 4 if K_SCALE >= 4 else K_SCALE + num_chunk_k = K_SCALE // SCALE_KWIDTH + + data = data.transpose(-1, -2) + data = data.view( + E, num_chunk_n, 4, preshuffle_factor // 4, num_chunk_k, SCALE_KWIDTH + ) + data = data.permute(0, 1, 4, 3, 2, 5).contiguous() + data = data.view(E, N // preshuffle_factor, K_SCALE * preshuffle_factor) + data = data.transpose(-1, -2) + + return data + + def reduce_grouped( x: torch.Tensor, indx: torch.Tensor, @@ -235,6 +297,7 @@ def moe_gemm_a8w4( for e in num_experts: Y[idxs_y_m(e), :] += matmul(X[idxs_x_m(e), :], W[e, :, :]) """ + use_gluon = get_arch() == "gfx1250" assert w.stride(-2) == 1, "`w` must be column-major when it has data-type mxfp" x_has_mx = x_scales is not None if x_has_mx: @@ -246,15 +309,22 @@ def moe_gemm_a8w4( stride_x_mx_m = 0 stride_x_mx_k = 0 # determine shapes - M = x.shape[-2] if gather_indx is None else gather_indx.shape[0] + num_tokens = x.shape[-2] + M = num_tokens if gather_indx is None else gather_indx.shape[0] K, N = x.shape[-1], w.shape[-1] block_m = routing_data.block_m if unpadded_N and block_m == 16: N = unpadded_N if unpadded_K and block_m == 16: K = unpadded_K + if use_gluon: + w = w.transpose(1, 2) + w_scales = w_scales.transpose(1, 2) # compute optimization flags - config = get_kernel_config(M, N, K, routing_data) + if use_gluon: + config = get_kernel_config_gluon(M, N, K, routing_data) + else: + config = get_kernel_config_triton(M, N, K, routing_data) if apply_swiglu and config["split_k"] > 1: apply_swiglu_matmul = False reduction_n_matmul = 1 @@ -272,8 +342,8 @@ def moe_gemm_a8w4( reduction_n_reduction = 1 # allocate output memory y, y_final = allocate_output( - x, - w, + M, + N, out_dtype, reduction_n_matmul, reduction_n_reduction, @@ -282,6 +352,7 @@ def moe_gemm_a8w4( scatter_indx, config["block_m"], config["split_k"], + x.device, ) stride_bias = None if bias is None else bias.stride(0) # moe metadata @@ -290,66 +361,121 @@ def moe_gemm_a8w4( expt_hist_sum = None if expt_data is None else expt_data.token_offs_pad[-1] expt_token_offs_raw = None if expt_data is None else expt_data.token_offs_raw expt_block_pid_map = None if expt_data is None else expt_data.block_pid_map - # spmd grid + # pid grid grid_m = routing_data.n_blocks(M, config["block_m"]) grid_n = triton.cdiv(N, config["block_n"]) grid = grid_m * grid_n * config["split_k"] # launch kernel - _moe_gemm_a8w4[(grid,)]( - y, - y.stride(0), - y.stride(1), - y.stride(2), - x, - x.stride(0), - x.stride(1), - x_scales, - stride_x_mx_m, - stride_x_mx_k, - w, - w.stride(0), - w.stride(1), - w.stride(2), - w_scales, - w_scales.stride(0), - w_scales.stride(1), - w_scales.stride(2), - x_static_scale, - quant_static_scale, - bias, - stride_bias, - gammas, - N, - K, - gather_indx, - expt_hist, - expt_token_offs_raw, - expt_hist_sum, - expt_block_pid_map, - grid_m, - grid_n, - apply_swiglu_matmul, - alpha, - limit, - reduction_n_matmul, - routing_data.n_expts_act, - config["block_m"], - config["block_n"], - config["block_k"], - config["group_m"], - XCD_SWIZZLE=config["xcd_swizzle"], - SWIZZLE_MX_SCALE=swizzle_mx_scale, - SPLIT_K=config["split_k"], - EVEN_K=K % config["block_k"] == 0, - MASK_K_LIMIT=K % config["block_k"], - W_CACHE_MODIFIER=config["w_cache_modifier"], - num_warps=config["num_warps"], - num_stages=config["num_stages"], - UPCAST_INDICES=should_upcast_indices(x, w, y), - waves_per_eu=config["waves_per_eu"], - matrix_instr_nonkdim=config["matrix_instr_nonkdim"], - kpack=config["kpack"], - ) + if use_gluon: + _moe_gemm_a8w4_gluon[(grid,)]( + y, + y.stride(0), + y.stride(1), + y.stride(2), + x, + x.stride(0), + x.stride(1), + x_scales, + stride_x_mx_m, + stride_x_mx_k, + w, + w.stride(0), + w.stride(1), + w.stride(2), + w_scales, + w_scales.stride(0), + w_scales.stride(1), + w_scales.stride(2), + x_static_scale, + quant_static_scale, + bias, + stride_bias, + gammas, + num_tokens, + N, + K, + gather_indx, + expt_hist, + expt_token_offs_raw, + expt_hist_sum, + expt_block_pid_map, + grid_m, + grid_n, + apply_swiglu_matmul, + alpha, + limit, + reduction_n_matmul, + routing_data.n_expts_act, + config["block_m"], + config["block_n"], + config["block_k"], + XCD_SWIZZLE=config["xcd_swizzle"], + NUM_BUFFERS=config["num_stages"], + SWIZZLE_MX_SCALE=swizzle_mx_scale, + EVEN_K=K % config["block_k"] == 0, + MASK_K_LIMIT=K % config["block_k"], + W_CACHE_MODIFIER=config["w_cache_modifier"], + num_warps=config["num_warps"], + UPCAST_INDICES=should_upcast_indices(x, w, y), + waves_per_eu=config["waves_per_eu"], + ) + else: + _moe_gemm_a8w4_triton[(grid,)]( + y, + y.stride(0), + y.stride(1), + y.stride(2), + x, + x.stride(0), + x.stride(1), + x_scales, + stride_x_mx_m, + stride_x_mx_k, + w, + w.stride(0), + w.stride(1), + w.stride(2), + w_scales, + w_scales.stride(0), + w_scales.stride(1), + w_scales.stride(2), + x_static_scale, + quant_static_scale, + bias, + stride_bias, + gammas, + N, + K, + gather_indx, + expt_hist, + expt_token_offs_raw, + expt_hist_sum, + expt_block_pid_map, + grid_m, + grid_n, + apply_swiglu_matmul, + alpha, + limit, + reduction_n_matmul, + routing_data.n_expts_act, + config["block_m"], + config["block_n"], + config["block_k"], + config["group_m"], + XCD_SWIZZLE=config["xcd_swizzle"], + SWIZZLE_MX_SCALE=swizzle_mx_scale, + SPLIT_K=config["split_k"], + EVEN_K=K % config["block_k"] == 0, + MASK_K_LIMIT=K % config["block_k"], + W_CACHE_MODIFIER=config["w_cache_modifier"], + num_warps=config["num_warps"], + num_stages=config["num_stages"], + UPCAST_INDICES=should_upcast_indices(x, w, y), + waves_per_eu=config["waves_per_eu"], + matrix_instr_nonkdim=config["matrix_instr_nonkdim"], + kpack=config["kpack"], + ) + # Build grouped reduction inputs in a uniform way group_indx = ( None diff --git a/op_tests/op_benchmarks/triton/bench_moe_gemm_a8w4.py b/op_tests/op_benchmarks/triton/bench_moe_gemm_a8w4.py index 1e5404c6c1..c04247795f 100644 --- a/op_tests/op_benchmarks/triton/bench_moe_gemm_a8w4.py +++ b/op_tests/op_benchmarks/triton/bench_moe_gemm_a8w4.py @@ -10,7 +10,7 @@ from aiter.ops.triton.gemm.basic.gemm_a16w16 import gemm_a16w16 from aiter.ops.triton.moe.moe_op_gemm_a8w4 import ( moe_gemm_a8w4, - swizzle_scales, + swizzle_scales_gfx950, ) from aiter.ops.triton.utils._triton.arch_info import get_arch import tempfile @@ -88,7 +88,7 @@ def inject_proxy_and_call(val, args, kwargs): def check_and_swizzle_scales(scale, N, K): if N % 32 == 0 and K % (32 * 8) == 0: - scale = swizzle_scales(scale) + scale = swizzle_scales_gfx950(scale) return scale, "CDNA4_SCALE" else: return scale, None diff --git a/op_tests/triton_tests/moe/test_moe_gemm_a8w4.py b/op_tests/triton_tests/moe/test_moe_gemm_a8w4.py index e75c73a815..be6e9f7306 100644 --- a/op_tests/triton_tests/moe/test_moe_gemm_a8w4.py +++ b/op_tests/triton_tests/moe/test_moe_gemm_a8w4.py @@ -12,7 +12,8 @@ from aiter.ops.triton.moe.moe_op_gemm_a8w4 import ( moe_gemm_a8w4, moe_gemm_torch, - swizzle_scales, + swizzle_scales_gfx950, + swizzle_scales_gfx1250, ) # numerics utilities @@ -68,7 +69,6 @@ def init_compute_data( has_y_gammas, device="cuda", ): - torch.manual_seed(0) in_m = m * (n_expts_act if gindx is None else 1) shape_x = (in_m, k) x = alloc_rand(shape_x, device=device, dtype=act_dtype) @@ -193,6 +193,10 @@ class Case: Case(4096, 256, 256, "mxfloat8_e4m3fn", 128, 4), Case(1000, 704, 800, "mxfloat8_e4m3fn", 8, 2), Case(300, 400, 800, "mxfloat8_e4m3fn", 8, 4), + # smaller tests for gfx1250 ffm + Case(16, 512, 512, "float8_e4m3fn", 32, 2), + Case(16, 512, 512, "float8_e4m3fn", 32, 2, hbm_swizzling=True), + Case(300, 400, 800, "float8_e4m3fn", 8, 4), ] ], ) @@ -224,20 +228,22 @@ def test_op( device="cuda", ): - if get_arch() != "gfx950": - pytest.skip("float8 x mx only supported on CDNA4") + if get_arch() != "gfx950" and get_arch() != "gfx1250": + pytest.skip("Kernel not supported on this GPU.") - if "float8_e4m3fnuz" in act_dtype_str and get_arch() != "gfx942": - pytest.skip("float8_e4m3fnuz only tested on AMD CDNA3 Platform") + if get_arch() == "gfx1250": + if act_dtype_str == "mxfloat8_e4m3fn": + pytest.skip("Mxfloat activations are not supported yet on gfx1250.") + if apply_swiglu and has_y_gammas: + pytest.skip("Swiglu and gammas are not supported together on gfx1250.") + # temporary + if m > 1024 or n > 1024 or k > 1024 or n_expts_tot > 32: + pytest.skip("Test will take too long time on FFM") if hbm_swizzling: - if get_arch() != "gfx950": - pytest.skip( - "Scale preshuffling on AMD GPU has not been emulated on non-CDNA4 arch yet." - ) - if n % 32 != 0 or k % (32 * 8) != 0: + if get_arch() == "gfx950" and (n % 32 != 0 or k % (32 * 8) != 0): pytest.skip( - f"Shape {m}x{n}x{k} is not supported for scale swizzling on AMD GPU" + f"Shape {m}x{n}x{k} is not supported for scale swizzling on gfx950." ) torch.manual_seed(0) @@ -274,8 +280,13 @@ def test_op( w_tri, w_scale_tri = downcast_to_mxfp(w_tri, weight_dtype, axis=1) w_ref = upcast_from_mxfp(w_tri, w_scale_tri, torch.bfloat16, axis=1) if hbm_swizzling: - swizzle_mx_scale = "CDNA4_SCALE" - w_scale_tri = swizzle_scales(w_scale_tri) + if get_arch() == "gfx1250": + swizzle_mx_scale = "GFX1250_SCALE" + w_scale_tri = swizzle_scales_gfx1250(w_scale_tri) + else: + assert get_arch() == "gfx950" + swizzle_mx_scale = "CDNA4_SCALE" + w_scale_tri = swizzle_scales_gfx950(w_scale_tri) else: swizzle_mx_scale = None @@ -283,14 +294,12 @@ def test_op( x_tri, x_mx_scales_tri = downcast_to_mxfp(x_tri, act_dtype, axis=-1) x_ref = upcast_from_mxfp(x_tri, x_mx_scales_tri, torch.bfloat16, axis=-1) x_static_scale = None - out_dtype = torch.bfloat16 maxtol = None rmstol = None else: x_mx_scales_tri = None x_static_scale = x_tri.abs().max().float() / 448.0 x_tri = downcast_to_static_fp8(x_tri, x_static_scale) - out_dtype = torch.float8_e4m3fn maxtol = 4e-1 rmstol = 4e-2 @@ -299,8 +308,10 @@ def test_op( ) if not act_mxfp8 and fused_quant: quant_static_scale = ref_y.abs().max().float() / 448.0 + out_dtype = torch.float8_e4m3fn else: quant_static_scale = None + out_dtype = torch.bfloat16 tri_y = moe_gemm_a8w4( x_tri, w_tri,