Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions python/tvm/contrib/cublas.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,31 @@ def matmul(lhs, rhs, transa=False, transb=False):
lambda ins, outs: _intrin.call_packed(
"tvm.contrib.cublas.matmul",
ins[0], ins[1], outs[0], transa, transb), name="C")

def batch_matmul(lhs, rhs, transa=False, transb=False):
"""Create an extern op that compute batch matrix mult of A and rhs with cuBLAS

Parameters
----------
lhs : Tensor
The left matrix operand
rhs : Tensor
The right matrix operand
transa : bool
Whether transpose lhs
transb : bool
Whether transpose rhs

Returns
-------
C : Tensor
The result tensor.
"""
b = lhs.shape[0]
n = lhs.shape[2] if transa else lhs.shape[1]
m = rhs.shape[1] if transb else rhs.shape[2]
return _api.extern(
(b, n, m), [lhs, rhs],
lambda ins, outs: _intrin.call_packed(
"tvm.contrib.cublas.batch_matmul",
ins[0], ins[1], outs[0], transa, transb), name="C")
60 changes: 60 additions & 0 deletions src/contrib/cublas/cublas.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,50 @@ struct CublasDgemmOp {
}
};

struct CublasSgemmBatchOp {
typedef float TDatatype;
cublasHandle_t handle;
explicit CublasSgemmBatchOp(cublasHandle_t hdl)
: handle(hdl)
{}
void operator()(int batch_size, bool ta, bool tb, int M, int N, int K, float alpha, float* A,
int a_stride, int lda, float* B, int b_stride, int ldb, float beta, float* C,
int c_stride, int ldc) {
CHECK_CUBLAS_ERROR(cublasSgemmStridedBatched(handle,
BooleanToTranspose(ta),
BooleanToTranspose(tb),
M, N, K,
&alpha,
A, lda, a_stride,
B, ldb, b_stride,
&beta,
C, ldc, c_stride,
batch_size));
}
};

struct CublasDgemmBatchOp {
typedef double TDatatype;
cublasHandle_t handle;
explicit CublasDgemmBatchOp(cublasHandle_t hdl)
: handle(hdl)
{}
void operator()(int batch_size, bool ta, bool tb, int M, int N, int K, double alpha, double* A,
int a_stride, int lda, double* B, int b_stride, int ldb, double beta, double* C,
int c_stride, int ldc) {
CHECK_CUBLAS_ERROR(cublasDgemmStridedBatched(handle,
BooleanToTranspose(ta),
BooleanToTranspose(tb),
M, N, K,
&alpha,
A, lda, a_stride,
B, ldb, b_stride,
&beta,
C, ldc, c_stride,
batch_size));
}
};

// matrix multiplication for row major
TVM_REGISTER_GLOBAL("tvm.contrib.cublas.matmul")
.set_body([](TVMArgs args, TVMRetValue *ret) {
Expand All @@ -96,5 +140,21 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cublas.matmul")
else
CallGemm(args, ret, CublasDgemmOp(entry_ptr->handle));
});

TVM_REGISTER_GLOBAL("tvm.contrib.cublas.batch_matmul")
.set_body([](TVMArgs args, TVMRetValue* ret) {
DLTensor* A = args[0];

CHECK(TypeMatch(A->dtype, kDLFloat, 32) ||
TypeMatch(A->dtype, kDLFloat, 64));

CuBlasThreadEntry* entry_ptr = CuBlasThreadEntry::ThreadLocal();

if (TypeMatch(A->dtype, kDLFloat, 32))
CallBatchGemm(args, ret, CublasSgemmBatchOp(entry_ptr->handle));
else
CallBatchGemm(args, ret, CublasDgemmBatchOp(entry_ptr->handle));
});

} // namespace contrib
} // namespace tvm
28 changes: 28 additions & 0 deletions tests/python/contrib/test_cublas.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,34 @@ def verify(target="cuda"):
c.asnumpy(), np.dot(a.asnumpy(), b.asnumpy()), rtol=1e-5)
verify()

def test_batch_matmul():
j = 16
n = 1024
l = 128
m = 235
A = tvm.placeholder((j, n, l), name='A')
B = tvm.placeholder((j, l, m), name='B')
C = cublas.batch_matmul(A, B)
s = tvm.create_schedule(C.op)

def verify(target="cuda"):
if not tvm.module.enabled(target):
print("skip because %s is not enabled..." % target)
return
if not tvm.get_global_func("tvm.contrib.cublas.matmul", True):
print("skip because extern function is not available")
return
ctx = tvm.gpu(0)
f = tvm.build(s, [A, B, C], target)
a = tvm.nd.array(np.random.uniform(size=(j, n, l)).astype(A.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=(j, l, m)).astype(B.dtype), ctx)
c = tvm.nd.array(np.zeros((j, n, m), dtype=C.dtype), ctx)
f(a, b, c)
tvm.testing.assert_allclose(
c.asnumpy(), np.matmul(a.asnumpy(), b.asnumpy()), rtol=1e-5)
verify()


if __name__ == "__main__":
test_matmul_add()
test_batch_matmul()
32 changes: 32 additions & 0 deletions topi/include/topi/contrib/cublas.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,38 @@ inline Tensor cublas_matmul(const Tensor& lhs,
}, "C", "", {})[0];
}

/*!
* \brief Create an op that multiplies batch matrices
* lhs and rhs with cuBLAS
*
* \param lhs The left matrix operand
* \param rhs The right matrix operand
* \param transa Whether to transpose lhs
* \param transb Whether to transpose rhs
*
* \return The output tensor
*/
inline Tensor cublas_batch_matmul(const Tensor& lhs,
const Tensor& rhs,
bool transa,
bool transb) {
auto b = lhs->shape[0];
auto n = transa ? lhs->shape[2] : lhs->shape[1];
auto m = transb ? rhs->shape[1] : rhs->shape[2];

return make_extern(
{ { b, n, m } }, { lhs->dtype }, { lhs, rhs },
[&](Array<Buffer> ins, Array<Buffer> outs) {
return call_packed({
Expr("tvm.contrib.cublas.batch_matmul"),
pack_buffer(ins[0]),
pack_buffer(ins[1]),
pack_buffer(outs[0]),
transa,
transb });
}, "C", "", {})[0];
}

} // namespace contrib
} // namespace topi

Expand Down
29 changes: 28 additions & 1 deletion topi/python/topi/cuda/batch_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,33 @@
"""cuda batch_matmul operators"""
from __future__ import absolute_import as _abs
import tvm

from tvm.contrib import cublas
from topi.nn import batch_matmul, batch_matmul_default
from .. import generic
from ..util import traverse_inline, get_const_tuple, get_max_power2_factor

@batch_matmul.register(["cuda", "gpu"])
def batch_matmul_cuda(x, y):
"""Computes batch matrix multiplication of `x` and `y` when `x` and `y` are
data in batch.

Parameters
----------
x : tvm.Tensor
3-D with shape [batch, M, K]

y : tvm.Tensor
3-D with shape [batch, N, K]

Returns
-------
output : tvm.Tensor
3-D with shape [batch, M, N]
"""
target = tvm.target.current_target()
if target.target_name == "cuda" and "cublas" in target.libs:
return cublas.batch_matmul(x, y, False, True)
return batch_matmul_default(x, y)

@generic.schedule_batch_matmul.register(["cuda", "gpu"])
def schedule_batch_matmul(outs):
Expand All @@ -38,6 +61,10 @@ def schedule_batch_matmul(outs):
s: Schedule
The computation schedule for the op.
"""
target = tvm.target.current_target()
if target.target_name == "cuda" and "cublas" in target.libs:
return generic.schedule_extern(outs)

outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])

Expand Down