-
Notifications
You must be signed in to change notification settings - Fork 80
Misaligned memory access with 3D sharding #5920
Copy link
Copy link
Closed
Labels
Multi-GPUallocation domainissues related to allocation domain supportissues related to allocation domain support
Description
Repro: #5919
With NVFUSER_DUMP=fusion_ir,
T1_g___bfloat[iblockIdx.x42{( ceilDiv(( ceilDiv(( i2 * ( i0 * ( ( ceilDiv(i3, 2) ) * 128 ) ) ), 2) ), 128) )}, iUS41{1}, iV39{2}, ithreadIdx.x43{128}, ideviceIdx.y32{2}, ideviceIdx.x33{2}] ca_pos( 2 ) produce_pos( 2 ) (DeviceMesh{{0 1}{2 3}})
logical domain: (iS28{i0}, iS29{128}, iS30{i2}, iS31{i3}, ideviceIdx.y32{2})
allocation domain: (iS30{i2}, iS28{i0}, iS29{128}, ideviceIdx.x33{2}, iS34{( ceilDiv(i3, 2) )}, ideviceIdx.y32{2})
contiguity: t t t t t t
Outer split: iS31{i3} by factor 2 -> ideviceIdx.x33{2}, iS34{( ceilDiv(i3, 2) )}
Merge: iS29{128} and iS34{( ceilDiv(i3, 2) )} -> iS35{( ( ceilDiv(i3, 2) ) * 128 )}
Merge: iS28{i0} and iS35{( ( ceilDiv(i3, 2) ) * 128 )} -> iS36{( i0 * ( ( ceilDiv(i3, 2) ) * 128 ) )}
Merge: iS30{i2} and iS36{( i0 * ( ( ceilDiv(i3, 2) ) * 128 ) )} -> iS37{( i2 * ( i0 * ( ( ceilDiv(i3, 2) ) * 128 ) ) )}
Split: iS37{( i2 * ( i0 * ( ( ceilDiv(i3, 2) ) * 128 ) ) )} by factor 2 -> iS38{( ceilDiv(( i2 * ( i0 * ( ( ceilDiv(i3, 2) ) * 128 ) ) ), 2) )}, iV39{2}
Split: iS38{( ceilDiv(( i2 * ( i0 * ( ( ceilDiv(i3, 2) ) * 128 ) ) ), 2) )} by factor 1 -> iS40{( ceilDiv(( i2 * ( i0 * ( ( ceilDiv(i3, 2) ) * 128 ) ) ), 2) )}, iUS41{1}
Split: iS40{( ceilDiv(( i2 * ( i0 * ( ( ceilDiv(i3, 2) ) * 128 ) ) ), 2) )} by factor 128 -> iblockIdx.x42{( ceilDiv(( ceilDiv(( i2 * ( i0 * ( ( ceilDiv(i3, 2) ) * 128 ) ) ), 2) ), 128) )}, ithreadIdx.x43{128}
loop domain: (iblockIdx.x42{( ceilDiv(( ceilDiv(( i2 * ( i0 * ( ( ceilDiv(i3, 2) ) * 128 ) ) ), 2) ), 128) )}, iUS41{1}, iV39{2}, ithreadIdx.x43{128}, ideviceIdx.y32{2}, ideviceIdx.x33{2})
Is vector size 2 a problem? The innermost dimension size is 10/cp_size=5.
Incidentally, I was reading
Fuser/csrc/scheduler/vectorize_helper.cpp
Line 833 in 1bded43
| getProjectedExtent(logical_id), num_devices); |
getProjectedExtent(logical_id) might not even be divisible by number of devices, although it should be divisible by the size of the corresponding mesh axis. cc @Priya2698Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
Multi-GPUallocation domainissues related to allocation domain supportissues related to allocation domain support