Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 10 additions & 9 deletions include/matx/operators/matmul.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,9 @@ namespace matx
float alpha_;
float beta_;
PermDims perm_;
std::array<index_t, OpA::Rank()> out_dims_;
mutable matx::tensor_t<typename OpA::scalar_type, OpA::Rank()> tmp_out_;
static constexpr int out_rank = std::max(OpA::Rank(), OpB::Rank());
std::array<index_t, out_rank> out_dims_;
mutable matx::tensor_t<typename OpA::scalar_type, out_rank> tmp_out_;

public:
using matxop = bool;
Expand All @@ -59,7 +60,7 @@ namespace matx
using matmul_xform_op = bool;

__MATX_INLINE__ std::string str() const {
return "matmul(" + get_type_str(a_) + ")";
return "matmul(" + get_type_str(a_) + "," + get_type_str(b_) + ")";
}

__MATX_INLINE__ MatMulOp(OpA a, OpB b, float alpha, float beta, PermDims perm) :
Expand All @@ -73,17 +74,17 @@ namespace matx
out_dims_[r] = b_.Size(perm_[r]);
}
else {
out_dims_[r] = a_.Size(r);
out_dims_[r] = OpA::Rank() > OpB::Rank() ? a_.Size(r) : b_.Size(r);
}
}
}
else {
for (int r = 0; r < Rank() - 2; r++) {
out_dims_[r] = a_.Size(r);
out_dims_[r] = OpA::Rank() > OpB::Rank() ? a_.Size(r) : b_.Size(r);
}

out_dims_[OpA::Rank() - 2] = a_.Size(OpA::Rank() - 2);
out_dims_[OpB::Rank() - 1] = b_.Size(OpB::Rank() - 1);
out_dims_[Rank() - 2] = a_.Size(OpA::Rank() - 2);
out_dims_[Rank() - 1] = b_.Size(OpB::Rank() - 1);
}
}

Expand All @@ -96,7 +97,7 @@ namespace matx

static __MATX_INLINE__ constexpr __MATX_HOST__ __MATX_DEVICE__ int32_t Rank()
{
return OpA::Rank();
return out_rank;
}
constexpr __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ index_t Size(int dim) const
{
Expand All @@ -123,7 +124,7 @@ namespace matx

if constexpr (is_matx_op<OpB>()) {
b_.PreRun(std::forward<ShapeType>(shape), std::forward<Executor>(ex));
}
}

if constexpr (is_device_executor_v<Executor>) {
make_tensor(tmp_out_, out_dims_, MATX_ASYNC_DEVICE_MEMORY, ex.getStream());
Expand Down