Skip to content

[TIR] Fix reduce buffer allocation position#17799

Merged
Hzfengsy merged 2 commits into
apache:mainfrom
wrongtest-intellif:fix_reduction_buffer_position
Apr 3, 2025
Merged

[TIR] Fix reduce buffer allocation position#17799
Hzfengsy merged 2 commits into
apache:mainfrom
wrongtest-intellif:fix_reduction_buffer_position

Conversation

@wrongtest-intellif

@wrongtest-intellif wrongtest-intellif commented Apr 2, 2025

Copy link
Copy Markdown
Contributor

The change modify lca detector to ensure the reduction buffer allocation position dominates all reduce related loops. This behavior is also compatible with same logic in DecomposeReduction.

Below is a working case with wrong cpu compile results. We just build a simple reduce-sum function, with outer tiling loops in order "SRS".

import numpy as np
import tvm
from tvm import tir, te, topi

x = te.placeholder(name="x", shape=[256, 256, 256], dtype="float32")
y = topi.sum(x, axis=1)
f = te.create_prim_func([x, y])
origin_mod = tvm.IRModule.from_expr(f)

s = tir.schedule.Schedule(f)
blk = s.get_child_blocks(s.get_block("root"))[-1]
i, j, k = s.get_loops(blk)
i0, i1 = s.split(i, factors=[4, 64])
j0, j1 = s.split(j, factors=[4, 64])
k0, k1 = s.split(k, factors=[4, 64])
s.reorder(i0, k0, j0, i1, k1, j1)
write_blk = s.cache_write(blk, 0, "")
s.reverse_compute_at(write_blk, j0)
tiled_mod = s.mod
print(tiled_mod)

x = tvm.nd.array(np.random.uniform(0, 128, [256, 256, 256]).astype("float32"))
y = tvm.nd.array(np.zeros([256, 256], "float32"))
expect = np.sum(x.numpy(), axis=1, keepdims=False)

lib1 = tvm.compile(origin_mod, target="llvm")
lib1(x, y)
np.testing.assert_allclose(y.numpy(), expect)  # origin module is correct

lib2 = tvm.compile(tiled_mod, target="llvm")
lib2(x, y)
np.testing.assert_allclose(y.numpy(), expect)   # scheduled result is wrong

The error is due to transformation tir.PlanAndUpdateBufferAllocationLocation and tir.CompactBufferAllocation:

# after PlanAndUpdateBufferAllocationLocation
# the reduction write buffer `x_red_` position is incorrectly put under reduction tiling loop `k1_0`

@I.ir_module
class Module:
    @T.prim_func
    def main(x: T.Buffer((256, 256, 256), "float32"), x_red: T.Buffer((256, 256), "float32")):
        for ax0_0, k1_0, ax1_0 in T.grid(4, 4, 4):
            with T.block(""):
                x_red_ = T.alloc_buffer((256, 256))
                for ax0_1, k1_1, ax1_1 in T.grid(64, 64, 64):
                    with T.block("x_red"):
                        v_ax0 = T.axis.spatial(256, ax0_0 * 64 + ax0_1)
                        v_ax1 = T.axis.spatial(256, ax1_0 * 64 + ax1_1)
                        v_k1 = T.axis.reduce(256, k1_0 * 64 + k1_1)
                        if v_k1 == 0:
                            x_red_[v_ax0, v_ax1] = T.float32(0.0)
                        x_red_[v_ax0, v_ax1] = x_red_[v_ax0, v_ax1] + x[v_ax0, v_k1, v_ax1]
                for ax0, ax1 in T.grid(64, 64):
                    with T.block("x_red_"):
                        v0 = T.axis.spatial(256, ax0_0 * 64 + ax0)
                        v1 = T.axis.spatial(256, ax1_0 * 64 + ax1)
                        x_red[v0, v1] = x_red_[v0, v1]

# then after CompactBufferAllocation, because of incorrect planned position
# the compacted buffer lead to incorrect reuse across spatial loop `ax1_0`
# the correct compacted shape should be ` x_red_ = T.alloc_buffer((64, 4 * 64))`
@I.ir_module
class Module:
    @T.prim_func
    def main(x: T.Buffer((256, 256, 256), "float32"), x_red: T.Buffer((256, 256), "float32")):
        T.func_attr({"target": T.target({"host": {"keys": ["cpu"], "kind": "llvm", "mtriple": "x86_64-unknown-linux-gnu", "tag": ""}, "keys": ["cpu"], "kind": "llvm", "mtriple": "x86_64-unknown-linux-gnu", "tag": ""}), "tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0_0, k1_0, ax1_0 in T.grid(4, 4, 4):
            with T.block(""):
                x_red_ = T.alloc_buffer((64, 64))
                for ax0_1, k1_1, ax1_1 in T.grid(64, 64, 64):
                    with T.block("x_red"):
                        if k1_0 * 64 + k1_1 == 0:
                            x_red_[ax0_0 * 64 + ax0_1 - ax0_0 * 64, ax1_0 * 64 + ax1_1 - ax1_0 * 64] = T.float32(0.0)
                        x_red_[ax0_0 * 64 + ax0_1 - ax0_0 * 64, ax1_0 * 64 + ax1_1 - ax1_0 * 64] = x_red_[ax0_0 * 64 + ax0_1 - ax0_0 * 64, ax1_0 * 64 + ax1_1 - ax1_0 * 64] + x[ax0_0 * 64 + ax0_1, k1_0 * 64 + k1_1, ax1_0 * 64 + ax1_1]
                for ax0, ax1 in T.grid(64, 64):
                    with T.block("x_red_"):
                        x_red[ax0_0 * 64 + ax0, ax1_0 * 64 + ax1] = x_red_[ax0_0 * 64 + ax0 - ax0_0 * 64, ax1_0 * 64 + ax1 - ax1_0 * 64]

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR addresses a bug in the TIR transformation where the reduction write buffer allocation position is incorrectly placed under a reduction tiling loop. The changes in this PR add a new test case to verify that the reduction buffer allocation now correctly dominates all reduce loops.

  • Added a new test function (test_reduce_buffer_dominate_reduce_loops) to compare the IR before and after applying the buffer allocation transformation.
  • Enhanced the test by including explicit TIR annotations (T.reads/T.writes) to ensure the transformation preserves correct buffer reuse behavior.
Files not reviewed (1)
  • src/tir/analysis/buffer_access_lca_detector.cc: Language not supported
Comments suppressed due to low confidence (2)

tests/python/tir-transform/test_tir_transform_plan_update_buffer_allocation_location.py:432

  • [nitpick] The allocated shape for x_red_ is (256, 256), but based on the intended reduction buffer compaction (as noted in the description), a more compact shape (e.g., (64, 256) or (64, 64)) might be expected. Consider verifying that the allocation shape aligns with the intended transformation.
x_red_ = T.alloc_buffer((256, 256))

tests/python/tir-transform/test_tir_transform_plan_update_buffer_allocation_location.py:443

  • [nitpick] The index expression 'ax0_0 * 64 + ax0_1 - ax0_0 * 64' simplifies to 'ax0_1', and similarly for the second index, which might improve readability if simplified. Consider simplifying these expressions to make the intent clearer.
x_red_[ax0_0 * 64 + ax0_1 - ax0_0 * 64, ax1_0 * 64 + ax1_1 - ax1_0 * 64] = T.float32(0.0)

@Hzfengsy Hzfengsy left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@Hzfengsy Hzfengsy merged commit 6365a30 into apache:main Apr 3, 2025
ShiboXing pushed a commit to ShiboXing/tvm that referenced this pull request Aug 10, 2025
* fix reduce buffer allocation position

* fix test_tir_analysis_detect_buffer_access_lca.py::test_buffer_load_store
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants