Skip to content

Fix getContigMergeOfInnerSize for device-parallel logical dimensions#5929

Merged
wujingyue merged 3 commits intomainfrom
wjy/align
Feb 6, 2026
Merged

Fix getContigMergeOfInnerSize for device-parallel logical dimensions#5929
wujingyue merged 3 commits intomainfrom
wjy/align

Conversation

@wujingyue
Copy link
Copy Markdown
Collaborator

Fixes #5920

@wujingyue
Copy link
Copy Markdown
Collaborator Author

!test

@github-actions
Copy link
Copy Markdown

github-actions bot commented Feb 6, 2026

Description

  • Fix getContigMergeOfInnerSize to properly handle device-parallel logical dimensions

  • Add conditional logic to distinguish device dimensions from regular dimensions

  • Add comprehensive test case reproducing issue Misaligned memory access with 3D sharding #5920 with vectorization

  • Refactor existing test to improve mesh definition placement

Changes walkthrough

Relevant files
Bug fix
vectorize_helper.cpp
Fix device dimension handling in vectorization helper       

csrc/scheduler/vectorize_helper.cpp

  • Add conditional check for logical_id->isDeviceDim() in
    getContigMergeOfInnerSize
  • Handle device dimensions by setting sharded_extent to container's one
    value
  • Preserve original logic for non-device dimensions using division by
    num_devices
  • +7/-2     
    Tests
    test_multidevice.py
    Add vectorization test and refactor existing test               

    tests/python/multidevice/test_multidevice.py

  • Add new test_pointwise_vectorization function reproducing issue Misaligned memory access with 3D sharding #5920
  • Move mesh definition inside FusionDefinition block in existing test
  • Add comprehensive test with device mesh, parallelization, and
    vectorization
  • +60/-1   

    PR Reviewer Guide

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review
    Logic Correctness

    The fix correctly handles device-parallel logical dimensions by using oneVal() instead of dividing by num_devices. This prevents incorrect sharding calculations for device dimensions. The conditional logic properly distinguishes between device dimensions and other sharded dimensions.

    Val* sharded_extent;
    if (logical_id->isDeviceDim()) {
      sharded_extent = of_tv->container()->oneVal();
    } else {
      sharded_extent = SimplifyingIrBuilder::divExpr(
          getProjectedExtent(logical_id), num_devices);
    }

    Test failures (partial, pipeline still running)

    • (Medium, 3) Shape mismatch in thunderfx higher-order inplace alias update test (thunder/tests/test_update_aliases.py)

      Test Name A100 GB200 H100 Source
      thunder.tests.test_update_aliases.test_higher_order_inplace_alias_update_nvfuser_cuda_thunder.dtypes.float32
    • (Medium, 1) thunder nanogpt scalar mismatch in test_networks (nvFuser CUDA)

      Test Name GB200 Source
      thunder.tests.test_networks.test_nanogpt_complete_autograd_nvfuser_cuda_thunder.dtypes.float32

    @greptile-apps
    Copy link
    Copy Markdown
    Contributor

    greptile-apps bot commented Feb 6, 2026

    Greptile Overview

    Greptile Summary

    • Updates vectorization contiguity analysis (getContigMergeOfInnerSize) to treat device-parallel logical dimensions differently when computing the product of inner extents.
    • Adds an MPI repro test for issue Misaligned memory access with 3D sharding #5920 that uses a 3D DeviceMesh and schedules mesh_x/mesh_y/mesh_z-parallel axes to exercise vectorization on sharded tensors.
    • Minor refactor in existing multidevice pointwise test to construct DeviceMesh inside the fusion definition, consistent with how TVs are configured.

    Confidence Score: 2/5

    • This PR likely needs a correction in the contiguity-size computation before it is safe to merge.
    • The core C++ change appears to drop the extent contribution of device-parallel logical dimensions (setting it to 1), which can miscompute the contiguous inner merge size used for vectorization decisions; the added test may not cover the incorrect case if the intended semantics require including device-dim extent.
    • csrc/scheduler/vectorize_helper.cpp

    Important Files Changed

    Filename Overview
    csrc/scheduler/vectorize_helper.cpp Adjusts sharded extent computation in contiguous-inner-size analysis for device-parallel logical dims; change appears to drop device-dim extent from the product, which can miscompute vectorizable inner span.
    tests/python/multidevice/test_multidevice.py Moves DeviceMesh creation inside fusion definition and adds a new MPI repro test for vectorization with mesh_x/mesh_y/mesh_z parallelized axes; test looks reasonable and skips unsupported world sizes.

    Sequence Diagram

    sequenceDiagram
        participant Test as pytest (MPI)
        participant FD as FusionDefinition
        participant TV as TensorView scheduling
        participant Sched as vectorize_helper.cpp
        participant Exec as fd.execute
    
        Test->>FD: define inp_tv/out_tv (device-parallel layout)
        FD->>TV: set_device_mesh(mesh)
        Test->>TV: parallelize axes (mesh_x/mesh_y/mesh_z)
        Test->>Exec: execute(sharded inp)
        Exec->>Sched: analyze vectorization
        Sched->>Sched: getContigMergeOfInnerSize(of_tv)
        Sched-->>Exec: computed contiguous inner size
        Exec-->>Test: produce sharded output
        Test->>Test: assert_close(output, expected)
    
    Loading

    Copy link
    Copy Markdown
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    2 files reviewed, 1 comment

    Edit Code Review Agent Settings | Greptile

    Comment on lines +832 to +834
    Val* sharded_extent;
    if (logical_id->isDeviceDim()) {
    sharded_extent = of_tv->container()->oneVal();
    Copy link
    Copy Markdown
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Incorrect sharded extent

    In getContigMergeOfInnerSize, when logical_id->isDeviceDim() you set sharded_extent = 1. That makes the contiguity product ignore the local per-rank extent of a device-parallel logical dimension, which can cause the computed contiguous-inner-merge size to be too large and enable vectorization when the local contiguous span is actually smaller. In this function, device dims still contribute their projected extent (already per-rank), so they shouldn’t be dropped from the product.

    Also appears to affect the same logic at csrc/scheduler/vectorize_helper.cpp:835-838 where num_devices division is skipped entirely for device dims.

    getProjectedExtent(logical_id), num_devices);
    Val* sharded_extent;
    if (logical_id->isDeviceDim()) {
    sharded_extent = of_tv->container()->oneVal();
    Copy link
    Copy Markdown
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    IIUC, the patch is for the case where we don't have any expr between logical and allocation. Looked seemed straightforward to me.

    Yet I'm surprised that this didn't get caught earlier.

    @wujingyue wujingyue merged commit 459d852 into main Feb 6, 2026
    56 of 61 checks passed
    @wujingyue wujingyue deleted the wjy/align branch February 6, 2026 17:29
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    Misaligned memory access with 3D sharding

    2 participants