Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion xtuner/v1/float8/triton_kernels/trans_quant_per_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
from xtuner.v1.float8.float8_utils import to_fp8_saturated


SM_MARGIN = int(os.environ.get("XTUNER_SM_MARGIN", 0))


@triton.jit
def get_group_id(m, group_offsets, g_start, num_experts):
id = 0
Expand Down Expand Up @@ -123,7 +126,7 @@ def _trans_per_block_quant_expand_128x(
finfo = torch.finfo(dtype)
fmin = finfo.min
fmax = finfo.max
NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count - int(os.getenv("MinusSM", 0))
NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count - SM_MARGIN
grid = (NUM_SMS,)
num_experts = size_per_group.shape[0]
M, N = input_tensor.shape
Expand Down
8 changes: 7 additions & 1 deletion xtuner/v1/ops/moe/cuda/triton_kernels/k_grouped_gemm_TMA.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os

import torch
import triton
import triton.language as tl
Expand All @@ -6,6 +9,9 @@
from .utils import TmaAutoTuneHelper


SM_MARGIN = int(os.environ.get("XTUNER_SM_MARGIN", 0))


def get_cuda_autotune_config():
return [
triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 64, "GROUP_M": 12}, num_stages=3, num_warps=8),
Expand Down Expand Up @@ -148,7 +154,7 @@ def k_grouped_gemm(A: Tensor, B: Tensor, size_per_group: torch.Tensor) -> Tensor
assert dtype_b >= 0, f"data type {B.dtype} not supported"
assert dtype_c >= 0, f"data type {C.dtype} not supported"

NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count - SM_MARGIN

desc_helper = TmaAutoTuneHelper()
desc_helper.init_tma_descriptor("a")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os
from typing import Optional

import torch
Expand All @@ -7,6 +8,9 @@
from torch import Tensor


SM_MARGIN = int(os.environ.get("XTUNER_SM_MARGIN", 0))


def get_cuda_autotune_config():
return [
triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 64, "GROUP_M": 12}, num_stages=5, num_warps=8),
Expand Down Expand Up @@ -160,7 +164,7 @@ def k_grouped_gemm(A: Tensor, B: Tensor, size_per_group: torch.Tensor) -> Tensor
assert dtype_b >= 0, f"data type {B.dtype} not supported"
assert dtype_c >= 0, f"data type {C.dtype} not supported"

NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count - SM_MARGIN

def grid(META):
# assert N % META["BLOCK_N"] == 0, "Only support when N is a multiple of BLOCK_N"
Expand Down
13 changes: 8 additions & 5 deletions xtuner/v1/ops/moe/cuda/triton_kernels/m_grouped_gemm_TMA.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os

import torch
import triton
import triton.language as tl
Expand All @@ -7,6 +9,9 @@
from .utils import TmaAutoTuneHelper


SM_MARGIN = int(os.environ.get("XTUNER_SM_MARGIN", 0))


def get_cuda_autotune_config():
return [
triton.Config({"BLOCK_N": 64, "BLOCK_K": 256, "GROUP_M": 6}, num_stages=3, num_warps=8),
Expand Down Expand Up @@ -229,9 +234,7 @@ def repeat_interleave(


@torch.library.custom_op("moe::m_grouped_gemm", mutates_args=())
def m_grouped_gemm(
A: Tensor, B: Tensor, size_per_group: torch.Tensor, trans_b: bool = False, numSM: int = -1
) -> Tensor:
def m_grouped_gemm(A: Tensor, B: Tensor, size_per_group: torch.Tensor, trans_b: bool = False) -> Tensor:
assert A.dim() == 2
assert B.dim() == 3

Expand Down Expand Up @@ -264,7 +267,7 @@ def m_grouped_gemm(
group_end = size_per_group.cumsum(0) - size_per_group + size_per_group
group_start = size_per_group.cumsum(0) - size_per_group

NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count if numSM <= 0 else numSM
NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count - SM_MARGIN

dtype_mapping = {torch.bfloat16: 0, torch.float16: 1}
dtype_a = dtype_mapping.get(A.dtype, -1)
Expand Down Expand Up @@ -347,7 +350,7 @@ def grid(META):


@m_grouped_gemm.register_fake
def _(A: Tensor, B: Tensor, size_per_group: torch.Tensor, trans_b: bool = False, numSM: int = -1) -> Tensor:
def _(A: Tensor, B: Tensor, size_per_group: torch.Tensor, trans_b: bool = False) -> Tensor:
M, K = A.shape
if trans_b:
num_groups, N, BK = B.shape
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os
from typing import Optional

import torch
Expand All @@ -7,6 +8,9 @@
from torch import Tensor


SM_MARGIN = int(os.environ.get("XTUNER_SM_MARGIN", 0))


def get_cuda_autotune_config():
return [
triton.Config({"BLOCK_N": 64, "BLOCK_K": 256, "GROUP_M": 6}, num_stages=3, num_warps=8),
Expand Down Expand Up @@ -245,9 +249,7 @@ def repeat_interleave(


@torch.library.custom_op("moe::m_grouped_gemm", mutates_args=())
def m_grouped_gemm(
A: Tensor, B: Tensor, size_per_group: torch.Tensor, trans_b: bool = False, numSM: int = -1
) -> Tensor:
def m_grouped_gemm(A: Tensor, B: Tensor, size_per_group: torch.Tensor, trans_b: bool = False) -> Tensor:
assert A.dim() == 2
assert B.dim() == 3

Expand Down Expand Up @@ -280,7 +282,7 @@ def m_grouped_gemm(
group_end = size_per_group.cumsum(0) - size_per_group + size_per_group
group_start = size_per_group.cumsum(0) - size_per_group

NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count if numSM <= 0 else numSM
NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count - SM_MARGIN

dtype_mapping = {torch.bfloat16: 0, torch.float16: 1}
dtype_a = dtype_mapping.get(A.dtype, -1)
Expand Down Expand Up @@ -324,7 +326,7 @@ def alloc_fn(size: int, alignment: int, stream: Optional[int]):


@m_grouped_gemm.register_fake
def _(A: Tensor, B: Tensor, size_per_group: torch.Tensor, trans_b: bool = False, numSM: int = -1) -> Tensor:
def _(A: Tensor, B: Tensor, size_per_group: torch.Tensor, trans_b: bool = False) -> Tensor:
M, K = A.shape
if trans_b:
num_groups, N, BK = B.shape
Expand Down
1 change: 1 addition & 0 deletions xtuner/v1/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1786,6 +1786,7 @@ def _setup_env(self):
"XTUNER_USE_CUTLASS_GROUP_GEMM": os.getenv("XTUNER_USE_CUTLASS_GROUP_GEMM"),
"GROUPED_GEMM_USE_CUTLASS": os.getenv("GROUPED_GEMM_USE_CUTLASS"),
"XTUNER_USE_NATIVE_RMSNORM": os.getenv("XTUNER_USE_NATIVE_RMSNORM"),
"XTUNER_SM_MARGIN": os.getenv("XTUNER_SM_MARGIN"),
}

for k, v in env.items():
Expand Down