Skip to content

Comments

Support multimodule pipelining in 1F1B schedule#3129

Open
yashaswikarnati wants to merge 17 commits intoNVIDIA:mainfrom
yashaswikarnati:yash/1f1b_changes
Open

Support multimodule pipelining in 1F1B schedule#3129
yashaswikarnati wants to merge 17 commits intoNVIDIA:mainfrom
yashaswikarnati:yash/1f1b_changes

Conversation

@yashaswikarnati
Copy link
Contributor

@yashaswikarnati yashaswikarnati commented Jan 28, 2026

Summary

Adds support for multi-module pipeline parallelism (encoder + LLM) in the 1F1B schedule.

Changes:

  • Add MultiModuleProcessGroupCollection for managing process groups across modules
  • Support dict-based tensor format {module_name: tensor} in forward/backward
  • Handle 2D/3D tensor conversion for P2P and bridge communication
  • Add backward_step_multimodule to handle backward for multimodule cases

⚠️ For major changes (either in lines of code or in its impact), please make sure to first share a design doc with the team. If you're unsure what's the best way to do so, contact the @mcore-oncall.

Contribution process

flowchart LR
    A[Pre-checks] --> B[PR Tests]
    subgraph Code Review/Approval
        C1[Expert Review] --> C2[Final Review]
    end
    B --> C1
    C2 --> D[Merge]
Loading

Pre-checks

  • I want this PR in a versioned release and have added the appropriate Milestone (e.g., Core 0.8)
  • I have added relevant unit tests
  • I have added relevant functional tests
  • I have added proper typing to my code Typing guidelines
  • I have added relevant documentation
  • I have run the autoformatter.sh on my PR

Code review

The following process is enforced via the CODEOWNERS file for changes into megatron/core. For changes outside of megatron/core, it is up to the PR author whether or not to tag the Final Reviewer team.

For MRs into `main` branch

Feel free to message or comment the @mcore-oncall to help accelerate your merge into main. The less complex your PR is, the faster it will be approved and merged!

(Step 1): Add PR label Expert Review

(Step 2): Collect the expert reviewers reviews

  1. Attach the Expert Review label when your PR is ready for review.
  2. GitHub auto-assigns expert reviewers based on your changes. They will get notified and pick up your PR soon.

⚠️ Only proceed to the next step once all reviewers have approved, merge-conflict are resolved and the CI is passing.
Final Review might get declined if these requirements are not fulfilled.

(Step 3): Final Review

  1. Add Final Review label
  2. GitHub auto-assigns final reviewers based on your changes. They will get notified and pick up your PR soon.

(Optional Step 4): Cherry-pick into release branch

If this PR also needs to be merged into core_r* release branches, after this PR has been merged, select Cherry-pick to open a new PR into the release branch.

For MRs into `dev` branch The proposed review process for `dev` branch is under active discussion.

MRs are mergable after one approval by either eharper@nvidia.com or zijiey@nvidia.com.

Merging your PR

Any member of core-adlr and core-nemo will be able to merge your PR.

- Rename ProcessGroupCollectionWrapper to MultiModuleProcessGroupCollection
- Rename language_model field to language_model_module_name for clarity
- Add language_model_module_name param to backward_step_multimodule
- Use functools.partial to bind param, keeping signature consistent
- Add type hints to _ensure_3d_tensor and _restore_tensor_shape
- Move is_multimodule check earlier for validation and backward selection
@yashaswikarnati yashaswikarnati requested review from a team as code owners January 28, 2026 22:53
@copy-pr-bot
Copy link

copy-pr-bot bot commented Jan 28, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@ko3n1g ko3n1g requested a review from a team January 28, 2026 22:54
@dimapihtar dimapihtar added complexity: high Expert Review Apply this label to indicate that your PR is ready for expert review. labels Jan 29, 2026
@dimapihtar
Copy link
Contributor

/ok to test 2d7c176

Returns:
3D tensor (with singleton last dim if input was 2D), list of 3D tensors, or None.
"""
if tensor is None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you assert fail if a 3d tensor passed in and its last_dim.size != 1

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think for 3D tensor its a no op. any 3d tensor is fine (last dim can be 1 and no assert is needed). will make it clear in the doc string. thx

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if original 3d and last dim==1, _restore_tensor_from_comm will make it a 2d?
dim in = [a, b, 1], dim out will be [a, b]?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good call out! i added an assert in the prepare comm to prevent this ee189df

input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
deallocate_output_tensor(output_tensor[0], config.deallocate_pipeline_outputs)
deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why [0] is removed here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to consistently handle individual tensors or lists or dicts we have these checks in the deallocate output tensor. just passing the first element is indeed not needed

@dimapihtar dimapihtar requested a review from erhoo82 February 4, 2026 15:10
@shifangx
Copy link
Contributor

shifangx commented Feb 14, 2026

Hi, @yaoyu-33, @yashaswikarnati, Could you help increase the priority of this PR?
As for as I know, #3129 is the last functionality PR for M4, and it is fundamental for DistTrain.

There are also other leftover prs for M4, but #3129 is the most importance one.

@shifangx
Copy link
Contributor

/ok to test 597862e

return block


def create_module_with_grid(tp, pp, dp, grid_offset, hidden_size):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe create_module_and_grid is more appropriate, because this function create grid, not use grid to create model.

Signed-off-by: ykarnati <ykarnati@nvidia.com>
@yashaswikarnati
Copy link
Contributor Author

/ok to test 5f941d1

@shifangx
Copy link
Contributor

shifangx commented Feb 23, 2026

Hi, @yaoyu-33, @yashaswikarnati, Could you help increase the priority of this PR? As for as I know, #3129 is the last functionality PR for M4, and it is fundamental for DistTrain.

There are also other leftover prs for M4, but #3129 is the most importance one.

Hi, @yashaswikarnati , what is the next step of this pr?
@NVIDIA/core-adlr , @NVIDIA/pipeline-parallelism , @NVIDIA/mcore-oncall , @erhoo82 , can you help review this pr?

@shifangx
Copy link
Contributor

/ok to test 908ea5f

@shifangx
Copy link
Contributor

Hi, @yashaswikarnati, can you help to address CI test issue?
Some test cases failed.


# Apply grad scaling if needed (for last stage only)
for module_name in output_tensor.keys():
if output_tensor_grad[module_name] is None and config.grad_scale_func is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just out of curiosity, why is the scaling only applied when the gradient is None?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't this break when using gradient accumulation?

@yashaswikarnati
Copy link
Contributor Author

/ok to test ee189df

@yashaswikarnati
Copy link
Contributor Author

Hi, @yashaswikarnati, can you help to address CI test issue? Some test cases failed.

test failures look unrelated and all tests pass locally, retriggered again

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

complexity: high Expert Review Apply this label to indicate that your PR is ready for expert review.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants