Skip to content

Added JIT capabilities into all operators except transform operators.#1085

Merged
cliffburdick merged 1 commit intomainfrom
all_operator_jit
Nov 5, 2025
Merged

Added JIT capabilities into all operators except transform operators.#1085
cliffburdick merged 1 commit intomainfrom
all_operator_jit

Conversation

@cliffburdick
Copy link
Collaborator

Tested with standalone unit tests, but will turn on real tests in subsequent commit.

@copy-pr-bot
Copy link

copy-pr-bot bot commented Nov 5, 2025

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@cliffburdick
Copy link
Collaborator Author

/build

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR adds comprehensive JIT (Just-In-Time) compilation support to all MatX operators and generators except transform operators. The implementation follows a consistent pattern across all affected files.

Key Changes

  • Core Infrastructure: Enhanced capability system queries for JIT_TYPE_QUERY and JIT_CLASS_QUERY to enable runtime code generation
  • Operator Pattern: Each operator now includes:
    • JIT_Storage struct with serializable members
    • ToJITStorage() method for converting runtime state to JIT-compatible form
    • get_jit_class_name() and get_jit_op_str() for code generation
    • Enhanced get_capability() with JIT-specific capability handling
  • Executor: CUDAJITExecutor refactored with improved launch parameter selection using capability queries
  • Grid Dimensions: Extended rank 3 and 4 support in get_grid_dims_jit(), added assertion for groups_per_block divisibility
  • RTC Compatibility: Changed std:: to cuda::std:: types throughout for NVRTC compiler compatibility

Operators/Generators Modified (66 files total)

  • All generators: alternate, bartlett, blackman, chirp, diag, fftfreq, flattop, hamming, hanning, linspace, logspace, meshgrid, range
  • All operators: at, cart2sph, cast, clone, collapse, comma, constval, cross, diag, fftshift, flatten, frexp, hermitian, if, ifelse, index, interleaved, isclose, kronecker, legendre, overlap, pad, permute, planar, polyval, r2c, remap, repmat, reshape, reverse, select, self, shift, sign, slice, sph2cart, toeplitz, unary_operators, updownsample

Notable Implementation Details

  • FFTShift operators use mutable members and perfect forwarding to support both const/non-const usage
  • Range generator embeds runtime values (first, step) in class name strings for uniqueness
  • Scalar operations use MATX_IFNDEF_CUDACC_RTC macro to exclude host-only code from RTC compilation

Confidence Score: 4/5

  • This PR is safe to merge with testing. The changes follow a consistent pattern and enable significant new functionality
  • Score reflects the large scale of changes (66 files, 4800+ additions) and complexity of JIT code generation. While the implementation pattern is consistent and well-structured, the PR description states tests will be enabled in a subsequent commit, meaning this code is not yet fully validated against the test suite. The core logic appears sound but warrants thorough testing before production use
  • Pay close attention to include/matx/core/get_grid_dims.h for the block calculation changes with groups_per_block, and include/matx/operators/fftshift.h for the mutable member pattern

Important Files Changed

File Analysis

Filename Score Overview
include/matx/executors/jit_cuda.h 4/5 JIT executor refactored with improved capability queries and launch parameter selection, comprehensive EPT handling for ranks 0-4
include/matx/core/operator_utils.h 5/5 Moved to_jit_storage helper, changed std::multiplies to cuda::std::multiplies for RTC compatibility
include/matx/core/get_grid_dims.h 4/5 Enabled rank 3 and 4 support in get_grid_dims_jit, added assertion for groups_per_block divisibility, fixed block calculation for multiple groups
include/matx/operators/cast.h 4/5 Added JIT support with ToJITStorage, get_jit_class_name, get_jit_op_str methods and JIT_TYPE_QUERY/JIT_CLASS_QUERY capability handling for CastOp and ComplexCastOp
include/matx/operators/scalar_ops.h 4/5 Added MATX_IFNDEF_CUDACC_RTC macro to conditionally exclude host-only code from RTC compilation, removed duplicate scalar_internal functions from JIT code generation
include/matx/operators/fftshift.h 3/5 Added JIT support for all FFTShift/IFFTShift variants with mutable op_ members and non-const operator() overloads, EPT limited to ONE

Sequence Diagram

sequenceDiagram
    participant User
    participant CUDAJITExecutor
    participant Operator
    participant CapabilitySystem
    participant JITCompiler as NVRTC Compiler
    participant Kernel

    User->>CUDAJITExecutor: Exec(op)
    CUDAJITExecutor->>Operator: Check SUPPORTS_JIT capability
    Operator->>CapabilitySystem: get_capability<SUPPORTS_JIT>()
    CapabilitySystem-->>CUDAJITExecutor: true/false
    
    alt JIT Not Supported
        CUDAJITExecutor-->>User: Throw error
    end

    CUDAJITExecutor->>Operator: Query ELEMENTS_PER_THREAD bounds
    Operator->>CapabilitySystem: get_capability<ELEMENTS_PER_THREAD>()
    CapabilitySystem-->>CUDAJITExecutor: [min_ept, max_ept]

    CUDAJITExecutor->>CUDAJITExecutor: find_best_launch_params()
    Note over CUDAJITExecutor: Iterates through EPT values<br/>to find optimal config

    CUDAJITExecutor->>Operator: Query JIT_CLASS_QUERY
    Operator->>CapabilitySystem: get_capability<JIT_CLASS_QUERY>()
    Note over Operator: Calls get_jit_op_str()<br/>to generate class definitions
    Operator->>Operator: Recursively collect child ops
    CapabilitySystem-->>CUDAJITExecutor: Map of class definitions

    CUDAJITExecutor->>Operator: Query JIT_TYPE_QUERY
    Operator->>CapabilitySystem: get_capability<JIT_TYPE_QUERY>()
    Note over Operator: Returns fully qualified<br/>JIT type name
    CapabilitySystem-->>CUDAJITExecutor: Type string

    CUDAJITExecutor->>Operator: ToJITStorage()
    Operator-->>CUDAJITExecutor: JIT_Storage struct

    CUDAJITExecutor->>JITCompiler: nvrtc_compile_and_run()
    Note over JITCompiler: Compiles operator graph<br/>with optimal EPT and params
    JITCompiler->>Kernel: Launch compiled kernel
    Kernel-->>User: Results
Loading

66 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Tested with standalone unit tests, but will turn on real tests in
subsequent commit.
@cliffburdick
Copy link
Collaborator Author

/build

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR extends JIT compilation support from rank ≤2 operators to all operators with ranks 0-4, excluding transform operators. The implementation adds JIT capabilities to 66 files including generators (range, linspace, chirp, etc.) and operators (cast, reshape, fftshift, etc.).

Key architectural changes:

  • Removed rank restrictions in jit_cuda.h, enabling ranks 3-4 compilation
  • Updated kernel indexing in jit_kernel.h for rank 3-4 to properly use thread/block dimensions
  • Modified grid dimension calculations to support multi-rank scenarios with groups-per-block
  • Refactored CMake to decouple JIT support from MathDx dependency

Operator pattern:
Each JIT-enabled operator now implements:

  • JIT_Storage struct for runtime state serialization
  • ToJITStorage() method to convert operators to JIT-compatible form
  • get_jit_class_name() for unique type identification
  • get_jit_op_str() to generate CUDA source code
  • Capability handlers for JIT_TYPE_QUERY, JIT_CLASS_QUERY, and DYN_SHM_SIZE

Build system:
Added MATX_EN_JIT CMake option (default OFF) to enable JIT independently of MathDx, with proper NVRTC library linking.

Confidence Score: 4/5

  • This PR is safe to merge with minor risk - the changes are well-structured but testing is explicitly deferred to a subsequent commit
  • Score reflects: (1) systematic implementation pattern consistently applied across 66 files, (2) proper separation of JIT code with #ifdef MATX_EN_JIT guards ensuring no impact when disabled, (3) architectural improvements like fixing rank 3-4 kernel indexing bugs, but (4) author explicitly states real tests are disabled and will be enabled in a subsequent commit, creating uncertainty about runtime behavior
  • include/matx/executors/jit_kernel.h and include/matx/core/get_grid_dims.h contain critical indexing changes for rank 3-4 that should be carefully tested; include/matx/operators/fftshift.h uses mutable members which warrants extra scrutiny

Important Files Changed

File Analysis

Filename Score Overview
include/matx/executors/jit_cuda.h 4/5 Removed rank <= 2 restriction for JIT compilation, now supports ranks 0-4; refactored parameters for get_grid_dims_jit function signature change
include/matx/core/get_grid_dims.h 4/5 Added JIT support for ranks 3 and 4, fixed block calculation for multi-group-per-block scenarios, added assertion for dimension divisibility
include/matx/executors/jit_kernel.h 4/5 Fixed rank 3 and 4 kernel indexing to properly use threadIdx.y/z and blockDim.y/z; rank 4 kernels now use combined index approach (nmy) to handle 2D dimension mapping
CMakeLists.txt 5/5 Added MATX_EN_JIT build option, refactored NVRTC support to be shared between JIT and MathDx features
include/matx/operators/fftshift.h 4/5 Added JIT support for FFTShift/IFFTShift operators with EPT=ONE constraint; made operator members mutable and added non-const operator() overloads for FFTShift2D

Sequence Diagram

sequenceDiagram
    participant User
    participant JITExecutor as jitCudaExecutor
    participant OpTree as Operator Tree
    participant CapSys as Capability System
    participant GridCalc as Grid Dimensions
    participant KernelProv as Kernel Provider
    participant NVRTC as NVRTC Compiler
    participant GPU

    User->>JITExecutor: exec(operator, stream)
    JITExecutor->>OpTree: Check JIT support capability
    OpTree-->>JITExecutor: SUPPORTS_JIT = true
    
    JITExecutor->>OpTree: Query ELEMENTS_PER_THREAD bounds
    OpTree-->>JITExecutor: EPT bounds [min, max]
    
    JITExecutor->>KernelProv: Create kernel_provider lambda
    Note over KernelProv: Maps EPT values to kernel functions<br/>based on rank (0-4) and stride
    
    JITExecutor->>CapSys: find_best_launch_params(op, kernel_provider)
    CapSys->>GridCalc: Calculate grid dims for each EPT
    GridCalc-->>CapSys: blocks, threads, stride for EPT
    CapSys->>KernelProv: Get kernel pointer for EPT
    KernelProv-->>CapSys: Kernel function pointer
    CapSys-->>JITExecutor: best_ept, shm_size, block_size, groups_per_block
    
    JITExecutor->>GridCalc: get_grid_dims_jit(sizes, best_ept, groups_per_block)
    GridCalc-->>JITExecutor: final blocks, threads, stride
    
    JITExecutor->>OpTree: Query JIT_CLASS_QUERY
    OpTree->>OpTree: Recursively collect operator classes
    OpTree-->>JITExecutor: Map of class definitions
    
    JITExecutor->>OpTree: Query JIT_TYPE_QUERY
    OpTree-->>JITExecutor: Instantiated type string
    
    JITExecutor->>NVRTC: nvrtc_compile_and_run(code, op, params)
    NVRTC->>NVRTC: Generate complete CUDA source
    NVRTC->>NVRTC: Compile with nvrtcCompileProgram
    NVRTC->>GPU: Load and launch compiled kernel
    GPU-->>User: Execution complete
Loading

66 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@cliffburdick cliffburdick merged commit e78cd18 into main Nov 5, 2025
1 check passed
@cliffburdick cliffburdick deleted the all_operator_jit branch November 5, 2025 17:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant