fix: only instantiate CrossAttentionBlock when with_cross_attention=True#8848
fix: only instantiate CrossAttentionBlock when with_cross_attention=True#8848chhayankjain wants to merge 1 commit into
Conversation
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (2)
🚧 Files skipped from review as they are similar to previous changes (2)
📝 WalkthroughWalkthroughThe Estimated code review effort🎯 2 (Simple) | ⏱️ ~10 minutes 🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Tip 💬 Introducing Slack Agent: The best way for teams to turn conversations into code.Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.
Built for teams:
One agent for your entire SDLC. Right inside Slack. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
/black |
There was a problem hiding this comment.
🧹 Nitpick comments (1)
monai/networks/blocks/transformerblock.py (1)
42-52: ⚡ Quick winDocument missing parameters in docstring.
The docstring Args section is missing documentation for
causal,sequence_length, andwith_cross_attention. Especially given this PR's focus onwith_cross_attention, documenting its purpose and behavior would improve API clarity. As per coding guidelines, "Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings."📝 Suggested docstring additions
qkv_bias(bool, optional): apply bias term for the qkv linear layer. Defaults to False. save_attn (bool, optional): to make accessible the attention matrix. Defaults to False. + causal (bool, optional): whether to apply causal masking in self-attention. Defaults to False. + sequence_length (int | None, optional): length of the input sequence for causal masking. Defaults to None. + with_cross_attention (bool, optional): whether to include cross-attention layers. When False, + norm_cross_attn and cross_attn modules are not instantiated. Defaults to False. use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@monai/networks/blocks/transformerblock.py` around lines 42 - 52, The docstring for the TransformerBlock constructor is missing descriptions for the parameters causal, sequence_length, and with_cross_attention; update the Args section of the transformerblock.py docstring (the TransformerBlock class / __init__ signature) to add short Google-style entries explaining: causal (bool): whether to apply causal masking for autoregressive attention; sequence_length (int or None): expected input sequence length used for positional/attention shaping or None if dynamic; with_cross_attention (bool): whether to enable an extra cross-attention module that attends to an external memory/key-value input (describe expected shapes/behavior such as when to supply cross-attention inputs). Ensure wording is concise and consistent with the existing arg style and defaults.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Nitpick comments:
In `@monai/networks/blocks/transformerblock.py`:
- Around line 42-52: The docstring for the TransformerBlock constructor is
missing descriptions for the parameters causal, sequence_length, and
with_cross_attention; update the Args section of the transformerblock.py
docstring (the TransformerBlock class / __init__ signature) to add short
Google-style entries explaining: causal (bool): whether to apply causal masking
for autoregressive attention; sequence_length (int or None): expected input
sequence length used for positional/attention shaping or None if dynamic;
with_cross_attention (bool): whether to enable an extra cross-attention module
that attends to an external memory/key-value input (describe expected
shapes/behavior such as when to supply cross-attention inputs). Ensure wording
is concise and consistent with the existing arg style and defaults.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 47a20173-6a43-476f-9df8-fc4ca26ab245
📒 Files selected for processing (2)
monai/networks/blocks/transformerblock.pytests/networks/blocks/test_transformerblock.py
f0103a9 to
5ce030a
Compare
There was a problem hiding this comment.
🧹 Nitpick comments (1)
monai/networks/blocks/transformerblock.py (1)
41-59: ⚡ Quick winAdd a
Raisessection to__init__docstring.
__init__raisesValueError(Line 64, Line 67), but the docstring does not document raised exceptions in Google style.As per coding guidelines "
**/*.py: ... Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings."🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@monai/networks/blocks/transformerblock.py` around lines 41 - 59, The __init__ docstring for TransformerBlock is missing a Google-style "Raises" section documenting the ValueError exceptions; update the class __init__ docstring to include a "Raises" section that lists ValueError and describes the exact conditions under which it's raised (e.g., invalid sequence_length when causal=True and incompatible use_flash_attention/use_combined_linear or other validation checks thrown in __init__); reference the initializer method __init__ and the validation branches that currently raise ValueError so the messages match the actual checks.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Nitpick comments:
In `@monai/networks/blocks/transformerblock.py`:
- Around line 41-59: The __init__ docstring for TransformerBlock is missing a
Google-style "Raises" section documenting the ValueError exceptions; update the
class __init__ docstring to include a "Raises" section that lists ValueError and
describes the exact conditions under which it's raised (e.g., invalid
sequence_length when causal=True and incompatible
use_flash_attention/use_combined_linear or other validation checks thrown in
__init__); reference the initializer method __init__ and the validation branches
that currently raise ValueError so the messages match the actual checks.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 7d188182-b0c0-4d37-ac8c-36c4808cce29
📒 Files selected for processing (2)
monai/networks/blocks/transformerblock.pytests/networks/blocks/test_transformerblock.py
🚧 Files skipped from review as they are similar to previous changes (1)
- tests/networks/blocks/test_transformerblock.py
[200~Fixes Project-MONAI#8845 TransformerBlock previously instantiated norm_cross_attn and cross_attn unconditionally, even when with_cross_attention=False. These unused modules registered dead parameters in model.parameters(), wasting memory. Wrapped both instantiations in `if with_cross_attention:` to match the existing guard in forward(). Added tests to verify the modules and their parameters are absent when disabled, present when enabled, and that the forward pass with a context tensor works correctly.~ Signed-off-by: chhayankjain <chhayank44@gmail.com>
5ce030a to
1645458
Compare
Fixes #8845.
Description
TransformerBlockpreviously instantiatednorm_cross_attnandcross_attnunconditionally in__init__, even whenwith_cross_attention=False. These unused modules registered dead parameters inmodel.parameters(), consuming memory without contributing to computation. Theforward()method already had the correct guard (if self.with_cross_attention:), so the instantiation and the usage were inconsistent.This fix wraps both instantiations in
if with_cross_attention:, so the modules are only created when actually needed.Types of changes