Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
268 changes: 175 additions & 93 deletions include/matx/transforms/inverse.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ class matxInversePlan_t {
constexpr static int RANK = TensorTypeA::Rank();
static_assert(RANK == TensorTypeAInv::Rank(), "Input and output tensor ranks must match");
using T1 = typename TensorTypeAInv::value_type;
// Linear systems less than or equal to this threshold in size use the cublas*matinvBatched
// functions. This is one fused kernel rather than two separate kernels and it does not
// overwrite the input, so in some cases we do not require a temporary work buffer for the input.
static constexpr int BATCHED_SINGLE_CALL_INV_THRESHOLD = 32;

public:
/**
Expand All @@ -95,6 +99,8 @@ class matxInversePlan_t {
* Input tensor view
* @param a_inv
* Inverse of A (if it exists)
* @param stream
* CUDA stream on which the operation runs
*
*/
matxInversePlan_t(TensorTypeAInv &a_inv, const TensorTypeA &a, cudaStream_t stream)
Expand All @@ -118,37 +124,53 @@ class matxInversePlan_t {

params = GetInverseParams(a_inv, a, stream);

const bool use_input_workbuf = UseInputWorkBuffer(a);
// The cuBLAS getr*Batched LU decomposition functions overwrite the input, so
// we use a temporary buffer to store the inputs.
make_tensor(a_workbuf, a.Shape(), MATX_ASYNC_DEVICE_MEMORY, stream);
if (use_input_workbuf) {
make_tensor(a_workbuf, a.Shape(), MATX_ASYNC_DEVICE_MEMORY, stream);
}

if constexpr (ALGO == MAT_INVERSE_ALGO_LU) {
// cuBLAS requires a list of pointers to each matrix. Construct that list
// here as our batch dims
std::vector<const T1 *> in_pointers;
std::vector<T1 *> out_pointers;
if constexpr (RANK == 2) {
in_pointers.push_back(&a_workbuf(0, 0));
if (use_input_workbuf) {
in_pointers.push_back(&a_workbuf(0, 0));
} else {
if constexpr (is_tensor_view_v<TensorTypeA>) {
// We know this is a tensor view because we are not using a_workbuf
in_pointers.push_back(&a(0, 0));
}
}
out_pointers.push_back(&a_inv(0, 0));
}
else {
using ShapeTypeA = typename decltype(a_workbuf)::desc_type::shape_type;
using ShapeTypeAInv = typename TensorTypeAInv::desc_type::shape_type;
int batch_offset = 2;
cuda::std::array<ShapeTypeA, TensorTypeA::Rank()> a_idx{0};
cuda::std::array<ShapeTypeAInv, TensorTypeAInv::Rank()> a_inv_idx{0};
cuda::std::array<index_t, TensorTypeA::Rank()> a_idx{0};
cuda::std::array<index_t, TensorTypeAInv::Rank()> a_inv_idx{0};
auto a_shape = a.Shape();
// Get total number of batches
size_t total_iter = std::accumulate(a_shape.begin(), a_shape.begin() + TensorTypeA::Rank() - batch_offset, 1, std::multiplies<ShapeTypeA>());
size_t total_iter = std::accumulate(a_shape.begin(), a_shape.begin() + TensorTypeA::Rank() - batch_offset, 1, std::multiplies<index_t>());
for (size_t iter = 0; iter < total_iter; iter++) {
auto ip = cuda::std::apply([&a_workbuf = a_workbuf](auto... param) { return a_workbuf.GetPointer(param...); }, a_idx);
in_pointers.push_back(ip);
// Update all but the last 2 indices
UpdateIndices<decltype(a_workbuf), ShapeTypeA, TensorTypeA::Rank()>(a_workbuf, a_idx, batch_offset);
if (use_input_workbuf) {
auto ip = cuda::std::apply([&a_workbuf = a_workbuf](auto... param) { return a_workbuf.GetPointer(param...); }, a_idx);
in_pointers.push_back(ip);
UpdateIndices<decltype(a_workbuf), index_t, TensorTypeA::Rank()>(a_workbuf, a_idx, batch_offset);
} else {
if constexpr (is_tensor_view_v<TensorTypeA>) {
// We know this is a tensor view because we are not using a_workbuf
auto ip = cuda::std::apply([&a](auto... param) { return a.GetPointer(param...); }, a_idx);
in_pointers.push_back(ip);
UpdateIndices<TensorTypeA, index_t, TensorTypeA::Rank()>(a, a_idx, batch_offset);
}
}

auto op = cuda::std::apply([&a_inv](auto... param) { return a_inv.GetPointer(param...); }, a_inv_idx);
out_pointers.push_back(op);
UpdateIndices<TensorTypeAInv, ShapeTypeAInv, TensorTypeAInv::Rank()>(a_inv, a_inv_idx, batch_offset);
UpdateIndices<TensorTypeAInv, index_t, TensorTypeAInv::Rank()>(a_inv, a_inv_idx, batch_offset);
}
}

Expand All @@ -157,9 +179,12 @@ class matxInversePlan_t {
MATX_ASYNC_DEVICE_MEMORY, stream);
matxAlloc((void **)&d_A_inv_array, out_pointers.size() * sizeof(T1 *),
MATX_ASYNC_DEVICE_MEMORY, stream);
matxAlloc((void **)&d_pivot,
a.Size(RANK - 1) * in_pointers.size() * sizeof(*d_info),
MATX_ASYNC_DEVICE_MEMORY, stream);
if (a.Size(RANK-1) > BATCHED_SINGLE_CALL_INV_THRESHOLD) {
// The single function inverse calls do not save the pivots
matxAlloc((void **)&d_pivot,
a.Size(RANK - 1) * in_pointers.size() * sizeof(*d_info),
MATX_ASYNC_DEVICE_MEMORY, stream);
}
matxAlloc((void **)&d_info, in_pointers.size() * sizeof(*d_info),
MATX_ASYNC_DEVICE_MEMORY, stream);
matxAlloc((void **)&h_info, in_pointers.size() * sizeof(*h_info),
Expand All @@ -174,6 +199,15 @@ class matxInversePlan_t {
}
}

static inline bool UseInputWorkBuffer(const TensorTypeA &a)
{
if constexpr (!is_tensor_view_v<TensorTypeA>) {
return true;
} else {
return a.Size(RANK-1) > BATCHED_SINGLE_CALL_INV_THRESHOLD || !a.IsContiguous();
}
}

static InverseParams_t GetInverseParams(TensorTypeAInv &a_inv,
const TensorTypeA &a,
cudaStream_t stream)
Expand Down Expand Up @@ -211,11 +245,11 @@ class matxInversePlan_t {
*/
~matxInversePlan_t()
{
matxFree(d_A_array, cudaStreamDefault);
matxFree(d_A_inv_array, cudaStreamDefault);
matxFree(d_pivot, cudaStreamDefault);
matxFree(d_info, cudaStreamDefault);
matxFree(h_info);
if (d_A_array) { matxFree(d_A_array, cudaStreamDefault); d_A_array = nullptr; }
if (d_A_inv_array) { matxFree(d_A_inv_array, cudaStreamDefault); d_A_inv_array = nullptr; }
if (d_pivot) { matxFree(d_pivot, cudaStreamDefault); d_pivot = nullptr; }
if (d_info) { matxFree(d_info, cudaStreamDefault); d_info = nullptr; }
if (h_info) { matxFree(h_info); h_info = nullptr; }

cublasDestroy(handle);
}
Expand All @@ -230,6 +264,10 @@ class matxInversePlan_t {
*
* @tparam T1
* Type of matrix A
* @param a_inv
* Output tensor or operator into which the inverse of A is written, if it exists
* @param a
* Input tensor or operator for which the inverse will be computed, if it exists
* @param stream
* CUDA stream
*
Expand All @@ -240,82 +278,126 @@ class matxInversePlan_t {

cublasSetStream(handle, stream);

(a_workbuf = a).run(stream);
if (UseInputWorkBuffer(a)) {
(a_workbuf = a).run(stream);
}

if constexpr (ALGO == MAT_INVERSE_ALGO_LU) {
if constexpr (std::is_same_v<T1, float>) {
ret =
cublasSgetrfBatched(handle, static_cast<int>(params.n), d_A_array, static_cast<int>(params.n), d_pivot,
d_info, static_cast<int>(params.batch_size));
}
else if constexpr (std::is_same_v<T1, double>) {
ret =
cublasDgetrfBatched(handle, static_cast<int>(params.n), d_A_array, static_cast<int>(params.n), d_pivot,
d_info, static_cast<int>(params.batch_size));
}
else if constexpr (std::is_same_v<T1, cuda::std::complex<float>>) {
ret =
cublasCgetrfBatched(handle, static_cast<int>(params.n),
reinterpret_cast<cuComplex *const *>(d_A_array),
static_cast<int>(params.n), d_pivot, d_info,
static_cast<int>(params.batch_size));
}
else if constexpr (std::is_same_v<T1, cuda::std::complex<double>>) {
ret = cublasZgetrfBatched(
handle, static_cast<int>(params.n),
reinterpret_cast<cuDoubleComplex *const *>(d_A_array),
static_cast<int>(params.n), d_pivot, d_info,
static_cast<int>(params.batch_size));
}
if (a.Size(TensorTypeA::Rank()-1) > BATCHED_SINGLE_CALL_INV_THRESHOLD) {
if constexpr (std::is_same_v<T1, float>) {
ret =
cublasSgetrfBatched(handle, static_cast<int>(params.n), d_A_array, static_cast<int>(params.n), d_pivot,
d_info, static_cast<int>(params.batch_size));
}
else if constexpr (std::is_same_v<T1, double>) {
ret =
cublasDgetrfBatched(handle, static_cast<int>(params.n), d_A_array, static_cast<int>(params.n), d_pivot,
d_info, static_cast<int>(params.batch_size));
}
else if constexpr (std::is_same_v<T1, cuda::std::complex<float>>) {
ret =
cublasCgetrfBatched(handle, static_cast<int>(params.n),
reinterpret_cast<cuComplex *const *>(d_A_array),
static_cast<int>(params.n), d_pivot, d_info,
static_cast<int>(params.batch_size));
}
else if constexpr (std::is_same_v<T1, cuda::std::complex<double>>) {
ret = cublasZgetrfBatched(
handle, static_cast<int>(params.n),
reinterpret_cast<cuDoubleComplex *const *>(d_A_array),
static_cast<int>(params.n), d_pivot, d_info,
static_cast<int>(params.batch_size));
}

MATX_ASSERT(ret == CUBLAS_STATUS_SUCCESS, matxLUError);
MATX_ASSERT(ret == CUBLAS_STATUS_SUCCESS, matxLUError);

cudaMemcpyAsync(h_info, d_info, sizeof(int) * params.batch_size, cudaMemcpyDeviceToHost, stream);
cudaStreamSynchronize(stream);
for (size_t i = 0; i < params.batch_size; i++) {
if (h_info[i] != 0) {
MATX_THROW(matxLUError, "inverse failed");
cudaMemcpyAsync(h_info, d_info, sizeof(int) * params.batch_size, cudaMemcpyDeviceToHost, stream);
cudaStreamSynchronize(stream);
for (size_t i = 0; i < params.batch_size; i++) {
if (h_info[i] != 0) {
MATX_THROW(matxLUError, "inverse failed");
}
}
}

if constexpr (std::is_same_v<T1, float>) {
ret = cublasSgetriBatched(handle, static_cast<int>(params.n), d_A_array,
static_cast<int>(params.n), d_pivot,
d_A_inv_array, static_cast<int>(params.n),
d_info, static_cast<int>(params.batch_size));
}
else if constexpr (std::is_same_v<T1, double>) {
ret = cublasDgetriBatched(handle, static_cast<int>(params.n), d_A_array,
static_cast<int>(params.n), d_pivot,
d_A_inv_array, static_cast<int>(params.n),
d_info, static_cast<int>(params.batch_size));
}
else if constexpr (std::is_same_v<T1, cuda::std::complex<float>>) {
ret = cublasCgetriBatched(
handle, static_cast<int>(params.n),
reinterpret_cast<cuComplex *const *>(d_A_array),
static_cast<int>(params.n), d_pivot,
reinterpret_cast<cuComplex *const *>(d_A_inv_array),
static_cast<int>(params.n), d_info,
static_cast<int>(params.batch_size));
}
else if constexpr (std::is_same_v<T1, cuda::std::complex<double>>) {
ret = cublasZgetriBatched(
handle, static_cast<int>(params.n),
reinterpret_cast<cuDoubleComplex *const *>(d_A_array),
static_cast<int>(params.n), d_pivot,
reinterpret_cast<cuDoubleComplex *const *>(d_A_inv_array),
static_cast<int>(params.n), d_info,
static_cast<int>(params.batch_size));
}
if constexpr (std::is_same_v<T1, float>) {
ret = cublasSgetriBatched(handle, static_cast<int>(params.n), d_A_array,
static_cast<int>(params.n), d_pivot,
d_A_inv_array, static_cast<int>(params.n),
d_info, static_cast<int>(params.batch_size));
}
else if constexpr (std::is_same_v<T1, double>) {
ret = cublasDgetriBatched(handle, static_cast<int>(params.n), d_A_array,
static_cast<int>(params.n), d_pivot,
d_A_inv_array, static_cast<int>(params.n),
d_info, static_cast<int>(params.batch_size));
}
else if constexpr (std::is_same_v<T1, cuda::std::complex<float>>) {
ret = cublasCgetriBatched(
handle, static_cast<int>(params.n),
reinterpret_cast<cuComplex *const *>(d_A_array),
static_cast<int>(params.n), d_pivot,
reinterpret_cast<cuComplex *const *>(d_A_inv_array),
static_cast<int>(params.n), d_info,
static_cast<int>(params.batch_size));
}
else if constexpr (std::is_same_v<T1, cuda::std::complex<double>>) {
ret = cublasZgetriBatched(
handle, static_cast<int>(params.n),
reinterpret_cast<cuDoubleComplex *const *>(d_A_array),
static_cast<int>(params.n), d_pivot,
reinterpret_cast<cuDoubleComplex *const *>(d_A_inv_array),
static_cast<int>(params.n), d_info,
static_cast<int>(params.batch_size));
}

MATX_ASSERT(ret == CUBLAS_STATUS_SUCCESS, matxInverseError);
MATX_ASSERT(ret == CUBLAS_STATUS_SUCCESS, matxInverseError);

cudaMemcpyAsync(h_info, d_info, sizeof(int) * params.batch_size, cudaMemcpyDeviceToHost, stream);
cudaStreamSynchronize(stream);
for (size_t i = 0; i < params.batch_size; i++) {
if (h_info[i] != 0) {
MATX_THROW(matxLUError, "inverse failed");
cudaMemcpyAsync(h_info, d_info, sizeof(int) * params.batch_size, cudaMemcpyDeviceToHost, stream);
cudaStreamSynchronize(stream);
for (size_t i = 0; i < params.batch_size; i++) {
if (h_info[i] != 0) {
MATX_THROW(matxLUError, "inverse failed");
}
}
} else {
// For linear systems of size <= BATCHED_SINGLE_CALL_INV_THRESHOLD, we can use the more
// efficient single call inverse functions.
if constexpr (std::is_same_v<T1, float>) {
ret = cublasSmatinvBatched(handle, static_cast<int>(params.n), d_A_array,
static_cast<int>(params.n),
d_A_inv_array, static_cast<int>(params.n),
d_info, static_cast<int>(params.batch_size));
}
else if constexpr (std::is_same_v<T1, double>) {
ret = cublasDmatinvBatched(handle, static_cast<int>(params.n), d_A_array,
static_cast<int>(params.n),
d_A_inv_array, static_cast<int>(params.n),
d_info, static_cast<int>(params.batch_size));
}
else if constexpr (std::is_same_v<T1, cuda::std::complex<float>>) {
ret = cublasCmatinvBatched(handle, static_cast<int>(params.n),
reinterpret_cast<cuComplex *const *>(d_A_array),
static_cast<int>(params.n),
reinterpret_cast<cuComplex *const *>(d_A_inv_array),
static_cast<int>(params.n),
d_info, static_cast<int>(params.batch_size));
}
else if constexpr (std::is_same_v<T1, cuda::std::complex<double>>) {
ret = cublasZmatinvBatched(handle, static_cast<int>(params.n),
reinterpret_cast<cuDoubleComplex *const *>(d_A_array),
static_cast<int>(params.n),
reinterpret_cast<cuDoubleComplex *const *>(d_A_inv_array),
static_cast<int>(params.n),
d_info, static_cast<int>(params.batch_size));
}
MATX_ASSERT(ret == CUBLAS_STATUS_SUCCESS, matxInverseError);

cudaMemcpyAsync(h_info, d_info, sizeof(int) * params.batch_size, cudaMemcpyDeviceToHost, stream);
cudaStreamSynchronize(stream);
for (size_t i = 0; i < params.batch_size; i++) {
if (h_info[i] != 0) {
MATX_THROW(matxLUError, "inverse failed");
}
}
}
}
Expand All @@ -328,11 +410,11 @@ class matxInversePlan_t {
InverseParams_t params;
cublasHandle_t handle;
matx::tensor_t<typename TensorTypeA::value_type, TensorTypeA::Rank()> a_workbuf;
int *d_pivot;
int *d_info;
int *h_info;
T1 **d_A_array;
T1 **d_A_inv_array;
int *d_pivot { nullptr };
int *d_info { nullptr };
int *h_info { nullptr };
T1 **d_A_array { nullptr };
T1 **d_A_inv_array { nullptr };
};

/**
Expand Down
Loading