|
1 | 1 | # Copyright (c) OpenMMLab. All rights reserved. |
| 2 | +import os |
2 | 3 | from typing import Optional |
3 | 4 |
|
4 | 5 | import torch |
|
7 | 8 | from torch import Tensor |
8 | 9 |
|
9 | 10 |
|
| 11 | +SM_MARGIN = int(os.environ.get("XTUNER_SM_MARGIN", 0)) |
| 12 | + |
| 13 | + |
10 | 14 | def get_cuda_autotune_config(): |
11 | 15 | return [ |
12 | 16 | triton.Config({"BLOCK_N": 64, "BLOCK_K": 256, "GROUP_M": 6}, num_stages=3, num_warps=8), |
@@ -245,9 +249,7 @@ def repeat_interleave( |
245 | 249 |
|
246 | 250 |
|
247 | 251 | @torch.library.custom_op("moe::m_grouped_gemm", mutates_args=()) |
248 | | -def m_grouped_gemm( |
249 | | - A: Tensor, B: Tensor, size_per_group: torch.Tensor, trans_b: bool = False, numSM: int = -1 |
250 | | -) -> Tensor: |
| 252 | +def m_grouped_gemm(A: Tensor, B: Tensor, size_per_group: torch.Tensor, trans_b: bool = False) -> Tensor: |
251 | 253 | assert A.dim() == 2 |
252 | 254 | assert B.dim() == 3 |
253 | 255 |
|
@@ -280,7 +282,7 @@ def m_grouped_gemm( |
280 | 282 | group_end = size_per_group.cumsum(0) - size_per_group + size_per_group |
281 | 283 | group_start = size_per_group.cumsum(0) - size_per_group |
282 | 284 |
|
283 | | - NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count if numSM <= 0 else numSM |
| 285 | + NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count - SM_MARGIN |
284 | 286 |
|
285 | 287 | dtype_mapping = {torch.bfloat16: 0, torch.float16: 1} |
286 | 288 | dtype_a = dtype_mapping.get(A.dtype, -1) |
@@ -324,7 +326,7 @@ def alloc_fn(size: int, alignment: int, stream: Optional[int]): |
324 | 326 |
|
325 | 327 |
|
326 | 328 | @m_grouped_gemm.register_fake |
327 | | -def _(A: Tensor, B: Tensor, size_per_group: torch.Tensor, trans_b: bool = False, numSM: int = -1) -> Tensor: |
| 329 | +def _(A: Tensor, B: Tensor, size_per_group: torch.Tensor, trans_b: bool = False) -> Tensor: |
328 | 330 | M, K = A.shape |
329 | 331 | if trans_b: |
330 | 332 | num_groups, N, BK = B.shape |
|
0 commit comments