Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 48 additions & 2 deletions coremltools/converters/mil/mil/passes/defs/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from coremltools.converters.mil._deployment_compatibility import AvailableTarget as target
from coremltools.converters.mil.mil import Builder as mb
from coremltools.converters.mil.mil import types
from coremltools.converters.mil.mil.types.symbolic import is_symbolic
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think you need this line.


from coremltools import _logger as logger

Expand All @@ -30,7 +31,17 @@ class scaled_dot_product_attention_sliced_q(AbstractGraphPass):
Defines the size of the chunks of Q being processed in SDPA (chunk_size = seq_length / seq_length_divider)
"""

# Default threshold for dynamic-shape models. Dynamic shapes use runtime allocation
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The added comments in this file are far too long. They need to be much concise.

# which is more flexible, so we use a higher threshold to avoid unnecessary slicing
# overhead. This preserves the original behavior for dynamic-shape models.
_DEFAULT_MIN_SEQ_LENGTH: ClassVar[int] = 1280

# Lower threshold for static shapes. Static shapes enable compile-time buffer allocation
# by Metal, which can lead to excessive memory usage on mobile devices if attention
# intermediates (QK^T matrices) are materialized eagerly. Slicing breaks computation
# into smaller chunks that can reuse buffers, reducing peak memory allocation and
# preventing OOM issues.
_MIN_SEQ_LENGTH_STATIC: ClassVar[int] = 64
_DEFAULT_SEQ_LENGTH_DIVIDER: ClassVar[int] = 16

_min_seq_length: int
Expand Down Expand Up @@ -106,9 +117,44 @@ def _replace_scaled_dot_product_attention(self, op):

q_size = len(q.shape)
q_seq_length = q.shape[-2]
if q_seq_length < self._min_seq_length:

# Determine if the sequence length is statically known (compile-time constant).
# Static shapes enable compile-time buffer allocation by Metal, which can lead to
# excessive memory usage on mobile devices. For attention ops, Metal eagerly
# allocates large intermediate buffers (QK^T matrices) based on the full sequence
# length, potentially causing OOM. Slicing prevents this by breaking computation
# into smaller chunks that can reuse buffers, reducing peak memory allocation.
#
# Dynamic shapes use runtime allocation which is more flexible and doesn't suffer
# from the same eager allocation issues. We preserve the original behavior for
# dynamic shapes (threshold of 1280) to avoid unnecessary slicing overhead.
is_static_seq_length = not is_symbolic(q_seq_length)

if is_static_seq_length:
# Static shape: use lower threshold (_MIN_SEQ_LENGTH_STATIC = 64) to prevent
# eager buffer allocation issues. With static shapes, Metal pre-allocates
# buffers based on the full sequence length, which can be problematic for
# attention ops that create large QK^T intermediates. Slicing reduces peak
# memory by processing in chunks.
seq_length_value = int(q_seq_length)
min_seq_length_threshold = self._MIN_SEQ_LENGTH_STATIC

if seq_length_value < min_seq_length_threshold:
logger.debug(
f"skipping SDPA op, Q seq_length is {seq_length_value} (static, "
f"minimum seq length needed: {min_seq_length_threshold})"
)
return
else:
# Dynamic shape: preserve original behavior with default threshold (1280).
# We cannot compare symbolic values at compile time, so we skip slicing
# for dynamic shapes. This maintains the original behavior where only
# very long sequences (>= 1280) would be sliced. Runtime allocation for
# dynamic shapes is more flexible and doesn't suffer from the same eager
# allocation issues as static shapes, so the higher threshold is appropriate.
logger.debug(
f"skipping SDPA op, Q seq_length is {q_seq_length} (minimum seq length needed: {self._min_seq_length}"
f"skipping SDPA op, Q seq_length is dynamic (symbolic), "
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This shouldn't be a f-string since there is no variable being used.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like this is also an issue in several other places of this PR.

f"preserving original behavior (threshold: {self._min_seq_length})"
)
return

Expand Down
3 changes: 3 additions & 0 deletions coremltools/converters/mil/mil/passes/pass_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@
"common::loop_invariant_elimination",
"common::remove_symbolic_reshape",
"common::noop_elimination",
# Apply attention slicing early to reduce memory allocation for static sequence lengths.
# This pass replaces scaled_dot_product_attention with a memory-efficient sliced implementation.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this line of the comment. It doesn't really add much and can easily become outdated/inaccurate.

"common::scaled_dot_product_attention_sliced_q",
"common::fuse_matmul_weight_bias",
"common::fuse_linear_bias",
"common::fuse_gelu_tanh_approximation",
Expand Down
Loading