Skip to content

Commit a46c55d

Browse files
authored
Fix matvec output dims to match A rather than B (#523)
For matvecs, the batch dimensions for A and B should match and the final output dimension should match dim Rank-1 from A. Also generalize batching support so that the size of out_dims_ is based on the output rank.
1 parent 0fba213 commit a46c55d

2 files changed

Lines changed: 30 additions & 3 deletions

File tree

include/matx/operators/matvec.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,9 @@ namespace matx
4848
OpB b_;
4949
float alpha_;
5050
float beta_;
51-
std::array<index_t, 2> out_dims_;
52-
mutable matx::tensor_t<typename OpA::scalar_type, 2> tmp_out_;
51+
static constexpr int RANK = remove_cvref_t<OpB>::Rank();
52+
std::array<index_t, RANK> out_dims_;
53+
mutable matx::tensor_t<typename OpA::scalar_type, RANK> tmp_out_;
5354

5455
public:
5556
using matxop = bool;
@@ -65,7 +66,7 @@ namespace matx
6566
a_(A), b_(B), alpha_(alpha), beta_(beta) {
6667

6768
for (int r = 0; r < Rank(); r++) {
68-
out_dims_[r] = b_.Size(r);
69+
out_dims_[r] = a_.Size(r);
6970
}
7071
}
7172

test/00_transform/MatMul.cu

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -682,6 +682,12 @@ TYPED_TEST(MatMulTestFloatTypes, MediumMatVec)
682682
(cs = matvec(a, bs)).run();
683683
// example-end matvec-test-1
684684

685+
// Test the rank/size of the matvec operator
686+
auto a_times_bs = matvec(a, bs);
687+
ASSERT_EQ(a_times_bs.Rank(), 1);
688+
ASSERT_EQ(a_times_bs.Size(0), m);
689+
ASSERT_EQ(cs.Size(0), m);
690+
685691
MATX_TEST_ASSERT_COMPARE(this->pb, c, "c", this->thresh);
686692

687693
// Test also with rank-1 tensors rather than just slices
@@ -693,6 +699,26 @@ TYPED_TEST(MatMulTestFloatTypes, MediumMatVec)
693699

694700
MATX_TEST_ASSERT_COMPARE(this->pb, c, "c", this->thresh);
695701

702+
// Test with batching
703+
constexpr index_t batch1 = 5;
704+
constexpr index_t batch2 = 9;
705+
auto a_batch = clone<4>(a, {batch1, batch2, matxKeepDim, matxKeepDim});
706+
auto b_batch = clone<3>(bs, {batch1, batch2, matxKeepDim});
707+
auto batched_matvec = matvec(a_batch, b_batch);
708+
ASSERT_EQ(batched_matvec.Rank(), 3);
709+
ASSERT_EQ(batched_matvec.Size(0), batch1);
710+
ASSERT_EQ(batched_matvec.Size(1), batch2);
711+
ASSERT_EQ(batched_matvec.Size(2), m);
712+
auto result = make_tensor<TypeParam>(batched_matvec.Shape());
713+
(result = batched_matvec).run();
714+
for (index_t i = 0; i < batch1; i++) {
715+
for (index_t j = 0; j < batch2; j++) {
716+
auto rs = slice<1>(result, {i,j,0}, {matxDropDim,matxDropDim,matxEnd});
717+
auto rsc = clone<2>(rs, {matxKeepDim,1});
718+
MATX_TEST_ASSERT_COMPARE(this->pb, rsc, "c", this->thresh);
719+
}
720+
}
721+
696722
MATX_EXIT_HANDLER();
697723
}
698724

0 commit comments

Comments
 (0)