@@ -60,9 +60,7 @@ typedef sycl::event (*gemm_batch_impl_fn_ptr_t)(
6060 const char *,
6161 const char *,
6262 char *,
63- #if !defined(USE_ONEMATH_CUBLAS)
6463 const bool ,
65- #endif // !USE_ONEMATH_CUBLAS
6664 const std::vector<sycl::event> &);
6765
6866static gemm_batch_impl_fn_ptr_t
@@ -85,9 +83,7 @@ static sycl::event gemm_batch_impl(sycl::queue &exec_q,
8583 const char *matrixA,
8684 const char *matrixB,
8785 char *resultC,
88- #if !defined(USE_ONEMATH_CUBLAS)
8986 const bool is_row_major,
90- #endif // !USE_ONEMATH_CUBLAS
9187 const std::vector<sycl::event> &depends)
9288{
9389 type_utils::validate_type_for_device<Tab>(exec_q);
@@ -112,11 +108,6 @@ static sycl::event gemm_batch_impl(sycl::queue &exec_q,
112108 Tc *c, const std::int64_t ldc, const std::int64_t stridec,
113109 const std::int64_t batch_size,
114110 const std::vector<sycl::event> &deps) -> sycl::event {
115- #if defined(USE_ONEMATH_CUBLAS)
116- return mkl_blas::column_major::gemm_batch (
117- q, transA, transB, m, n, k, alpha, a, lda, stridea, b, ldb,
118- strideb, beta, c, ldc, stridec, batch_size, deps);
119- #else
120111 if (is_row_major) {
121112 return mkl_blas::row_major::gemm_batch (
122113 q, transA, transB, m, n, k, alpha, a, lda, stridea, b, ldb,
@@ -127,7 +118,6 @@ static sycl::event gemm_batch_impl(sycl::queue &exec_q,
127118 q, transA, transB, m, n, k, alpha, a, lda, stridea, b, ldb,
128119 strideb, beta, c, ldc, stridec, batch_size, deps);
129120 }
130- #endif // USE_ONEMATH_CUBLAS
131121 };
132122 gemm_batch_event = gemm_batch_func (
133123 exec_q,
@@ -317,7 +307,7 @@ std::tuple<sycl::event, sycl::event, bool>
317307
318308// cuBLAS supports only column-major storage
319309#if defined(USE_ONEMATH_CUBLAS)
320- const bool is_row_major = false ;
310+ constexpr bool is_row_major = false ;
321311
322312 transA = A_base_is_c_contig ? oneapi::mkl::transpose::T
323313 : oneapi::mkl::transpose::N;
@@ -396,17 +386,10 @@ std::tuple<sycl::event, sycl::event, bool>
396386 const char *b_typeless_ptr = matrixB.get_data ();
397387 char *r_typeless_ptr = resultC.get_data ();
398388
399- #if defined(USE_ONEMATH_CUBLAS)
400- sycl::event gemm_batch_ev =
401- gemm_batch_fn (exec_q, m, n, k, batch_size, lda, ldb, ldc, stridea,
402- strideb, stridec, transA, transB, a_typeless_ptr,
403- b_typeless_ptr, r_typeless_ptr, depends);
404- #else
405389 sycl::event gemm_batch_ev =
406390 gemm_batch_fn (exec_q, m, n, k, batch_size, lda, ldb, ldc, stridea,
407391 strideb, stridec, transA, transB, a_typeless_ptr,
408392 b_typeless_ptr, r_typeless_ptr, is_row_major, depends);
409- #endif // USE_ONEMATH_CUBLAS
410393
411394 sycl::event args_ev = dpctl::utils::keep_args_alive (
412395 exec_q, {matrixA, matrixB, resultC}, {gemm_batch_ev});
0 commit comments