Skip to content

Commit 1f6799d

Browse files
committed
[CUBLAS] Add support for nn.dense and nn.batch_matmul
This commit includes a fix for cublas.batch_matmul when mixed precision is being used.
1 parent b2a0e1d commit 1f6799d

File tree

3 files changed

+164
-16
lines changed

3 files changed

+164
-16
lines changed

python/tvm/relay/op/contrib/cublas.py

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -64,17 +64,23 @@ def pattern_table() -> List[Tuple[str, relay.Pattern, Callable[[relay.Call], boo
6464
"""Get the cuBLAS pattern table."""
6565

6666
def matmul_pattern() -> relay.Pattern:
67-
"""Create pattern for matrix multiply."""
67+
"""Create pattern for matmul."""
6868
return is_op("nn.matmul")(wildcard(), wildcard())
6969

70-
def check_matmul(matched: relay.Call) -> bool:
70+
def batch_matmul_pattern() -> relay.Pattern:
71+
"""Create pattern for batch_matmul."""
72+
return is_op("nn.batch_matmul")(wildcard(), wildcard())
73+
74+
def dense_pattern() -> relay.Pattern:
75+
"""Create pattern for dense."""
76+
return is_op("nn.dense")(wildcard(), wildcard())
77+
78+
def check_matmul_like(matched: relay.Call) -> bool:
7179
"""Check if matmul is supported by cuBLAS."""
72-
# Units not supported
73-
if matched.attrs["units"] is not None:
74-
return False
7580
# Input data types can't be mixed
7681
if matched.args[0].checked_type.dtype != matched.args[1].checked_type.dtype:
7782
return False
83+
7884
in_dtype = matched.args[0].checked_type.dtype
7985
out_dtype = matched.checked_type.dtype
8086
# Only the following data type combinations are supported
@@ -87,18 +93,21 @@ def check_matmul(matched: relay.Call) -> bool:
8793
("int8", "float32"),
8894
]:
8995
return False
96+
9097
# If inputs are int8, input column strides must be a multiple of 4
9198
if in_dtype == "int8":
9299
if (
93-
matched.args[0].checked_type.shape[1] % 4 != 0
94-
or matched.args[1].checked_type.shape[1] % 4 != 0
100+
matched.args[0].checked_type.shape[-1] % 4 != 0
101+
or matched.args[1].checked_type.shape[-1] % 4 != 0
95102
):
96103
return False
97104

98105
return True
99106

100107
return [
101-
("cublas.matmul", matmul_pattern(), check_matmul),
108+
("cublas.matmul", matmul_pattern(), check_matmul_like),
109+
("cublas.batch_matmul", batch_matmul_pattern(), check_matmul_like),
110+
("cublas.dense", dense_pattern(), check_matmul_like),
102111
]
103112

104113

@@ -156,3 +165,21 @@ def _lower_matmul(op: relay.Call, inputs: List[te.Tensor]) -> te.Tensor:
156165
transb=op.attrs["transpose_b"],
157166
dtype=op.checked_type.dtype,
158167
)
168+
169+
170+
@_lower_composite("cublas.batch_matmul")
171+
def _lower_batch_matmul(op: relay.Call, inputs: List[te.Tensor]) -> te.Tensor:
172+
"""Lower a batch_matmul using cuBLAS."""
173+
return cublas.batch_matmul(
174+
inputs[0],
175+
inputs[1],
176+
transa=op.attrs["transpose_a"],
177+
transb=op.attrs["transpose_b"],
178+
dtype=op.checked_type.dtype,
179+
)
180+
181+
182+
@_lower_composite("cublas.dense")
183+
def _lower_dense(op: relay.Call, inputs: List[te.Tensor]) -> te.Tensor:
184+
"""Lower a dense using cuBLAS."""
185+
return cublas.matmul(inputs[0], inputs[1], False, True, dtype=op.checked_type.dtype)

src/runtime/contrib/cublas/cublas.cc

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -277,23 +277,23 @@ inline void CallBatchGemmEx(TVMArgs args, TVMRetValue* ret, cublasHandle_t hdl)
277277
ICHECK_EQ(C->ndim, 3);
278278

279279
int batch_size = BatchCount3D(C);
280-
ICHECK_EQ(ElementStride(A), 1);
281-
ICHECK_EQ(ElementStride(B), 1);
282-
ICHECK_EQ(ElementStride(C), 1);
280+
ICHECK_EQ(ElementStride3D(A), 1);
281+
ICHECK_EQ(ElementStride3D(B), 1);
282+
ICHECK_EQ(ElementStride3D(C), 1);
283283

284284
ICHECK(TypeEqual(A->dtype, B->dtype));
285285

286286
// C can never be transposed.
287-
ICHECK(!IsInPlaceTransposed(C));
287+
ICHECK(!IsInPlaceTransposed3D(C));
288288

289289
// Reversed strides indicates an in-place transpose operation.
290-
transa = IsInPlaceTransposed(A) ? !transa : transa;
291-
transb = IsInPlaceTransposed(B) ? !transb : transb;
290+
transa = IsInPlaceTransposed3D(A) ? !transa : transa;
291+
transb = IsInPlaceTransposed3D(B) ? !transb : transb;
292292

293293
ICHECK(CheckMixPrecisionType(A->dtype, C->dtype, true)) << "Unsupported data type";
294-
ICHECK(!TypeMatch(A->dtype, kDLInt, 8) || ColumnStride(A) % 4 == 0)
294+
ICHECK(!TypeMatch(A->dtype, kDLInt, 8) || ColumnStride3D(A) % 4 == 0)
295295
<< "leading dimension must divide 4 for int8 gemm";
296-
ICHECK(!TypeMatch(B->dtype, kDLInt, 8) || ColumnStride(B) % 4 == 0)
296+
ICHECK(!TypeMatch(B->dtype, kDLInt, 8) || ColumnStride3D(B) % 4 == 0)
297297
<< "leading dimension must divide 4 for int8 gemm";
298298
double alpha = args.size() > 5 ? args[5] : 1.0;
299299
double beta = args.size() > 6 ? args[6] : 0.0;

tests/python/contrib/test_cublas.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,5 +256,126 @@ def test_relay_cublas_matmul(n, m, k, in_dtype, out_dtype, transpose_a, transpos
256256
_verify_cublas_relay(matmul)
257257

258258

259+
@tvm.testing.requires_cuda
260+
@pytest.mark.parametrize(
261+
"n,m,k",
262+
[
263+
(64, 128, 32),
264+
(17, 32, 16),
265+
(24, 17, 12),
266+
(96, 4, 17),
267+
],
268+
)
269+
@pytest.mark.parametrize(
270+
"in_dtype,out_dtype",
271+
[
272+
("float32", "float32"),
273+
("float16", "float16"),
274+
("float16", "float32"),
275+
("int8", "int32"),
276+
("float64", "float64"),
277+
("int8", "float32"),
278+
],
279+
)
280+
def test_relay_cublas_dense(n, m, k, in_dtype, out_dtype):
281+
unsupported_configs = [
282+
(96, 4, 17, "int8", "float32"),
283+
(96, 4, 17, "int8", "int32"),
284+
]
285+
if (n, m, k, in_dtype, out_dtype) in unsupported_configs:
286+
pytest.skip("Unsupported parameters.")
287+
288+
data = tvm.relay.var("data", tvm.relay.TensorType((n, k), in_dtype))
289+
weight = tvm.relay.var("weight", tvm.relay.TensorType((m, k), in_dtype))
290+
dense = relay.op.nn.dense(data, weight, out_dtype=out_dtype)
291+
_verify_cublas_relay(dense)
292+
293+
294+
@tvm.testing.requires_cuda
295+
@pytest.mark.parametrize(
296+
"n,m,k,batch_a,batch_b,transpose_a,transpose_b",
297+
[
298+
(64, 128, 32, 16, 16, False, False),
299+
(17, 32, 16, 16, 1, True, False),
300+
(24, 17, 12, 17, 17, False, True),
301+
(96, 4, 17, 53, 1, True, True),
302+
],
303+
)
304+
@pytest.mark.parametrize(
305+
"in_dtype,out_dtype",
306+
[
307+
("float32", "float32"),
308+
("float16", "float16"),
309+
("float16", "float32"),
310+
("int8", "int32"),
311+
("float64", "float64"),
312+
("int8", "float32"),
313+
],
314+
)
315+
def test_relay_cublas_batch_matmul(
316+
n, m, k, batch_a, batch_b, in_dtype, out_dtype, transpose_a, transpose_b
317+
):
318+
unsupported_configs = [
319+
(17, 32, 16, 16, 1, "int8", "float32", True, False),
320+
(96, 4, 17, 53, 1, "int8", "float32", True, True),
321+
(17, 32, 16, 16, 1, "int8", "int32", True, False),
322+
(96, 4, 17, 53, 1, "int8", "int32", True, True),
323+
]
324+
if (
325+
n,
326+
m,
327+
k,
328+
batch_a,
329+
batch_b,
330+
in_dtype,
331+
out_dtype,
332+
transpose_a,
333+
transpose_b,
334+
) in unsupported_configs:
335+
pytest.skip("Unsupported parameters.")
336+
337+
a_shape = (batch_a, k, n) if transpose_a else (batch_a, n, k)
338+
b_shape = (batch_b, m, k) if transpose_b else (batch_b, k, m)
339+
a = tvm.relay.var("A", tvm.relay.TensorType(a_shape, in_dtype))
340+
b = tvm.relay.var("B", tvm.relay.TensorType(b_shape, in_dtype))
341+
batch_matmul = relay.op.nn.batch_matmul(a, b, out_dtype, transpose_a, transpose_b)
342+
_verify_cublas_relay(batch_matmul)
343+
344+
345+
@tvm.testing.requires_cuda
346+
@pytest.mark.parametrize(
347+
"n,m,k",
348+
[
349+
(64, 128, 32),
350+
(17, 32, 16),
351+
(24, 17, 12),
352+
(96, 4, 17),
353+
],
354+
)
355+
@pytest.mark.parametrize(
356+
"in_dtype,out_dtype",
357+
[
358+
("float32", "float32"),
359+
("float16", "float16"),
360+
("float16", "float32"),
361+
("int8", "int32"),
362+
("float64", "float64"),
363+
("int8", "float32"),
364+
],
365+
)
366+
def test_relay_cublas_dense(n, m, k, in_dtype, out_dtype):
367+
unsupported_configs = [
368+
(96, 4, 17, "int8", "float32"),
369+
(96, 4, 17, "int8", "int32"),
370+
]
371+
if (n, m, k, in_dtype, out_dtype) in unsupported_configs:
372+
pytest.skip("Unsupported parameters.")
373+
374+
data = tvm.relay.var("data", tvm.relay.TensorType((n, k), in_dtype))
375+
weight = tvm.relay.var("weight", tvm.relay.TensorType((m, k), in_dtype))
376+
dense = relay.op.nn.dense(data, weight, out_dtype=out_dtype)
377+
_verify_cublas_relay(dense)
378+
379+
259380
if __name__ == "__main__":
260381
pytest.main([__file__])

0 commit comments

Comments
 (0)