From 7a40ca283f95a3390803223b8cf9be63cb4a6938 Mon Sep 17 00:00:00 2001 From: Thomas Benson Date: Sat, 9 Mar 2024 16:07:48 -0800 Subject: [PATCH] Broadcast lower-rank tensors during batched matmul When performing a matmul on two tensors with mismatched ranks, at least one of which is greater than 3, broadcast the lower-rank tensor. This also fixes a bug in the batched cov transform. Signed-off-by: Thomas Benson --- include/matx/transforms/matmul.h | 14 +++++--- test/00_transform/Cov.cu | 4 ++- test/00_transform/MatMul.cu | 60 ++++++++++++++++++++++++++++++++ 3 files changed, 72 insertions(+), 6 deletions(-) diff --git a/include/matx/transforms/matmul.h b/include/matx/transforms/matmul.h index 344a8ae5a..001b47ad6 100644 --- a/include/matx/transforms/matmul.h +++ b/include/matx/transforms/matmul.h @@ -808,7 +808,9 @@ class matxMatMulHandle_t { // Prep for batch looping using shape_type = typename TensorTypeA::desc_type::shape_type; - [[maybe_unused]] std::array idx{0}; + [[maybe_unused]] std::array a_idx{0}; + [[maybe_unused]] std::array b_idx{0}; + [[maybe_unused]] std::array c_idx{0}; [[maybe_unused]] auto a_shape = a.Shape(); [[maybe_unused]] size_t total_iter = 1; @@ -855,9 +857,9 @@ class matxMatMulHandle_t { for (size_t iter = 0; iter < total_iter; iter++) { // Get pointers into A/B/C for this round - auto ap = std::apply([&a_adj](auto... param) { return a_adj.GetPointer(param...); }, idx); - auto bp = std::apply([&b_adj](auto... param) { return b_adj.GetPointer(param...); }, idx); - auto cp = std::apply([&c_adj](auto... param) { return c_adj.GetPointer(param...); }, idx); + auto ap = std::apply([&a_adj](auto... param) { return a_adj.GetPointer(param...); }, a_idx); + auto bp = std::apply([&b_adj](auto... param) { return b_adj.GetPointer(param...); }, b_idx); + auto cp = std::apply([&c_adj](auto... param) { return c_adj.GetPointer(param...); }, c_idx); auto res = cublasLtMatmul( ltHandle, operationDesc, &salpha, (void *)ap, Adesc, (void *)bp, Bdesc, &sbeta, @@ -868,7 +870,9 @@ class matxMatMulHandle_t { MATX_ASSERT(res == CUBLAS_STATUS_SUCCESS, matxMatMulError); // Update all but the last 3 indices - UpdateIndices(a_adj, idx, 3); + UpdateIndices(a_adj, a_idx, 3); + UpdateIndices(b_adj, b_idx, 3); + UpdateIndices(c_adj, c_idx, 3); } } } diff --git a/test/00_transform/Cov.cu b/test/00_transform/Cov.cu index 22bc800db..bc92ba437 100644 --- a/test/00_transform/Cov.cu +++ b/test/00_transform/Cov.cu @@ -101,10 +101,12 @@ TYPED_TEST(CovarianceTestFloatTypes, BatchedCov) (batched_out = cov(batched_in)).run(); + cudaDeviceSynchronize(); + for (int im = 0; im < m; im++) { for (int in = 0; in < n; in++) { for (int ik = 0; ik < k; ik++) { - auto bv = slice<2>(batched_out, {im,in,ik,0,0}, {matxDropDim,matxDropDim,matxDropDim,matxKeepDim,matxKeepDim}); + auto bv = slice<2>(batched_out, {im,in,ik,0,0}, {matxDropDim,matxDropDim,matxDropDim,matxEnd,matxEnd}); MATX_TEST_ASSERT_COMPARE(this->pb, bv, "c_cov", this->thresh); } } diff --git a/test/00_transform/MatMul.cu b/test/00_transform/MatMul.cu index feb1f72c2..dd14f6fcb 100644 --- a/test/00_transform/MatMul.cu +++ b/test/00_transform/MatMul.cu @@ -659,6 +659,66 @@ TYPED_TEST(MatMulTestFloatNonHalfTypes, MatMulOp) MATX_EXIT_HANDLER(); } +TYPED_TEST(MatMulTestFloatNonHalfTypes, MatMulBroadcast) +{ + MATX_ENTER_HANDLER(); + + constexpr index_t n = 16; + constexpr index_t b = 8; + constexpr index_t x = 3; + constexpr index_t y = 4; + + tensor_t eye2{{n, n}}; + tensor_t a5{{x, y, b, n, n}}; + tensor_t c5{{x, y, b, n, n}}; + + const TypeParam two { 2.0 }; + const TypeParam three { 3.0 }; + + (eye2 = two*eye({n,n})).run(); + (a5 = three).run(); + + (c5 = 0).run(); + // Broadcast eye2, scaling each entry in a5 by 2 + (c5 = matmul(eye2, a5)).run(); + + cudaDeviceSynchronize(); + + for (index_t i0 = 0; i0 < x; i0++) + for (index_t i1 = 0; i1 < y; i1++) + for (index_t i2 = 0; i2 < b; i2++) + for (index_t i3 = 0; i3 < n; i3++) + for (index_t i4 = 0; i4 < n; i4++) { + if constexpr (is_complex_v) { + ASSERT_NEAR(c5(i0,i1,i2,i3,i4).real(), 2.0*a5(i0,i1,i2,i3,i4).real(), this->thresh); + ASSERT_NEAR(c5(i0,i1,i2,i3,i4).imag(), 2.0*a5(i0,i1,i2,i3,i4).imag(), this->thresh); + } else { + ASSERT_NEAR(c5(i0,i1,i2,i3,i4), two*a5(i0,i1,i2,i3,i4), this->thresh); + } + } + + (c5 = 0).run(); + // Broadcast eye2, scaling each entry in a5 by 2 + (c5 = matmul(a5, eye2)).run(); + + cudaDeviceSynchronize(); + + for (index_t i0 = 0; i0 < x; i0++) + for (index_t i1 = 0; i1 < y; i1++) + for (index_t i2 = 0; i2 < b; i2++) + for (index_t i3 = 0; i3 < n; i3++) + for (index_t i4 = 0; i4 < n; i4++) { + if constexpr (is_complex_v) { + ASSERT_NEAR(c5(i0,i1,i2,i3,i4).real(), 2.0*a5(i0,i1,i2,i3,i4).real(), this->thresh); + ASSERT_NEAR(c5(i0,i1,i2,i3,i4).imag(), 2.0*a5(i0,i1,i2,i3,i4).imag(), this->thresh); + } else { + ASSERT_NEAR(c5(i0,i1,i2,i3,i4), two*a5(i0,i1,i2,i3,i4), this->thresh); + } + } + + MATX_EXIT_HANDLER(); +} + TYPED_TEST(MatMulTestFloatTypes, MediumMatVec) { MATX_ENTER_HANDLER();