Skip to content

Commit 1f42c51

Browse files
committed
[Feature] add os environ named XTUNER_SM_MARGIN to control the sm number the compute kernel can use
1 parent 76fdef3 commit 1f42c51

6 files changed

Lines changed: 32 additions & 13 deletions

File tree

xtuner/v1/float8/triton_kernels/trans_quant_per_block.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99
from xtuner.v1.float8.float8_utils import to_fp8_saturated
1010

1111

12+
SM_MARGIN = int(os.environ.get("XTUNER_SM_MARGIN", 0))
13+
14+
1215
@triton.jit
1316
def get_group_id(m, group_offsets, g_start, num_experts):
1417
id = 0
@@ -123,7 +126,7 @@ def _trans_per_block_quant_expand_128x(
123126
finfo = torch.finfo(dtype)
124127
fmin = finfo.min
125128
fmax = finfo.max
126-
NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count - int(os.getenv("MinusSM", 0))
129+
NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count - SM_MARGIN
127130
grid = (NUM_SMS,)
128131
num_experts = size_per_group.shape[0]
129132
M, N = input_tensor.shape

xtuner/v1/ops/moe/cuda/triton_kernels/k_grouped_gemm_TMA.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
import os
3+
14
import torch
25
import triton
36
import triton.language as tl
@@ -6,6 +9,9 @@
69
from .utils import TmaAutoTuneHelper
710

811

12+
SM_MARGIN = int(os.environ.get("XTUNER_SM_MARGIN", 0))
13+
14+
915
def get_cuda_autotune_config():
1016
return [
1117
triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 64, "GROUP_M": 12}, num_stages=3, num_warps=8),
@@ -148,7 +154,7 @@ def k_grouped_gemm(A: Tensor, B: Tensor, size_per_group: torch.Tensor) -> Tensor
148154
assert dtype_b >= 0, f"data type {B.dtype} not supported"
149155
assert dtype_c >= 0, f"data type {C.dtype} not supported"
150156

151-
NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
157+
NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count - SM_MARGIN
152158

153159
desc_helper = TmaAutoTuneHelper()
154160
desc_helper.init_tma_descriptor("a")

xtuner/v1/ops/moe/cuda/triton_kernels/k_grouped_gemm_TMA_triton3_4.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2+
import os
23
from typing import Optional
34

45
import torch
@@ -7,6 +8,9 @@
78
from torch import Tensor
89

910

11+
SM_MARGIN = int(os.environ.get("XTUNER_SM_MARGIN", 0))
12+
13+
1014
def get_cuda_autotune_config():
1115
return [
1216
triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 64, "GROUP_M": 12}, num_stages=5, num_warps=8),
@@ -160,7 +164,7 @@ def k_grouped_gemm(A: Tensor, B: Tensor, size_per_group: torch.Tensor) -> Tensor
160164
assert dtype_b >= 0, f"data type {B.dtype} not supported"
161165
assert dtype_c >= 0, f"data type {C.dtype} not supported"
162166

163-
NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
167+
NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count - SM_MARGIN
164168

165169
def grid(META):
166170
# assert N % META["BLOCK_N"] == 0, "Only support when N is a multiple of BLOCK_N"

xtuner/v1/ops/moe/cuda/triton_kernels/m_grouped_gemm_TMA.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2+
import os
3+
24
import torch
35
import triton
46
import triton.language as tl
@@ -7,6 +9,9 @@
79
from .utils import TmaAutoTuneHelper
810

911

12+
SM_MARGIN = int(os.environ.get("XTUNER_SM_MARGIN", 0))
13+
14+
1015
def get_cuda_autotune_config():
1116
return [
1217
triton.Config({"BLOCK_N": 64, "BLOCK_K": 256, "GROUP_M": 6}, num_stages=3, num_warps=8),
@@ -229,9 +234,7 @@ def repeat_interleave(
229234

230235

231236
@torch.library.custom_op("moe::m_grouped_gemm", mutates_args=())
232-
def m_grouped_gemm(
233-
A: Tensor, B: Tensor, size_per_group: torch.Tensor, trans_b: bool = False, numSM: int = -1
234-
) -> Tensor:
237+
def m_grouped_gemm(A: Tensor, B: Tensor, size_per_group: torch.Tensor, trans_b: bool = False) -> Tensor:
235238
assert A.dim() == 2
236239
assert B.dim() == 3
237240

@@ -264,7 +267,7 @@ def m_grouped_gemm(
264267
group_end = size_per_group.cumsum(0) - size_per_group + size_per_group
265268
group_start = size_per_group.cumsum(0) - size_per_group
266269

267-
NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count if numSM <= 0 else numSM
270+
NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count - SM_MARGIN
268271

269272
dtype_mapping = {torch.bfloat16: 0, torch.float16: 1}
270273
dtype_a = dtype_mapping.get(A.dtype, -1)
@@ -347,7 +350,7 @@ def grid(META):
347350

348351

349352
@m_grouped_gemm.register_fake
350-
def _(A: Tensor, B: Tensor, size_per_group: torch.Tensor, trans_b: bool = False, numSM: int = -1) -> Tensor:
353+
def _(A: Tensor, B: Tensor, size_per_group: torch.Tensor, trans_b: bool = False) -> Tensor:
351354
M, K = A.shape
352355
if trans_b:
353356
num_groups, N, BK = B.shape

xtuner/v1/ops/moe/cuda/triton_kernels/m_grouped_gemm_TMA_triton3_4.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2+
import os
23
from typing import Optional
34

45
import torch
@@ -7,6 +8,9 @@
78
from torch import Tensor
89

910

11+
SM_MARGIN = int(os.environ.get("XTUNER_SM_MARGIN", 0))
12+
13+
1014
def get_cuda_autotune_config():
1115
return [
1216
triton.Config({"BLOCK_N": 64, "BLOCK_K": 256, "GROUP_M": 6}, num_stages=3, num_warps=8),
@@ -245,9 +249,7 @@ def repeat_interleave(
245249

246250

247251
@torch.library.custom_op("moe::m_grouped_gemm", mutates_args=())
248-
def m_grouped_gemm(
249-
A: Tensor, B: Tensor, size_per_group: torch.Tensor, trans_b: bool = False, numSM: int = -1
250-
) -> Tensor:
252+
def m_grouped_gemm(A: Tensor, B: Tensor, size_per_group: torch.Tensor, trans_b: bool = False) -> Tensor:
251253
assert A.dim() == 2
252254
assert B.dim() == 3
253255

@@ -280,7 +282,7 @@ def m_grouped_gemm(
280282
group_end = size_per_group.cumsum(0) - size_per_group + size_per_group
281283
group_start = size_per_group.cumsum(0) - size_per_group
282284

283-
NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count if numSM <= 0 else numSM
285+
NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count - SM_MARGIN
284286

285287
dtype_mapping = {torch.bfloat16: 0, torch.float16: 1}
286288
dtype_a = dtype_mapping.get(A.dtype, -1)
@@ -324,7 +326,7 @@ def alloc_fn(size: int, alignment: int, stream: Optional[int]):
324326

325327

326328
@m_grouped_gemm.register_fake
327-
def _(A: Tensor, B: Tensor, size_per_group: torch.Tensor, trans_b: bool = False, numSM: int = -1) -> Tensor:
329+
def _(A: Tensor, B: Tensor, size_per_group: torch.Tensor, trans_b: bool = False) -> Tensor:
328330
M, K = A.shape
329331
if trans_b:
330332
num_groups, N, BK = B.shape

xtuner/v1/train/trainer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1786,6 +1786,7 @@ def _setup_env(self):
17861786
"XTUNER_USE_CUTLASS_GROUP_GEMM": os.getenv("XTUNER_USE_CUTLASS_GROUP_GEMM"),
17871787
"GROUPED_GEMM_USE_CUTLASS": os.getenv("GROUPED_GEMM_USE_CUTLASS"),
17881788
"XTUNER_USE_NATIVE_RMSNORM": os.getenv("XTUNER_USE_NATIVE_RMSNORM"),
1789+
"XTUNER_SM_MARGIN": os.getenv("XTUNER_SM_MARGIN"),
17891790
}
17901791

17911792
for k, v in env.items():

0 commit comments

Comments
 (0)