-
Notifications
You must be signed in to change notification settings - Fork 235
Add Triton fallback for fused_rope_rms (QKNorm+RoPE) #2227
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
66765bd
1a4dec2
e83e6b5
a813eb0
4f2e07d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,72 @@ | ||
| name: Auto-Rebase PRs | ||
|
|
||
| on: | ||
| workflow_run: | ||
| workflows: ["Sync Fork", "Sync Upstream"] | ||
| types: [completed] | ||
| workflow_dispatch: | ||
|
|
||
| jobs: | ||
| rebase: | ||
| if: ${{ github.event_name == 'workflow_dispatch' || github.event.workflow_run.conclusion == 'success' }} | ||
| runs-on: ubuntu-latest | ||
| steps: | ||
| - uses: actions/checkout@v4 | ||
| with: | ||
| fetch-depth: 0 | ||
| token: ${{ secrets.GITHUB_TOKEN }} | ||
|
|
||
| - name: Configure git | ||
| run: | | ||
| git config user.name "github-actions[bot]" | ||
| git config user.email "github-actions[bot]@users.noreply.github.com" | ||
|
|
||
| - name: Rebase open PRs | ||
| env: | ||
| GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} | ||
| run: | | ||
| DEFAULT_BRANCH=$(gh api "repos/${{ github.repository }}" --jq '.default_branch') | ||
| echo "Default branch: $DEFAULT_BRANCH" | ||
|
|
||
| # Get all open PRs authored by the repo owner | ||
| PRS=$(gh pr list --state open --json number,headRefName,mergeable --jq '.[] | "\(.number) \(.headRefName) \(.mergeable)"') | ||
|
|
||
| if [ -z "$PRS" ]; then | ||
| echo "No open PRs found" | ||
| exit 0 | ||
| fi | ||
|
|
||
| echo "$PRS" | while read -r pr_number branch mergeable; do | ||
| echo "" | ||
| echo "=== PR #${pr_number} (${branch}) mergeable=${mergeable} ===" | ||
|
|
||
| # Fetch and checkout the PR branch | ||
| if ! git fetch origin "$branch" 2>/dev/null; then | ||
| echo " SKIP: branch $branch not found on origin" | ||
| continue | ||
| fi | ||
| git checkout "$branch" | ||
| git reset --hard "origin/$branch" | ||
|
|
||
| # Check if rebase is needed | ||
| git fetch origin "$DEFAULT_BRANCH" | ||
| if git merge-base --is-ancestor "origin/$DEFAULT_BRANCH" HEAD; then | ||
| echo " OK: already up to date with $DEFAULT_BRANCH" | ||
| continue | ||
| fi | ||
|
|
||
| # Attempt rebase | ||
| echo " Rebasing onto origin/$DEFAULT_BRANCH..." | ||
| if git rebase "origin/$DEFAULT_BRANCH" 2>/dev/null; then | ||
| echo " Pushing rebased branch..." | ||
| git push --force-with-lease origin "$branch" | ||
| echo " REBASED: PR #${pr_number} successfully rebased" | ||
| gh pr comment "$pr_number" --body "Auto-rebased onto \`${DEFAULT_BRANCH}\` after nightly upstream sync." 2>/dev/null || true | ||
| else | ||
| git rebase --abort 2>/dev/null || true | ||
| echo " CONFLICT: PR #${pr_number} has merge conflicts" | ||
| # Label the PR for manual attention | ||
| gh pr edit "$pr_number" --add-label "needs-rebase" 2>/dev/null || true | ||
| gh pr comment "$pr_number" --body "Auto-rebase failed due to merge conflicts with \`${DEFAULT_BRANCH}\`. Manual rebase needed." 2>/dev/null || true | ||
| fi | ||
| done | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,24 @@ | ||
| name: Sync Fork | ||
|
|
||
| on: | ||
|
Comment on lines
+1
to
+3
|
||
| schedule: | ||
| - cron: '0 6 * * *' # 6am UTC daily (before dashboard collection at 8am) | ||
| workflow_dispatch: | ||
|
|
||
| jobs: | ||
| sync: | ||
| runs-on: ubuntu-latest | ||
| steps: | ||
| - name: Sync fork default branch with upstream | ||
| run: | | ||
| # Try fast-forward sync first; fall back to force sync if diverged. | ||
| # Safe because feature work lives on branches, not the default branch. | ||
| if gh repo sync "${{ github.repository }}" 2>/dev/null; then | ||
| echo "Synced successfully (fast-forward)" | ||
| else | ||
| echo "Diverging commits detected — force syncing to match upstream" | ||
| gh repo sync "${{ github.repository }}" --force | ||
| echo "Force synced successfully" | ||
| fi | ||
| env: | ||
| GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1293,21 +1293,20 @@ def forward( | |
| else: | ||
| return q_out, None, None | ||
| else: | ||
| raise NotImplementedError("fused_rope_rms not supported yet") | ||
| # fused_rope_rms( | ||
| # qkv, | ||
| # q_weight, | ||
| # k_weight, | ||
| # self.cos_sin_cache, | ||
| # positions, | ||
| # num_tokens, | ||
| # num_heads_q, | ||
| # num_heads_k, | ||
| # num_heads_v, | ||
| # self.head_size, | ||
| # self.is_neox_style, | ||
| # eps, | ||
| # ) | ||
| fused_rope_rms( | ||
| qkv, | ||
| q_weight, | ||
| k_weight, | ||
| self.cos_sin_cache, | ||
| positions, | ||
| num_tokens, | ||
| num_heads_q, | ||
| num_heads_k, | ||
| num_heads_v, | ||
| self.head_size, | ||
| self.is_neox_style, | ||
| eps, | ||
| ) | ||
| q_size = num_heads_q * self.head_size | ||
| k_size = num_heads_k * self.head_size | ||
| v_size = num_heads_v * self.head_size | ||
|
|
@@ -1318,6 +1317,67 @@ def forward( | |
| return q, k, v | ||
|
|
||
|
|
||
| def fused_rope_rms( | ||
|
||
| qkv, | ||
| q_weight, | ||
| k_weight, | ||
| cos_sin_cache, | ||
| positions, | ||
| num_tokens, | ||
| num_heads_q, | ||
| num_heads_k, | ||
| num_heads_v, | ||
| head_size, | ||
| is_neox_style, | ||
| eps, | ||
| ): | ||
| """Fused QK-RMSNorm + RoPE on packed QKV tensor (in-place). | ||
| Triton fallback for the HIP fused kernel. | ||
| """ | ||
|
Comment on lines
+1334
to
+1336
|
||
| from aiter.ops.triton.normalization.rmsnorm import rmsnorm_forward_inference | ||
| from aiter.ops.triton.rope.rope import ( | ||
| rope_cached_thd_positions_2c_fwd_inplace, | ||
| ) | ||
|
|
||
| q_size = num_heads_q * head_size | ||
| k_size = num_heads_k * head_size | ||
| v_size = num_heads_v * head_size | ||
|
|
||
| qkv_2d = qkv.view(num_tokens, q_size + k_size + v_size) | ||
|
||
| q, k, _v = qkv_2d.split([q_size, k_size, v_size], dim=-1) | ||
|
|
||
| # Per-head RMSNorm: [T, H*D] -> [T*H, D] so rmsnorm operates per-head | ||
| q_normed = rmsnorm_forward_inference( | ||
| q.reshape(num_tokens * num_heads_q, head_size), q_weight, eps | ||
| ) | ||
| q.copy_(q_normed.view(num_tokens, q_size)) | ||
|
|
||
| k_normed = rmsnorm_forward_inference( | ||
| k.reshape(num_tokens * num_heads_k, head_size), k_weight, eps | ||
| ) | ||
| k.copy_(k_normed.view(num_tokens, k_size)) | ||
|
|
||
| # RoPE in-place | ||
| q_rope = q.view(num_tokens, num_heads_q, head_size) | ||
| k_rope = k.view(num_tokens, num_heads_k, head_size) | ||
|
|
||
| half = cos_sin_cache.shape[-1] // 2 | ||
| cos = cos_sin_cache[:, :half] | ||
| sin = cos_sin_cache[:, half:] | ||
| rotate_style = 0 if is_neox_style else 1 | ||
|
|
||
| rope_cached_thd_positions_2c_fwd_inplace( | ||
| q_rope, | ||
| k_rope, | ||
| cos, | ||
| sin, | ||
| positions, | ||
| rotate_style, | ||
| reuse_freqs_front_part=True, | ||
| nope_first=False, | ||
| ) | ||
|
|
||
|
|
||
| class MRotaryEmbeddingQKNormFused(RotaryEmbeddingFusedQKNorm): | ||
| """Rotary Embedding with Multimodal Sections fused with QKNorm""" | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This workflow force-pushes branches and also comments/labels PRs via
gh pr comment/gh pr edit, but it does not declare requiredpermissions. Without explicit permissions, pushes and PR edits commonly fail with the defaultGITHUB_TOKEN. Add workflow/job permissions such ascontents: write(push),pull-requests: write(comment), andissues: write(labels).