Skip to content

refactor: use ctypes binding#2255

Open
amd-ruitang3 wants to merge 5 commits intomainfrom
refactor_bind_kl
Open

refactor: use ctypes binding#2255
amd-ruitang3 wants to merge 5 commits intomainfrom
refactor_bind_kl

Conversation

@amd-ruitang3
Copy link
Contributor

Motivation

Technical Details

Test Plan

Test Result

Submission Checklist

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR refactors the A16W16 ASM GEMM path to avoid Torch/pybind dependencies by introducing a C ABI entrypoint invoked via ctypes, along with a lightweight tensor/dtype bridge shared between C++ and Python.

Changes:

  • Replaces the torch::Tensor/pybind interface for gemm_a16w16_asm with an exported extern "C" function taking an AiterTensor descriptor and an explicit hipStream_t.
  • Adds shared dtype/tensor definitions (AiterDtype, AiterTensor) and Python utilities to mirror them and perform zero-copy conversions from torch.Tensor.
  • Extends JIT compile_ops to support a torch_exclude mode that loads and calls the built .so via ctypes instead of torch.ops.

Reviewed changes

Copilot reviewed 8 out of 8 changed files in this pull request and generated 6 comments.

Show a summary per file
File Description
csrc/py_itfs_cu/asm_gemm_a16w16.cu Switches GEMM entrypoint to a C ABI using AiterTensor + explicit stream; replaces TORCH_CHECK with AITER_CHECK.
csrc/include/aiter_hip_common.h Adds AITER_CHECK macro and defines the AiterTensor struct used by the C ABI.
csrc/include/aiter_enum.h Introduces AiterDtype enum plus helpers for element sizing and stringification.
aiter/utility/dtypes.py Builds dtype mappings from the parsed C++ header and adds torch_to_aiter() conversion to AiterTensor.
aiter/utility/aiter_types.py New: parses AiterDtype from aiter_enum.h and defines the ctypes.Structure for AiterTensor.
aiter/ops/gemm_op_a16w16.py Marks the op as torch_exclude=True so it routes through the new ctypes-based call path.
aiter/jit/optCompilerConfig.json Removes the pybind source from module_gemm_a16w16_asm, leaving only the C ABI implementation.
aiter/jit/core.py Adds _ctypes_call() and a torch_exclude mode in compile_ops() to load/call non-Torch .so modules.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

next_val = int(val.strip())
else:
name = line
result[name.removeprefix(_PREFIX)] = next_val
Copy link

Copilot AI Mar 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

name.removeprefix(_PREFIX) requires Python 3.9+, but this repo declares python_requires >= 3.8. On Python 3.8 this will raise AttributeError during import, breaking all dtype handling. Please replace with a 3.8-compatible prefix strip (e.g., manual startswith/slice) or bump the minimum Python version accordingly.

Suggested change
result[name.removeprefix(_PREFIX)] = next_val
short_name = name[len(_PREFIX):] if name.startswith(_PREFIX) else name
result[short_name] = next_val

Copilot uses AI. Check for mistakes.
Comment on lines +984 to +986
c_args.append(
ctypes.c_void_p(torch.cuda.current_stream().cuda_stream)
) # stream
Copy link

Copilot AI Mar 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The ctypes wrapper always passes torch.cuda.current_stream().cuda_stream without selecting the stream for the tensors' device. If callers have tensors on a non-current device, this can pass a stream from the wrong device into the HIP kernel launch (potentially leading to invalid launches or memory corruption). Please pick the stream for the relevant tensor device (e.g., out/A) or accept a stream argument from the caller.

Copilot uses AI. Check for mistakes.
Comment on lines +916 to +935
hints = typing.get_type_hints(func)
argtypes = []
for pname in inspect.signature(func).parameters:
hint = hints.get(pname) ### type hint
origin = typing.get_origin(hint) ### check if union type
type_args = typing.get_args(hint)
if hint is torch.Tensor:
argtypes.append(ctypes.POINTER(AiterTensor))
elif origin is typing.Union and torch.Tensor in type_args:
argtypes.append(ctypes.POINTER(AiterTensor))
elif origin is typing.Union and int in type_args:
argtypes.append(ctypes.c_int)
elif origin is typing.Union and str in type_args:
argtypes.append(ctypes.c_char_p)
elif hint is bool:
argtypes.append(ctypes.c_int)
elif hint is int:
argtypes.append(ctypes.c_int)
else:
argtypes.append(ctypes.c_void_p)
Copy link

Copilot AI Mar 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_ctypes_call only recognizes optionals when typing.get_origin(hint) is typing.Union. This misses PEP604 unions (Tensor | None, int | None, etc.), which are used widely in this repo, and will cause incorrect ctypes argtypes/conversions (falling back to c_void_p). Please also handle types.UnionType / origin is types.UnionType (and/or origin in {typing.Union, types.UnionType}) in both the argtypes building and argument conversion logic.

Copilot uses AI. Check for mistakes.
Comment on lines +205 to +219
size_t A_elem_size = AiterDtype_element_size(A->dtype);
size_t B_elem_size = AiterDtype_element_size(B->dtype);
size_t out_elem_size = AiterDtype_element_size(out->dtype);

KernelArgs args = {};
args.ptr_D = (void*)out.data_ptr();
args.ptr_D = out->ptr;
args.ptr_C = nullptr;
args.ptr_A = (void*)A.data_ptr();
args.ptr_B = (void*)B.data_ptr();
args.ptr_Bias = bias.has_value() ? (void*)bias.value().data_ptr() : nullptr;
args.ptr_A = A->ptr;
args.ptr_B = B->ptr;
args.ptr_Bias = bias ? bias->ptr : nullptr;
args.alpha = 1.0f;
args.beta = 0.0f;
args.stride_A0 = A.stride(0) * A.element_size();
args.stride_B0 = B.stride(0) * B.element_size();
args.stride_C0 = args.stride_D0 = Ndim * out.element_size();
args.stride_A0 = A->strides[0] * A_elem_size;
args.stride_B0 = B->strides[0] * B_elem_size;
args.stride_C0 = args.stride_D0 = Ndim * out_elem_size;
Copy link

Copilot AI Mar 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AiterDtype_element_size() returns 0 for unknown dtypes, but the result is used to compute byte strides. Currently only out->dtype is validated; A/B dtypes (and the computed element sizes) are unchecked, so an unsupported dtype could silently produce zero/incorrect strides and wrong memory accesses. Please validate A/B dtypes (and/or *_elem_size > 0) before using them in stride math.

Copilot uses AI. Check for mistakes.
Comment on lines 4 to +38
@@ -18,3 +19,52 @@ enum class QuantType : int
per_1x128,
per_128x128,
};
typedef enum {
AITER_DTYPE_fp8,
AITER_DTYPE_fp8_e8m0,
AITER_DTYPE_fp16,
AITER_DTYPE_bf16,
AITER_DTYPE_fp32,
AITER_DTYPE_i4x2,
AITER_DTYPE_fp4x2,
AITER_DTYPE_u32,
AITER_DTYPE_i32,
AITER_DTYPE_i16,
AITER_DTYPE_i8,
} AiterDtype;

static inline size_t AiterDtype_element_size(AiterDtype dtype)
{
switch (dtype) {
Copy link

Copilot AI Mar 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

aiter_enum.h defines functions using size_t but doesn't include a header that guarantees size_t is declared (<cstddef>/<stddef.h>). This relies on transitive includes from other headers and can break compilation when aiter_enum.h is included standalone. Please include <cstddef> (or <stddef.h>) directly here.

Copilot uses AI. Check for mistakes.
amd-ruitang3 and others added 4 commits March 12, 2026 04:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants