You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hello, I want to implement the calculation of moe_gemm for W4A4 INT4 quantization on the Ada architecture (SM89), the original input type is fp16 . The quantization of activation is per-token, and the quantization of weights is per-channel. I have implemented the group_gemm calculation using cutlass's cutlass:: gem:: kernel:: DefaultGemmGrouped. But I don't know how to implement dequantize in the Epilogue stage, which is to calculate the product of the activation scale column broadcast and weight scale row broadcast with the accumulator (scale_act * scale_weight * accumulator). Can you tell me how to do it in the calculation of the group gemm?
I am currently using an inverse quantization kernel instead of Epilogue implementation, but the performance is very poor.
This is my core implementation code. Could you please tell me how to implement dequantize in Epilogue.
template <typename ActivationType, typename WeightType, typename OutputType, typename Arch,
typename ThreadblockShape, typename WarpShape, int Stages>
void Int4MoeGemmKernelLauncherFused<ActivationType, WeightType, OutputType, Arch,
ThreadblockShape, WarpShape, Stages>::run_fused(
ActivationType const* input,
int64_t const* total_tokens_including_expert,
WeightType const* expert_weights,
float const* input_scales,
float const* weight_scales,
OutputType* output,
int64_t num_rows,
int64_t gemm_n,
int64_t gemm_k,
int num_experts,
int sm_count,
cudaStream_t stream,
int* occupancy)
{
// ========================================================================
// Optimized Implementation: GEMM + Vectorized Dequantization
// Uses CUTLASS grouped GEMM for INT4xINT4->INT32 followed by
// an optimized vectorized dequantization kernel.
// ========================================================================
using BaseGemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped<
ActivationType, LayoutA,
cutlass::ComplexTransform::kNone, Alignment,
WeightType, LayoutB,
cutlass::ComplexTransform::kNone, Alignment,
ElementAccumulator, LayoutC,
ElementAccumulator,
cutlass::arch::OpClassTensorOp,
Arch,
ThreadblockShape,
WarpShape,
InstructionShape,
cutlass::epilogue::thread::LinearCombination<ElementAccumulator, 128 / cutlass::sizeof_bits<ElementAccumulator>::value, ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle,
Stages,
cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly
>::GemmKernel;
using GemmGrouped = cutlass::gemm::device::GemmGrouped<BaseGemmKernel>;
if (occupancy != nullptr) {
*occupancy = GemmGrouped::maximum_active_blocks();
return;
}
// Validate alignment
if (gemm_k % 32 != 0) {
throw std::runtime_error("K dimension must be a multiple of 32 for INT4 TensorOp");
}
if (gemm_n % 32 != 0) {
throw std::runtime_error("N dimension must be a multiple of 32 for INT4 TensorOp");
}
// Copy total_tokens_including_expert from device to host
std::vector<int64_t> tokens_per_expert_host(num_experts + 1);
cudaMemcpy(tokens_per_expert_host.data(), total_tokens_including_expert,
sizeof(int64_t) * (num_experts + 1), cudaMemcpyDeviceToHost);
// Prepare problem configurations
std::vector<cutlass::gemm::GemmCoord> problem_sizes_host;
std::vector<int64_t> lda_host, ldb_host, ldc_host;
std::vector<ActivationType*> ptr_A_host;
std::vector<WeightType*> ptr_B_host;
std::vector<ElementAccumulator*> ptr_C_host;
std::vector<ElementAccumulator*> ptr_D_host;
int64_t total_output_elements = 0;
int64_t k_packed = gemm_k / 2; // INT4 packing: K elements = K/2 bytes
// Cast to uint8_t* for correct byte-based pointer arithmetic
const uint8_t* input_bytes = reinterpret_cast<const uint8_t*>(input);
const uint8_t* weight_bytes = reinterpret_cast<const uint8_t*>(expert_weights);
for (int expert = 0; expert < num_experts; ++expert) {
int64_t tokens_start = tokens_per_expert_host[expert];
int64_t tokens_end = tokens_per_expert_host[expert + 1];
int64_t m = tokens_end - tokens_start;
if (m <= 0) continue;
problem_sizes_host.emplace_back(
static_cast<int>(m),
static_cast<int>(gemm_n),
static_cast<int>(gemm_k)
);
// Input A: [M, K/2] bytes per expert slice, row-major
ptr_A_host.push_back(reinterpret_cast<ActivationType*>(
const_cast<uint8_t*>(input_bytes + tokens_start * k_packed)));
// Weight B: [N, K/2] bytes per expert (col-major [K, N])
ptr_B_host.push_back(reinterpret_cast<WeightType*>(
const_cast<uint8_t*>(weight_bytes + expert * gemm_n * k_packed)));
ptr_C_host.push_back(nullptr); // Will be set after int32_output allocation
ptr_D_host.push_back(nullptr);
// Leading dimensions in INT4 elements
lda_host.push_back(gemm_k); // A row-major: stride between rows = K
ldb_host.push_back(gemm_k); // B col-major: stride between columns = K
ldc_host.push_back(gemm_n); // C row-major: stride between rows = N
total_output_elements += m * gemm_n;
}
int problem_count = static_cast<int>(problem_sizes_host.size());
if (problem_count == 0) return;
// Allocate INT32 accumulator buffer
cutlass::DeviceAllocation<ElementAccumulator> int32_accum(total_output_elements);
// Setup output pointers for int32 accumulator
int64_t offset = 0;
for (int i = 0; i < problem_count; ++i) {
ptr_C_host[i] = int32_accum.get() + offset;
ptr_D_host[i] = int32_accum.get() + offset;
int m = problem_sizes_host[i].m();
int n = problem_sizes_host[i].n();
offset += m * n;
}
// Allocate device memory for problem descriptors
cutlass::DeviceAllocation<cutlass::gemm::GemmCoord> problem_sizes_device(problem_count);
cutlass::DeviceAllocation<int64_t> lda_device(problem_count);
cutlass::DeviceAllocation<int64_t> ldb_device(problem_count);
cutlass::DeviceAllocation<int64_t> ldc_device(problem_count);
cutlass::DeviceAllocation<ActivationType*> ptr_A_device(problem_count);
cutlass::DeviceAllocation<WeightType*> ptr_B_device(problem_count);
cutlass::DeviceAllocation<ElementAccumulator*> ptr_C_device(problem_count);
cutlass::DeviceAllocation<ElementAccumulator*> ptr_D_device(problem_count);
// Copy to device
cudaMemcpy(problem_sizes_device.get(), problem_sizes_host.data(),
sizeof(cutlass::gemm::GemmCoord) * problem_count, cudaMemcpyHostToDevice);
cudaMemcpy(lda_device.get(), lda_host.data(),
sizeof(int64_t) * problem_count, cudaMemcpyHostToDevice);
cudaMemcpy(ldb_device.get(), ldb_host.data(),
sizeof(int64_t) * problem_count, cudaMemcpyHostToDevice);
cudaMemcpy(ldc_device.get(), ldc_host.data(),
sizeof(int64_t) * problem_count, cudaMemcpyHostToDevice);
cudaMemcpy(ptr_A_device.get(), ptr_A_host.data(),
sizeof(ActivationType*) * problem_count, cudaMemcpyHostToDevice);
cudaMemcpy(ptr_B_device.get(), ptr_B_host.data(),
sizeof(WeightType*) * problem_count, cudaMemcpyHostToDevice);
cudaMemcpy(ptr_C_device.get(), ptr_C_host.data(),
sizeof(ElementAccumulator*) * problem_count, cudaMemcpyHostToDevice);
cudaMemcpy(ptr_D_device.get(), ptr_D_host.data(),
sizeof(ElementAccumulator*) * problem_count, cudaMemcpyHostToDevice);
// Calculate threadblock count
int threadblock_count = GemmGrouped::sufficient(problem_sizes_host.data(), problem_count);
// Configure epilogue
typename GemmGrouped::EpilogueOutputOp::Params epilogue_params{
ElementAccumulator(1),
ElementAccumulator(0)
};
// Configure GEMM arguments with device pointers
typename GemmGrouped::Arguments arguments(
problem_sizes_device.get(),
problem_count,
threadblock_count,
epilogue_params,
ptr_A_device.get(),
ptr_B_device.get(),
ptr_C_device.get(),
ptr_D_device.get(),
lda_device.get(),
ldb_device.get(),
ldc_device.get(),
ldc_device.get(),
problem_sizes_host.data()
);
// Initialize and run GEMM
GemmGrouped gemm_op;
size_t workspace_size = gemm_op.get_workspace_size(arguments);
cutlass::DeviceAllocation<uint8_t> workspace(workspace_size);
cutlass::Status status = gemm_op.initialize(arguments, workspace.get(), stream);
if (status != cutlass::Status::kSuccess) {
throw std::runtime_error("Failed to initialize CUTLASS group GEMM: " +
std::to_string(static_cast<int>(status)));
}
status = gemm_op(stream);
if (status != cutlass::Status::kSuccess) {
throw std::runtime_error("Failed to run CUTLASS group GEMM: " +
std::to_string(static_cast<int>(status)));
}
// Apply optimized vectorized dequantization kernel
launch_moe_dequant_optimized<OutputType>(
output,
int32_accum.get(),
input_scales,
weight_scales,
total_tokens_including_expert,
static_cast<int>(num_rows),
static_cast<int>(gemm_n),
num_experts,
stream
);
// Synchronize to ensure int32_accum is not freed before kernel finishes
cudaStreamSynchronize(stream);
}
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hello, I want to implement the calculation of moe_gemm for W4A4 INT4 quantization on the Ada architecture (SM89), the original input type is fp16 . The quantization of activation is per-token, and the quantization of weights is per-channel. I have implemented the group_gemm calculation using cutlass's cutlass:: gem:: kernel:: DefaultGemmGrouped. But I don't know how to implement dequantize in the Epilogue stage, which is to calculate the product of the activation scale column broadcast and weight scale row broadcast with the accumulator (scale_act * scale_weight * accumulator). Can you tell me how to do it in the calculation of the group gemm?
I have referred to the following code, but I am unable to implement the code I want:
TensorRT-LLM/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h
https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_rowwise_gemm/fp8_rowwise_gemm_kernel_template_sm89.h
I am currently using an inverse quantization kernel instead of Epilogue implementation, but the performance is very poor.
This is my core implementation code. Could you please tell me how to implement dequantize in Epilogue.
Beta Was this translation helpful? Give feedback.
All reactions