Conversation
There was a problem hiding this comment.
Pull request overview
This PR refactors the A16W16 ASM GEMM path to avoid Torch/pybind dependencies by introducing a C ABI entrypoint invoked via ctypes, along with a lightweight tensor/dtype bridge shared between C++ and Python.
Changes:
- Replaces the
torch::Tensor/pybind interface forgemm_a16w16_asmwith an exportedextern "C"function taking anAiterTensordescriptor and an explicithipStream_t. - Adds shared dtype/tensor definitions (
AiterDtype,AiterTensor) and Python utilities to mirror them and perform zero-copy conversions fromtorch.Tensor. - Extends JIT
compile_opsto support atorch_excludemode that loads and calls the built.soviactypesinstead oftorch.ops.
Reviewed changes
Copilot reviewed 8 out of 8 changed files in this pull request and generated 6 comments.
Show a summary per file
| File | Description |
|---|---|
| csrc/py_itfs_cu/asm_gemm_a16w16.cu | Switches GEMM entrypoint to a C ABI using AiterTensor + explicit stream; replaces TORCH_CHECK with AITER_CHECK. |
| csrc/include/aiter_hip_common.h | Adds AITER_CHECK macro and defines the AiterTensor struct used by the C ABI. |
| csrc/include/aiter_enum.h | Introduces AiterDtype enum plus helpers for element sizing and stringification. |
| aiter/utility/dtypes.py | Builds dtype mappings from the parsed C++ header and adds torch_to_aiter() conversion to AiterTensor. |
| aiter/utility/aiter_types.py | New: parses AiterDtype from aiter_enum.h and defines the ctypes.Structure for AiterTensor. |
| aiter/ops/gemm_op_a16w16.py | Marks the op as torch_exclude=True so it routes through the new ctypes-based call path. |
| aiter/jit/optCompilerConfig.json | Removes the pybind source from module_gemm_a16w16_asm, leaving only the C ABI implementation. |
| aiter/jit/core.py | Adds _ctypes_call() and a torch_exclude mode in compile_ops() to load/call non-Torch .so modules. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| next_val = int(val.strip()) | ||
| else: | ||
| name = line | ||
| result[name.removeprefix(_PREFIX)] = next_val |
There was a problem hiding this comment.
name.removeprefix(_PREFIX) requires Python 3.9+, but this repo declares python_requires >= 3.8. On Python 3.8 this will raise AttributeError during import, breaking all dtype handling. Please replace with a 3.8-compatible prefix strip (e.g., manual startswith/slice) or bump the minimum Python version accordingly.
| result[name.removeprefix(_PREFIX)] = next_val | |
| short_name = name[len(_PREFIX):] if name.startswith(_PREFIX) else name | |
| result[short_name] = next_val |
| c_args.append( | ||
| ctypes.c_void_p(torch.cuda.current_stream().cuda_stream) | ||
| ) # stream |
There was a problem hiding this comment.
The ctypes wrapper always passes torch.cuda.current_stream().cuda_stream without selecting the stream for the tensors' device. If callers have tensors on a non-current device, this can pass a stream from the wrong device into the HIP kernel launch (potentially leading to invalid launches or memory corruption). Please pick the stream for the relevant tensor device (e.g., out/A) or accept a stream argument from the caller.
| hints = typing.get_type_hints(func) | ||
| argtypes = [] | ||
| for pname in inspect.signature(func).parameters: | ||
| hint = hints.get(pname) ### type hint | ||
| origin = typing.get_origin(hint) ### check if union type | ||
| type_args = typing.get_args(hint) | ||
| if hint is torch.Tensor: | ||
| argtypes.append(ctypes.POINTER(AiterTensor)) | ||
| elif origin is typing.Union and torch.Tensor in type_args: | ||
| argtypes.append(ctypes.POINTER(AiterTensor)) | ||
| elif origin is typing.Union and int in type_args: | ||
| argtypes.append(ctypes.c_int) | ||
| elif origin is typing.Union and str in type_args: | ||
| argtypes.append(ctypes.c_char_p) | ||
| elif hint is bool: | ||
| argtypes.append(ctypes.c_int) | ||
| elif hint is int: | ||
| argtypes.append(ctypes.c_int) | ||
| else: | ||
| argtypes.append(ctypes.c_void_p) |
There was a problem hiding this comment.
_ctypes_call only recognizes optionals when typing.get_origin(hint) is typing.Union. This misses PEP604 unions (Tensor | None, int | None, etc.), which are used widely in this repo, and will cause incorrect ctypes argtypes/conversions (falling back to c_void_p). Please also handle types.UnionType / origin is types.UnionType (and/or origin in {typing.Union, types.UnionType}) in both the argtypes building and argument conversion logic.
csrc/py_itfs_cu/asm_gemm_a16w16.cu
Outdated
| size_t A_elem_size = AiterDtype_element_size(A->dtype); | ||
| size_t B_elem_size = AiterDtype_element_size(B->dtype); | ||
| size_t out_elem_size = AiterDtype_element_size(out->dtype); | ||
|
|
||
| KernelArgs args = {}; | ||
| args.ptr_D = (void*)out.data_ptr(); | ||
| args.ptr_D = out->ptr; | ||
| args.ptr_C = nullptr; | ||
| args.ptr_A = (void*)A.data_ptr(); | ||
| args.ptr_B = (void*)B.data_ptr(); | ||
| args.ptr_Bias = bias.has_value() ? (void*)bias.value().data_ptr() : nullptr; | ||
| args.ptr_A = A->ptr; | ||
| args.ptr_B = B->ptr; | ||
| args.ptr_Bias = bias ? bias->ptr : nullptr; | ||
| args.alpha = 1.0f; | ||
| args.beta = 0.0f; | ||
| args.stride_A0 = A.stride(0) * A.element_size(); | ||
| args.stride_B0 = B.stride(0) * B.element_size(); | ||
| args.stride_C0 = args.stride_D0 = Ndim * out.element_size(); | ||
| args.stride_A0 = A->strides[0] * A_elem_size; | ||
| args.stride_B0 = B->strides[0] * B_elem_size; | ||
| args.stride_C0 = args.stride_D0 = Ndim * out_elem_size; |
There was a problem hiding this comment.
AiterDtype_element_size() returns 0 for unknown dtypes, but the result is used to compute byte strides. Currently only out->dtype is validated; A/B dtypes (and the computed element sizes) are unchecked, so an unsupported dtype could silently produce zero/incorrect strides and wrong memory accesses. Please validate A/B dtypes (and/or *_elem_size > 0) before using them in stride math.
| @@ -18,3 +19,52 @@ enum class QuantType : int | |||
| per_1x128, | |||
| per_128x128, | |||
| }; | |||
| typedef enum { | |||
| AITER_DTYPE_fp8, | |||
| AITER_DTYPE_fp8_e8m0, | |||
| AITER_DTYPE_fp16, | |||
| AITER_DTYPE_bf16, | |||
| AITER_DTYPE_fp32, | |||
| AITER_DTYPE_i4x2, | |||
| AITER_DTYPE_fp4x2, | |||
| AITER_DTYPE_u32, | |||
| AITER_DTYPE_i32, | |||
| AITER_DTYPE_i16, | |||
| AITER_DTYPE_i8, | |||
| } AiterDtype; | |||
|
|
|||
| static inline size_t AiterDtype_element_size(AiterDtype dtype) | |||
| { | |||
| switch (dtype) { | |||
There was a problem hiding this comment.
aiter_enum.h defines functions using size_t but doesn't include a header that guarantees size_t is declared (<cstddef>/<stddef.h>). This relies on transitive includes from other headers and can break compilation when aiter_enum.h is included standalone. Please include <cstddef> (or <stddef.h>) directly here.
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Motivation
Technical Details
Test Plan
Test Result
Submission Checklist