Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions .github/workflows/flash_attention_integration.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ env:
# TODO: Switch to Dao-AILab/flash-attention main
FA_BRANCH: micmelesse/aiter_migration
FA_REPOSITORY_URL: https://github.com/ROCm/flash-attention.git
GPU_ARCH: gfx950
BASE_IMAGE: rocm/pytorch:latest@sha256:683765a52c61341e1674fe730ab3be861a444a45a36c0a8caae7653a08a0e208
AITER_SUBMODULE_PATH: third_party/aiter

Expand Down Expand Up @@ -90,9 +89,17 @@ jobs:
# =============================================================================
flash_attention_triton:
if: ${{ needs.prechecks.outputs.run_triton == 'true' }}
name: Flash Attention - Triton (1 GPU)
name: Flash Attention - Triton / ${{ matrix.label }} (1 GPU)
needs: [check-signal, prechecks]
runs-on: linux-aiter-mi355-1
runs-on: ${{ matrix.runner }}
strategy:
fail-fast: false
matrix:
include:
- runner: linux-aiter-mi355-1
label: MI355
- runner: aiter-gfx1100
label: RDNA3

steps:
- name: Checkout aiter repo
Expand Down Expand Up @@ -187,14 +194,14 @@ jobs:
cd /flash-attention
FLASH_ATTENTION_TRITON_AMD_ENABLE=TRUE \
python benchmarks/benchmark_flash_attention.py
" |& tee benchmark_triton.log
" |& tee benchmark_triton_${{ matrix.label }}.log

- name: Upload benchmark results
if: success()
uses: actions/upload-artifact@v4
with:
name: flash-attention-triton-benchmark
path: benchmark_triton.log
name: flash-attention-triton-benchmark-${{ matrix.label }}
path: benchmark_triton_${{ matrix.label }}.log

- name: Clean Up
if: always()
Expand Down