Add Triton fallback for fused_rope_rms (QKNorm+RoPE)#2227
Add Triton fallback for fused_rope_rms (QKNorm+RoPE)#2227
Conversation
Implement fused_rope_rms as a standalone function using existing Triton RMSNorm and RoPE kernels. This unblocks Qwen3-MoE on CK-free builds where the HIP fused kernel is unavailable. The fallback applies per-head RMSNorm via reshape [T,H*D]->[T*H,D] then in-place RoPE, all operating through QKV tensor views.
There was a problem hiding this comment.
Pull request overview
Adds a Triton-based fallback for the fused QKNorm+RoPE path so CK-free/HIP-kernel-missing builds can still run the fused_rope_rms flow.
Changes:
- Enables the previously disabled
fused_rope_rmscall path inRotaryEmbeddingFusedQKNorm.forward(). - Introduces a new Triton fallback implementation combining per-head RMSNorm + in-place RoPE.
- Exports
fused_rope_rmsat the package level and adds GitHub workflows for fork syncing and auto-rebasing PRs.
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 5 comments.
| File | Description |
|---|---|
| aiter/rotary_embedding.py | Implements Triton fallback fused_rope_rms() and wires it into the forward path. |
| aiter/init.py | Re-exports fused_rope_rms from the top-level package. |
| .github/workflows/sync-fork.yml | Adds a scheduled/manual workflow to sync a fork’s default branch. |
| .github/workflows/auto-rebase-prs.yml | Adds a workflow to rebase open PR branches after sync workflows complete. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| k_size = num_heads_k * head_size | ||
| v_size = num_heads_v * head_size | ||
|
|
||
| qkv_2d = qkv.view(num_tokens, q_size + k_size + v_size) |
There was a problem hiding this comment.
Tensor.view() will throw at runtime if qkv is non-contiguous (which can happen after slicing/transpose or some fused ops). Since this is a fallback path meant to be robust, prefer reshape(...) here (or make qkv contiguous before viewing) to avoid hard failures on valid inputs.
| return q, k, v | ||
|
|
||
|
|
||
| def fused_rope_rms( |
There was a problem hiding this comment.
This function is exported publicly (via aiter/__init__.py) but its API contract is currently ambiguous: it mutates qkv in-place and returns None, and callers may expect outputs similar to other embedding helpers. Please document the in-place behavior and expected tensor shapes/dtypes in the docstring (and/or consider returning (q, k) or (q, k, v) views explicitly) so external consumers don’t misuse it.
| """Fused QK-RMSNorm + RoPE on packed QKV tensor (in-place). | ||
| Triton fallback for the HIP fused kernel. | ||
| """ |
There was a problem hiding this comment.
This function is exported publicly (via aiter/__init__.py) but its API contract is currently ambiguous: it mutates qkv in-place and returns None, and callers may expect outputs similar to other embedding helpers. Please document the in-place behavior and expected tensor shapes/dtypes in the docstring (and/or consider returning (q, k) or (q, k, v) views explicitly) so external consumers don’t misuse it.
| name: Sync Fork | ||
|
|
||
| on: |
There was a problem hiding this comment.
This workflow relies on GITHUB_TOKEN to sync and potentially update the default branch, but it does not declare permissions. On many repos/orgs the default token permissions are read-only, causing gh repo sync to fail. Add an explicit permissions: block (at least contents: write) to ensure the sync can push updates.
| name: Auto-Rebase PRs | ||
|
|
||
| on: |
There was a problem hiding this comment.
This workflow force-pushes branches and also comments/labels PRs via gh pr comment / gh pr edit, but it does not declare required permissions. Without explicit permissions, pushes and PR edits commonly fail with the default GITHUB_TOKEN. Add workflow/job permissions such as contents: write (push), pull-requests: write (comment), and issues: write (labels).
gyohuangxin
left a comment
There was a problem hiding this comment.
Hey, are sync-fork.yml and auto-rebase-prs.yml meant for your personal fork? I don't think they should be merged into the upstream repo. Could you drop these two files from the PR?
Summary
fused_rope_rms()as a pure-Triton fallback for the HIP fused QKNorm+RoPE kernelRotaryEmbeddingFusedQKNorm(previously raisedNotImplementedError)fused_rope_rmsfromaiter/__init__.pyMotivation
In CK-free builds (e.g., clean Docker images for faster CI/deployment), the HIP fused kernel
fused_qk_norm_rope_cache_quant_shuffleis unavailable. This Triton implementation enables thefused_rope_rmspath to work without any HIP/CK dependencies, using existing Triton kernels:rmsnorm_forward_inferencefor per-head RMSNormrope_cached_thd_positions_2c_fwd_inplacefor RoPEChanges
aiter/rotary_embedding.py: Addfused_rope_rms()function (67 lines), uncomment call siteaiter/__init__.py: Exportfused_rope_rmsTest plan