Added JIT capabilities into all operators except transform operators.#1085
Added JIT capabilities into all operators except transform operators.#1085cliffburdick merged 1 commit intomainfrom
Conversation
89f3eab to
055fb87
Compare
|
/build |
There was a problem hiding this comment.
Greptile Overview
Greptile Summary
This PR adds comprehensive JIT (Just-In-Time) compilation support to all MatX operators and generators except transform operators. The implementation follows a consistent pattern across all affected files.
Key Changes
- Core Infrastructure: Enhanced capability system queries for
JIT_TYPE_QUERYandJIT_CLASS_QUERYto enable runtime code generation - Operator Pattern: Each operator now includes:
JIT_Storagestruct with serializable membersToJITStorage()method for converting runtime state to JIT-compatible formget_jit_class_name()andget_jit_op_str()for code generation- Enhanced
get_capability()with JIT-specific capability handling
- Executor:
CUDAJITExecutorrefactored with improved launch parameter selection using capability queries - Grid Dimensions: Extended rank 3 and 4 support in
get_grid_dims_jit(), added assertion for groups_per_block divisibility - RTC Compatibility: Changed
std::tocuda::std::types throughout for NVRTC compiler compatibility
Operators/Generators Modified (66 files total)
- All generators: alternate, bartlett, blackman, chirp, diag, fftfreq, flattop, hamming, hanning, linspace, logspace, meshgrid, range
- All operators: at, cart2sph, cast, clone, collapse, comma, constval, cross, diag, fftshift, flatten, frexp, hermitian, if, ifelse, index, interleaved, isclose, kronecker, legendre, overlap, pad, permute, planar, polyval, r2c, remap, repmat, reshape, reverse, select, self, shift, sign, slice, sph2cart, toeplitz, unary_operators, updownsample
Notable Implementation Details
- FFTShift operators use
mutablemembers and perfect forwarding to support both const/non-const usage - Range generator embeds runtime values (first, step) in class name strings for uniqueness
- Scalar operations use
MATX_IFNDEF_CUDACC_RTCmacro to exclude host-only code from RTC compilation
Confidence Score: 4/5
- This PR is safe to merge with testing. The changes follow a consistent pattern and enable significant new functionality
- Score reflects the large scale of changes (66 files, 4800+ additions) and complexity of JIT code generation. While the implementation pattern is consistent and well-structured, the PR description states tests will be enabled in a subsequent commit, meaning this code is not yet fully validated against the test suite. The core logic appears sound but warrants thorough testing before production use
- Pay close attention to
include/matx/core/get_grid_dims.hfor the block calculation changes with groups_per_block, andinclude/matx/operators/fftshift.hfor the mutable member pattern
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| include/matx/executors/jit_cuda.h | 4/5 | JIT executor refactored with improved capability queries and launch parameter selection, comprehensive EPT handling for ranks 0-4 |
| include/matx/core/operator_utils.h | 5/5 | Moved to_jit_storage helper, changed std::multiplies to cuda::std::multiplies for RTC compatibility |
| include/matx/core/get_grid_dims.h | 4/5 | Enabled rank 3 and 4 support in get_grid_dims_jit, added assertion for groups_per_block divisibility, fixed block calculation for multiple groups |
| include/matx/operators/cast.h | 4/5 | Added JIT support with ToJITStorage, get_jit_class_name, get_jit_op_str methods and JIT_TYPE_QUERY/JIT_CLASS_QUERY capability handling for CastOp and ComplexCastOp |
| include/matx/operators/scalar_ops.h | 4/5 | Added MATX_IFNDEF_CUDACC_RTC macro to conditionally exclude host-only code from RTC compilation, removed duplicate scalar_internal functions from JIT code generation |
| include/matx/operators/fftshift.h | 3/5 | Added JIT support for all FFTShift/IFFTShift variants with mutable op_ members and non-const operator() overloads, EPT limited to ONE |
Sequence Diagram
sequenceDiagram
participant User
participant CUDAJITExecutor
participant Operator
participant CapabilitySystem
participant JITCompiler as NVRTC Compiler
participant Kernel
User->>CUDAJITExecutor: Exec(op)
CUDAJITExecutor->>Operator: Check SUPPORTS_JIT capability
Operator->>CapabilitySystem: get_capability<SUPPORTS_JIT>()
CapabilitySystem-->>CUDAJITExecutor: true/false
alt JIT Not Supported
CUDAJITExecutor-->>User: Throw error
end
CUDAJITExecutor->>Operator: Query ELEMENTS_PER_THREAD bounds
Operator->>CapabilitySystem: get_capability<ELEMENTS_PER_THREAD>()
CapabilitySystem-->>CUDAJITExecutor: [min_ept, max_ept]
CUDAJITExecutor->>CUDAJITExecutor: find_best_launch_params()
Note over CUDAJITExecutor: Iterates through EPT values<br/>to find optimal config
CUDAJITExecutor->>Operator: Query JIT_CLASS_QUERY
Operator->>CapabilitySystem: get_capability<JIT_CLASS_QUERY>()
Note over Operator: Calls get_jit_op_str()<br/>to generate class definitions
Operator->>Operator: Recursively collect child ops
CapabilitySystem-->>CUDAJITExecutor: Map of class definitions
CUDAJITExecutor->>Operator: Query JIT_TYPE_QUERY
Operator->>CapabilitySystem: get_capability<JIT_TYPE_QUERY>()
Note over Operator: Returns fully qualified<br/>JIT type name
CapabilitySystem-->>CUDAJITExecutor: Type string
CUDAJITExecutor->>Operator: ToJITStorage()
Operator-->>CUDAJITExecutor: JIT_Storage struct
CUDAJITExecutor->>JITCompiler: nvrtc_compile_and_run()
Note over JITCompiler: Compiles operator graph<br/>with optimal EPT and params
JITCompiler->>Kernel: Launch compiled kernel
Kernel-->>User: Results
66 files reviewed, no comments
Tested with standalone unit tests, but will turn on real tests in subsequent commit.
055fb87 to
c35db85
Compare
|
/build |
There was a problem hiding this comment.
Greptile Overview
Greptile Summary
This PR extends JIT compilation support from rank ≤2 operators to all operators with ranks 0-4, excluding transform operators. The implementation adds JIT capabilities to 66 files including generators (range, linspace, chirp, etc.) and operators (cast, reshape, fftshift, etc.).
Key architectural changes:
- Removed rank restrictions in
jit_cuda.h, enabling ranks 3-4 compilation - Updated kernel indexing in
jit_kernel.hfor rank 3-4 to properly use thread/block dimensions - Modified grid dimension calculations to support multi-rank scenarios with groups-per-block
- Refactored CMake to decouple JIT support from MathDx dependency
Operator pattern:
Each JIT-enabled operator now implements:
JIT_Storagestruct for runtime state serializationToJITStorage()method to convert operators to JIT-compatible formget_jit_class_name()for unique type identificationget_jit_op_str()to generate CUDA source code- Capability handlers for
JIT_TYPE_QUERY,JIT_CLASS_QUERY, andDYN_SHM_SIZE
Build system:
Added MATX_EN_JIT CMake option (default OFF) to enable JIT independently of MathDx, with proper NVRTC library linking.
Confidence Score: 4/5
- This PR is safe to merge with minor risk - the changes are well-structured but testing is explicitly deferred to a subsequent commit
- Score reflects: (1) systematic implementation pattern consistently applied across 66 files, (2) proper separation of JIT code with
#ifdef MATX_EN_JITguards ensuring no impact when disabled, (3) architectural improvements like fixing rank 3-4 kernel indexing bugs, but (4) author explicitly states real tests are disabled and will be enabled in a subsequent commit, creating uncertainty about runtime behavior include/matx/executors/jit_kernel.handinclude/matx/core/get_grid_dims.hcontain critical indexing changes for rank 3-4 that should be carefully tested;include/matx/operators/fftshift.huses mutable members which warrants extra scrutiny
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| include/matx/executors/jit_cuda.h | 4/5 | Removed rank <= 2 restriction for JIT compilation, now supports ranks 0-4; refactored parameters for get_grid_dims_jit function signature change |
| include/matx/core/get_grid_dims.h | 4/5 | Added JIT support for ranks 3 and 4, fixed block calculation for multi-group-per-block scenarios, added assertion for dimension divisibility |
| include/matx/executors/jit_kernel.h | 4/5 | Fixed rank 3 and 4 kernel indexing to properly use threadIdx.y/z and blockDim.y/z; rank 4 kernels now use combined index approach (nmy) to handle 2D dimension mapping |
| CMakeLists.txt | 5/5 | Added MATX_EN_JIT build option, refactored NVRTC support to be shared between JIT and MathDx features |
| include/matx/operators/fftshift.h | 4/5 | Added JIT support for FFTShift/IFFTShift operators with EPT=ONE constraint; made operator members mutable and added non-const operator() overloads for FFTShift2D |
Sequence Diagram
sequenceDiagram
participant User
participant JITExecutor as jitCudaExecutor
participant OpTree as Operator Tree
participant CapSys as Capability System
participant GridCalc as Grid Dimensions
participant KernelProv as Kernel Provider
participant NVRTC as NVRTC Compiler
participant GPU
User->>JITExecutor: exec(operator, stream)
JITExecutor->>OpTree: Check JIT support capability
OpTree-->>JITExecutor: SUPPORTS_JIT = true
JITExecutor->>OpTree: Query ELEMENTS_PER_THREAD bounds
OpTree-->>JITExecutor: EPT bounds [min, max]
JITExecutor->>KernelProv: Create kernel_provider lambda
Note over KernelProv: Maps EPT values to kernel functions<br/>based on rank (0-4) and stride
JITExecutor->>CapSys: find_best_launch_params(op, kernel_provider)
CapSys->>GridCalc: Calculate grid dims for each EPT
GridCalc-->>CapSys: blocks, threads, stride for EPT
CapSys->>KernelProv: Get kernel pointer for EPT
KernelProv-->>CapSys: Kernel function pointer
CapSys-->>JITExecutor: best_ept, shm_size, block_size, groups_per_block
JITExecutor->>GridCalc: get_grid_dims_jit(sizes, best_ept, groups_per_block)
GridCalc-->>JITExecutor: final blocks, threads, stride
JITExecutor->>OpTree: Query JIT_CLASS_QUERY
OpTree->>OpTree: Recursively collect operator classes
OpTree-->>JITExecutor: Map of class definitions
JITExecutor->>OpTree: Query JIT_TYPE_QUERY
OpTree-->>JITExecutor: Instantiated type string
JITExecutor->>NVRTC: nvrtc_compile_and_run(code, op, params)
NVRTC->>NVRTC: Generate complete CUDA source
NVRTC->>NVRTC: Compile with nvrtcCompileProgram
NVRTC->>GPU: Load and launch compiled kernel
GPU-->>User: Execution complete
66 files reviewed, no comments
Tested with standalone unit tests, but will turn on real tests in subsequent commit.