Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 44 additions & 11 deletions csrc/runtime/fusion_cache_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,21 +80,54 @@ void resetAllocationDomainAndContiguity(
}
const auto [sizes, strides] =
inferAllocationSizesAndStrides(tensor, tv, ExpressionEvaluator());
auto contiguity_without_reduction = computeContiguity(sizes, strides);
std::vector<std::optional<bool>> contiguity;
int64_t index = 0;
for (auto id : tv->getMaybeAllocationDomain()) {

const auto& alloc = tv->getMaybeAllocationDomain();

// Custom contiguity inference that considers IterDomain information
std::vector<std::optional<bool>> contiguity(alloc.size(), std::nullopt);

// Single pass from right to left with two dynamic indices:
// - alloc_idx: iterates through allocation domain
// - sizes_idx: tracks position in sizes/strides (excludes reductions)
int64_t sizes_idx = (int64_t)sizes.size() - 1;
int64_t prev_non_skipped_sizes_idx = -1;

for (int64_t alloc_idx = (int64_t)alloc.size() - 1; alloc_idx >= 0; --alloc_idx) {
auto id = alloc[alloc_idx];

// Reduction dimensions: nullopt contiguity (already set), no entry in sizes/strides
if (id->isReduction()) {
contiguity.push_back(std::nullopt);
} else if (
!id->isBroadcast() &&
!contiguity_without_reduction[index].has_value()) {
contiguity.push_back(true);
index++;
// Don't decrement sizes_idx since reductions have no entry
continue;
}

// This dimension has an entry in sizes/strides
NVF_CHECK(sizes_idx >= 0, "Sizes index out of bounds");

// Broadcast dimensions: nullopt contiguity (already set), but has entry in sizes/strides
if (id->isBroadcast()) {
sizes_idx--; // Move to next dimension in sizes/strides
continue;
}

// Non-broadcast, non-reduction dimension
if (prev_non_skipped_sizes_idx == -1) {
// This is the rightmost (innermost) non-skipped dimension
// It's contiguous if stride == 1
contiguity[alloc_idx] = (strides[sizes_idx] == 1);
} else {
contiguity.push_back(contiguity_without_reduction[index++]);
// A dimension is contiguous if its stride equals the stride of the
// next dimension multiplied by that dimension's size
contiguity[alloc_idx] = (strides[sizes_idx] ==
strides[prev_non_skipped_sizes_idx] * sizes[prev_non_skipped_sizes_idx]);
}

prev_non_skipped_sizes_idx = sizes_idx;
sizes_idx--; // Move to next dimension in sizes/strides
}

NVF_CHECK(sizes_idx == -1, "Not all sizes/strides were consumed");

tv->setContiguity(contiguity);
}

Expand Down
2 changes: 1 addition & 1 deletion csrc/runtime/fusion_kernel_runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -739,7 +739,7 @@ std::unordered_map<Val*, PolymorphicValue> FusionKernelRuntime::
}

args_manager.updateWithSegmentOutputs(
group_to_run->outputs(), group_runtime_outputs, run_order_id, true);
group_to_run->outputs(), group_runtime_outputs, run_order_id, false);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: Check whether changing this flag from true to false is intentional and related to the contiguity computation fix, as it's not explained in the PR description.

}

if (isProfilerEnabled()) {
Expand Down
24 changes: 11 additions & 13 deletions csrc/tensor_metadata.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -318,22 +318,20 @@ inferAllocationSizesAndStrides(
allocation_sizes.reserve(alloc.size());
allocation_strides.reserve(alloc.size());
for (IterDomain* id : alloc | TensorDomain::kNoReductions) {
auto it = active_ids.find(id);
NVF_ERROR(
it != active_ids.end(),
"Allocation domain of tensor ",
tv->toString(),
" is not complete. Missing ID: ",
id->toString());
auto [size, stride] = it->second;
if (id->isDeviceDim()) {
allocation_sizes.push_back(1);
allocation_strides.push_back(1);
continue;
}

auto it = active_ids.find(id);
if (it != active_ids.end()) {
allocation_sizes.push_back(it->second.first);
allocation_strides.push_back(it->second.second);
continue;
} else {
allocation_sizes.push_back(size);
}
// grouped matmul could introduce some IDs not in active_ids
// For those IDs, just push some dummy values
allocation_sizes.push_back(1);
allocation_strides.push_back(1);
allocation_strides.push_back(stride);
}
return {std::move(allocation_sizes), std::move(allocation_strides)};
}
Expand Down
Loading