diff --git a/xtuner/v1/float8/triton_kernels/trans_quant_per_block.py b/xtuner/v1/float8/triton_kernels/trans_quant_per_block.py index 5e6dfca9f..f182901cd 100644 --- a/xtuner/v1/float8/triton_kernels/trans_quant_per_block.py +++ b/xtuner/v1/float8/triton_kernels/trans_quant_per_block.py @@ -9,6 +9,9 @@ from xtuner.v1.float8.float8_utils import to_fp8_saturated +SM_MARGIN = int(os.environ.get("XTUNER_SM_MARGIN", 0)) + + @triton.jit def get_group_id(m, group_offsets, g_start, num_experts): id = 0 @@ -123,7 +126,7 @@ def _trans_per_block_quant_expand_128x( finfo = torch.finfo(dtype) fmin = finfo.min fmax = finfo.max - NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count - int(os.getenv("MinusSM", 0)) + NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count - SM_MARGIN grid = (NUM_SMS,) num_experts = size_per_group.shape[0] M, N = input_tensor.shape diff --git a/xtuner/v1/ops/moe/cuda/triton_kernels/k_grouped_gemm_TMA.py b/xtuner/v1/ops/moe/cuda/triton_kernels/k_grouped_gemm_TMA.py index 5e79e7201..a6662fab5 100644 --- a/xtuner/v1/ops/moe/cuda/triton_kernels/k_grouped_gemm_TMA.py +++ b/xtuner/v1/ops/moe/cuda/triton_kernels/k_grouped_gemm_TMA.py @@ -1,3 +1,6 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os + import torch import triton import triton.language as tl @@ -6,6 +9,9 @@ from .utils import TmaAutoTuneHelper +SM_MARGIN = int(os.environ.get("XTUNER_SM_MARGIN", 0)) + + def get_cuda_autotune_config(): return [ triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 64, "GROUP_M": 12}, num_stages=3, num_warps=8), @@ -148,7 +154,7 @@ def k_grouped_gemm(A: Tensor, B: Tensor, size_per_group: torch.Tensor) -> Tensor assert dtype_b >= 0, f"data type {B.dtype} not supported" assert dtype_c >= 0, f"data type {C.dtype} not supported" - NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count + NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count - SM_MARGIN desc_helper = TmaAutoTuneHelper() desc_helper.init_tma_descriptor("a") diff --git a/xtuner/v1/ops/moe/cuda/triton_kernels/k_grouped_gemm_TMA_triton3_4.py b/xtuner/v1/ops/moe/cuda/triton_kernels/k_grouped_gemm_TMA_triton3_4.py index a307bf7d1..63b280901 100644 --- a/xtuner/v1/ops/moe/cuda/triton_kernels/k_grouped_gemm_TMA_triton3_4.py +++ b/xtuner/v1/ops/moe/cuda/triton_kernels/k_grouped_gemm_TMA_triton3_4.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +import os from typing import Optional import torch @@ -7,6 +8,9 @@ from torch import Tensor +SM_MARGIN = int(os.environ.get("XTUNER_SM_MARGIN", 0)) + + def get_cuda_autotune_config(): return [ triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 64, "GROUP_M": 12}, num_stages=5, num_warps=8), @@ -160,7 +164,7 @@ def k_grouped_gemm(A: Tensor, B: Tensor, size_per_group: torch.Tensor) -> Tensor assert dtype_b >= 0, f"data type {B.dtype} not supported" assert dtype_c >= 0, f"data type {C.dtype} not supported" - NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count + NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count - SM_MARGIN def grid(META): # assert N % META["BLOCK_N"] == 0, "Only support when N is a multiple of BLOCK_N" diff --git a/xtuner/v1/ops/moe/cuda/triton_kernels/m_grouped_gemm_TMA.py b/xtuner/v1/ops/moe/cuda/triton_kernels/m_grouped_gemm_TMA.py index 67c928f54..1e74c070f 100644 --- a/xtuner/v1/ops/moe/cuda/triton_kernels/m_grouped_gemm_TMA.py +++ b/xtuner/v1/ops/moe/cuda/triton_kernels/m_grouped_gemm_TMA.py @@ -1,4 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. +import os + import torch import triton import triton.language as tl @@ -7,6 +9,9 @@ from .utils import TmaAutoTuneHelper +SM_MARGIN = int(os.environ.get("XTUNER_SM_MARGIN", 0)) + + def get_cuda_autotune_config(): return [ triton.Config({"BLOCK_N": 64, "BLOCK_K": 256, "GROUP_M": 6}, num_stages=3, num_warps=8), @@ -229,9 +234,7 @@ def repeat_interleave( @torch.library.custom_op("moe::m_grouped_gemm", mutates_args=()) -def m_grouped_gemm( - A: Tensor, B: Tensor, size_per_group: torch.Tensor, trans_b: bool = False, numSM: int = -1 -) -> Tensor: +def m_grouped_gemm(A: Tensor, B: Tensor, size_per_group: torch.Tensor, trans_b: bool = False) -> Tensor: assert A.dim() == 2 assert B.dim() == 3 @@ -264,7 +267,7 @@ def m_grouped_gemm( group_end = size_per_group.cumsum(0) - size_per_group + size_per_group group_start = size_per_group.cumsum(0) - size_per_group - NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count if numSM <= 0 else numSM + NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count - SM_MARGIN dtype_mapping = {torch.bfloat16: 0, torch.float16: 1} dtype_a = dtype_mapping.get(A.dtype, -1) @@ -347,7 +350,7 @@ def grid(META): @m_grouped_gemm.register_fake -def _(A: Tensor, B: Tensor, size_per_group: torch.Tensor, trans_b: bool = False, numSM: int = -1) -> Tensor: +def _(A: Tensor, B: Tensor, size_per_group: torch.Tensor, trans_b: bool = False) -> Tensor: M, K = A.shape if trans_b: num_groups, N, BK = B.shape diff --git a/xtuner/v1/ops/moe/cuda/triton_kernels/m_grouped_gemm_TMA_triton3_4.py b/xtuner/v1/ops/moe/cuda/triton_kernels/m_grouped_gemm_TMA_triton3_4.py index 6fc4f5bf9..4d198db5d 100644 --- a/xtuner/v1/ops/moe/cuda/triton_kernels/m_grouped_gemm_TMA_triton3_4.py +++ b/xtuner/v1/ops/moe/cuda/triton_kernels/m_grouped_gemm_TMA_triton3_4.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +import os from typing import Optional import torch @@ -7,6 +8,9 @@ from torch import Tensor +SM_MARGIN = int(os.environ.get("XTUNER_SM_MARGIN", 0)) + + def get_cuda_autotune_config(): return [ triton.Config({"BLOCK_N": 64, "BLOCK_K": 256, "GROUP_M": 6}, num_stages=3, num_warps=8), @@ -245,9 +249,7 @@ def repeat_interleave( @torch.library.custom_op("moe::m_grouped_gemm", mutates_args=()) -def m_grouped_gemm( - A: Tensor, B: Tensor, size_per_group: torch.Tensor, trans_b: bool = False, numSM: int = -1 -) -> Tensor: +def m_grouped_gemm(A: Tensor, B: Tensor, size_per_group: torch.Tensor, trans_b: bool = False) -> Tensor: assert A.dim() == 2 assert B.dim() == 3 @@ -280,7 +282,7 @@ def m_grouped_gemm( group_end = size_per_group.cumsum(0) - size_per_group + size_per_group group_start = size_per_group.cumsum(0) - size_per_group - NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count if numSM <= 0 else numSM + NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count - SM_MARGIN dtype_mapping = {torch.bfloat16: 0, torch.float16: 1} dtype_a = dtype_mapping.get(A.dtype, -1) @@ -324,7 +326,7 @@ def alloc_fn(size: int, alignment: int, stream: Optional[int]): @m_grouped_gemm.register_fake -def _(A: Tensor, B: Tensor, size_per_group: torch.Tensor, trans_b: bool = False, numSM: int = -1) -> Tensor: +def _(A: Tensor, B: Tensor, size_per_group: torch.Tensor, trans_b: bool = False) -> Tensor: M, K = A.shape if trans_b: num_groups, N, BK = B.shape diff --git a/xtuner/v1/train/trainer.py b/xtuner/v1/train/trainer.py index 645d14311..f9640e2ca 100644 --- a/xtuner/v1/train/trainer.py +++ b/xtuner/v1/train/trainer.py @@ -1786,6 +1786,7 @@ def _setup_env(self): "XTUNER_USE_CUTLASS_GROUP_GEMM": os.getenv("XTUNER_USE_CUTLASS_GROUP_GEMM"), "GROUPED_GEMM_USE_CUTLASS": os.getenv("GROUPED_GEMM_USE_CUTLASS"), "XTUNER_USE_NATIVE_RMSNORM": os.getenv("XTUNER_USE_NATIVE_RMSNORM"), + "XTUNER_SM_MARGIN": os.getenv("XTUNER_SM_MARGIN"), } for k, v in env.items():