-
Notifications
You must be signed in to change notification settings - Fork 768
Fix excessive memory allocation for static-shape attention ops #2636
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,6 +12,7 @@ | |
| from coremltools.converters.mil._deployment_compatibility import AvailableTarget as target | ||
| from coremltools.converters.mil.mil import Builder as mb | ||
| from coremltools.converters.mil.mil import types | ||
| from coremltools.converters.mil.mil.types.symbolic import is_symbolic | ||
|
|
||
| from coremltools import _logger as logger | ||
|
|
||
|
|
@@ -30,7 +31,17 @@ class scaled_dot_product_attention_sliced_q(AbstractGraphPass): | |
| Defines the size of the chunks of Q being processed in SDPA (chunk_size = seq_length / seq_length_divider) | ||
| """ | ||
|
|
||
| # Default threshold for dynamic-shape models. Dynamic shapes use runtime allocation | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The added comments in this file are far too long. They need to be much concise. |
||
| # which is more flexible, so we use a higher threshold to avoid unnecessary slicing | ||
| # overhead. This preserves the original behavior for dynamic-shape models. | ||
| _DEFAULT_MIN_SEQ_LENGTH: ClassVar[int] = 1280 | ||
|
|
||
| # Lower threshold for static shapes. Static shapes enable compile-time buffer allocation | ||
| # by Metal, which can lead to excessive memory usage on mobile devices if attention | ||
| # intermediates (QK^T matrices) are materialized eagerly. Slicing breaks computation | ||
| # into smaller chunks that can reuse buffers, reducing peak memory allocation and | ||
| # preventing OOM issues. | ||
| _MIN_SEQ_LENGTH_STATIC: ClassVar[int] = 64 | ||
| _DEFAULT_SEQ_LENGTH_DIVIDER: ClassVar[int] = 16 | ||
|
|
||
| _min_seq_length: int | ||
|
|
@@ -106,9 +117,44 @@ def _replace_scaled_dot_product_attention(self, op): | |
|
|
||
| q_size = len(q.shape) | ||
| q_seq_length = q.shape[-2] | ||
| if q_seq_length < self._min_seq_length: | ||
|
|
||
| # Determine if the sequence length is statically known (compile-time constant). | ||
| # Static shapes enable compile-time buffer allocation by Metal, which can lead to | ||
| # excessive memory usage on mobile devices. For attention ops, Metal eagerly | ||
| # allocates large intermediate buffers (QK^T matrices) based on the full sequence | ||
| # length, potentially causing OOM. Slicing prevents this by breaking computation | ||
| # into smaller chunks that can reuse buffers, reducing peak memory allocation. | ||
| # | ||
| # Dynamic shapes use runtime allocation which is more flexible and doesn't suffer | ||
| # from the same eager allocation issues. We preserve the original behavior for | ||
| # dynamic shapes (threshold of 1280) to avoid unnecessary slicing overhead. | ||
| is_static_seq_length = not is_symbolic(q_seq_length) | ||
|
|
||
| if is_static_seq_length: | ||
| # Static shape: use lower threshold (_MIN_SEQ_LENGTH_STATIC = 64) to prevent | ||
| # eager buffer allocation issues. With static shapes, Metal pre-allocates | ||
| # buffers based on the full sequence length, which can be problematic for | ||
| # attention ops that create large QK^T intermediates. Slicing reduces peak | ||
| # memory by processing in chunks. | ||
| seq_length_value = int(q_seq_length) | ||
| min_seq_length_threshold = self._MIN_SEQ_LENGTH_STATIC | ||
|
|
||
| if seq_length_value < min_seq_length_threshold: | ||
| logger.debug( | ||
| f"skipping SDPA op, Q seq_length is {seq_length_value} (static, " | ||
| f"minimum seq length needed: {min_seq_length_threshold})" | ||
| ) | ||
| return | ||
| else: | ||
| # Dynamic shape: preserve original behavior with default threshold (1280). | ||
| # We cannot compare symbolic values at compile time, so we skip slicing | ||
| # for dynamic shapes. This maintains the original behavior where only | ||
| # very long sequences (>= 1280) would be sliced. Runtime allocation for | ||
| # dynamic shapes is more flexible and doesn't suffer from the same eager | ||
| # allocation issues as static shapes, so the higher threshold is appropriate. | ||
| logger.debug( | ||
| f"skipping SDPA op, Q seq_length is {q_seq_length} (minimum seq length needed: {self._min_seq_length}" | ||
| f"skipping SDPA op, Q seq_length is dynamic (symbolic), " | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This shouldn't be a f-string since there is no variable being used.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looks like this is also an issue in several other places of this PR. |
||
| f"preserving original behavior (threshold: {self._min_seq_length})" | ||
| ) | ||
| return | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -50,6 +50,9 @@ | |
| "common::loop_invariant_elimination", | ||
| "common::remove_symbolic_reshape", | ||
| "common::noop_elimination", | ||
| # Apply attention slicing early to reduce memory allocation for static sequence lengths. | ||
| # This pass replaces scaled_dot_product_attention with a memory-efficient sliced implementation. | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove this line of the comment. It doesn't really add much and can easily become outdated/inaccurate. |
||
| "common::scaled_dot_product_attention_sliced_q", | ||
| "common::fuse_matmul_weight_bias", | ||
| "common::fuse_linear_bias", | ||
| "common::fuse_gelu_tanh_approximation", | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think you need this line.