Skip to content

fix: only instantiate CrossAttentionBlock when with_cross_attention=True#8848

Open
chhayankjain wants to merge 1 commit into
Project-MONAI:devfrom
chhayankjain:8845-fix-transformer-cross-attention-init
Open

fix: only instantiate CrossAttentionBlock when with_cross_attention=True#8848
chhayankjain wants to merge 1 commit into
Project-MONAI:devfrom
chhayankjain:8845-fix-transformer-cross-attention-init

Conversation

@chhayankjain
Copy link
Copy Markdown
Contributor

Fixes #8845.

Description

TransformerBlock previously instantiated norm_cross_attn and cross_attn unconditionally in __init__, even when with_cross_attention=False. These unused modules registered dead parameters in model.parameters(), consuming memory without contributing to computation. The forward() method already had the correct guard (if self.with_cross_attention:), so the instantiation and the usage were inconsistent.

This fix wraps both instantiations in if with_cross_attention:, so the modules are only created when actually needed.

Types of changes

  • Non-breaking change (fix or new feature that would not break existing functionality).
  • New tests added to cover the changes.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 12, 2026

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 4a52cb4e-d7b8-4e51-bc6b-ff6d7148295a

📥 Commits

Reviewing files that changed from the base of the PR and between 5ce030a and 1645458.

📒 Files selected for processing (2)
  • monai/networks/blocks/transformerblock.py
  • tests/networks/blocks/test_transformerblock.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • monai/networks/blocks/transformerblock.py
  • tests/networks/blocks/test_transformerblock.py

📝 Walkthrough

Walkthrough

The TransformerBlock constructor now creates norm_cross_attn and cross_attn only when with_cross_attention=True; these modules are not instantiated or registered when with_cross_attention=False. The docstring for with_cross_attention was updated to state this. Tests were added to assert cross-attention modules/parameters are absent when disabled, present when enabled, and that a forward pass with context yields an output with the same shape as the input when cross-attention is enabled.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~10 minutes

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 14.29% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed Title clearly summarizes the main change: conditionally instantiating cross-attention modules only when with_cross_attention=True.
Description check ✅ Passed Description covers the problem, solution, and confirms non-breaking change with new tests. Properly formatted with required sections.
Linked Issues check ✅ Passed Changes fully address issue #8845: cross-attention modules now instantiate only when with_cross_attention=True, preventing unused parameter registration.
Out of Scope Changes check ✅ Passed All changes are scoped to the issue: conditional instantiation in TransformerBlock and corresponding unit tests. No extraneous modifications.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Tip

💬 Introducing Slack Agent: The best way for teams to turn conversations into code.

Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.

  • Generate code and open pull requests
  • Plan features and break down work
  • Investigate incidents and troubleshoot customer tickets together
  • Automate recurring tasks and respond to alerts with triggers
  • Summarize progress and report instantly

Built for teams:

  • Shared memory across your entire org—no repeating context
  • Per-thread sandboxes to safely plan and execute work
  • Governance built-in—scoped access, auditability, and budget controls

One agent for your entire SDLC. Right inside Slack.

👉 Get started


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@chhayankjain
Copy link
Copy Markdown
Contributor Author

/black

@chhayankjain chhayankjain marked this pull request as ready for review May 12, 2026 03:08
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
monai/networks/blocks/transformerblock.py (1)

42-52: ⚡ Quick win

Document missing parameters in docstring.

The docstring Args section is missing documentation for causal, sequence_length, and with_cross_attention. Especially given this PR's focus on with_cross_attention, documenting its purpose and behavior would improve API clarity. As per coding guidelines, "Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings."

📝 Suggested docstring additions
             qkv_bias(bool, optional): apply bias term for the qkv linear layer. Defaults to False.
             save_attn (bool, optional): to make accessible the attention matrix. Defaults to False.
+            causal (bool, optional): whether to apply causal masking in self-attention. Defaults to False.
+            sequence_length (int | None, optional): length of the input sequence for causal masking. Defaults to None.
+            with_cross_attention (bool, optional): whether to include cross-attention layers. When False,
+                norm_cross_attn and cross_attn modules are not instantiated. Defaults to False.
             use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@monai/networks/blocks/transformerblock.py` around lines 42 - 52, The
docstring for the TransformerBlock constructor is missing descriptions for the
parameters causal, sequence_length, and with_cross_attention; update the Args
section of the transformerblock.py docstring (the TransformerBlock class /
__init__ signature) to add short Google-style entries explaining: causal (bool):
whether to apply causal masking for autoregressive attention; sequence_length
(int or None): expected input sequence length used for positional/attention
shaping or None if dynamic; with_cross_attention (bool): whether to enable an
extra cross-attention module that attends to an external memory/key-value input
(describe expected shapes/behavior such as when to supply cross-attention
inputs). Ensure wording is concise and consistent with the existing arg style
and defaults.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Nitpick comments:
In `@monai/networks/blocks/transformerblock.py`:
- Around line 42-52: The docstring for the TransformerBlock constructor is
missing descriptions for the parameters causal, sequence_length, and
with_cross_attention; update the Args section of the transformerblock.py
docstring (the TransformerBlock class / __init__ signature) to add short
Google-style entries explaining: causal (bool): whether to apply causal masking
for autoregressive attention; sequence_length (int or None): expected input
sequence length used for positional/attention shaping or None if dynamic;
with_cross_attention (bool): whether to enable an extra cross-attention module
that attends to an external memory/key-value input (describe expected
shapes/behavior such as when to supply cross-attention inputs). Ensure wording
is concise and consistent with the existing arg style and defaults.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 47a20173-6a43-476f-9df8-fc4ca26ab245

📥 Commits

Reviewing files that changed from the base of the PR and between 586dea1 and f0103a9.

📒 Files selected for processing (2)
  • monai/networks/blocks/transformerblock.py
  • tests/networks/blocks/test_transformerblock.py

@chhayankjain chhayankjain force-pushed the 8845-fix-transformer-cross-attention-init branch from f0103a9 to 5ce030a Compare May 12, 2026 03:16
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
monai/networks/blocks/transformerblock.py (1)

41-59: ⚡ Quick win

Add a Raises section to __init__ docstring.

__init__ raises ValueError (Line 64, Line 67), but the docstring does not document raised exceptions in Google style.

As per coding guidelines "**/*.py: ... Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings."

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@monai/networks/blocks/transformerblock.py` around lines 41 - 59, The __init__
docstring for TransformerBlock is missing a Google-style "Raises" section
documenting the ValueError exceptions; update the class __init__ docstring to
include a "Raises" section that lists ValueError and describes the exact
conditions under which it's raised (e.g., invalid sequence_length when
causal=True and incompatible use_flash_attention/use_combined_linear or other
validation checks thrown in __init__); reference the initializer method __init__
and the validation branches that currently raise ValueError so the messages
match the actual checks.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Nitpick comments:
In `@monai/networks/blocks/transformerblock.py`:
- Around line 41-59: The __init__ docstring for TransformerBlock is missing a
Google-style "Raises" section documenting the ValueError exceptions; update the
class __init__ docstring to include a "Raises" section that lists ValueError and
describes the exact conditions under which it's raised (e.g., invalid
sequence_length when causal=True and incompatible
use_flash_attention/use_combined_linear or other validation checks thrown in
__init__); reference the initializer method __init__ and the validation branches
that currently raise ValueError so the messages match the actual checks.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 7d188182-b0c0-4d37-ac8c-36c4808cce29

📥 Commits

Reviewing files that changed from the base of the PR and between f0103a9 and 5ce030a.

📒 Files selected for processing (2)
  • monai/networks/blocks/transformerblock.py
  • tests/networks/blocks/test_transformerblock.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • tests/networks/blocks/test_transformerblock.py

[200~Fixes Project-MONAI#8845

  TransformerBlock previously instantiated norm_cross_attn and cross_attn
  unconditionally, even when with_cross_attention=False. These unused modules
  registered dead parameters in model.parameters(), wasting memory.

  Wrapped both instantiations in `if with_cross_attention:` to match the
  existing guard in forward(). Added tests to verify the modules and their
  parameters are absent when disabled, present when enabled, and that the
  forward pass with a context tensor works correctly.~

Signed-off-by: chhayankjain <chhayank44@gmail.com>
@chhayankjain chhayankjain force-pushed the 8845-fix-transformer-cross-attention-init branch from 5ce030a to 1645458 Compare May 12, 2026 03:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

TransformerBlock instantiates CrossAttentionBlock even when with_cross_attention=False

1 participant