Skip to content

Commit 0fb597a

Browse files
committed
NVPL FFT support
Switched all unit tests to use tuple with executor and type
1 parent c3420a0 commit 0fb597a

90 files changed

Lines changed: 2545 additions & 1580 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

CMakeLists.txt

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ option(MATX_EN_VISUALIZATION "Enable visualization support" OFF)
2121
option(MATX_EN_CUTLASS OFF)
2222
option(MATX_EN_CUTENSOR OFF)
2323
option(MATX_EN_FILEIO OFF)
24+
option(MATX_EN_NVPL OFF, "Enable NVIDIA Performance Libraries for optimized ARM CPU support")
2425
option(MATX_DISABLE_CUB_CACHE "Disable caching for CUB allocations" ON)
2526

2627
set(MATX_EN_PYBIND11 OFF CACHE BOOL "Enable pybind11 support")
@@ -152,6 +153,15 @@ else()
152153
target_compile_definitions(matx INTERFACE MATX_ENABLE_CUTLASS=0)
153154
endif()
154155

156+
if (MATX_EN_NVPL)
157+
message(STATUS "Enabling NVPL library support")
158+
# find_package is currently broken in NVPL. Use proper targets once working
159+
#find_package(nvpl REQUIRED COMPONENTS fft)
160+
#target_link_libraries(matx INTERFACE nvpl::fftw)
161+
target_link_libraries(matx INTERFACE nvpl_fftw)
162+
target_compile_definitions(matx INTERFACE MATX_EN_NVPL=1)
163+
endif()
164+
155165
if (MATX_DISABLE_CUB_CACHE)
156166
target_compile_definitions(matx INTERFACE MATX_DISABLE_CUB_CACHE=1)
157167
endif()
@@ -291,4 +301,3 @@ if (MATX_BUILD_TESTS)
291301
include(cmake/GetGTest.cmake)
292302
add_subdirectory(test)
293303
endif()
294-

docs_input/api/dft/fft/fft2d.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@ Perform a 2D FFT
99
These functions are currently not supported with host-based executors (CPU)
1010

1111

12-
.. doxygenfunction:: fft2(OpA &&a)
13-
.. doxygenfunction:: fft2(OpA &&a, const int32_t (&axis)[2])
12+
.. doxygenfunction:: fft2(OpA &&a, FFTNorm norm = FFTNorm::BACKWARD)
13+
.. doxygenfunction:: fft2(OpA &&a, const int32_t (&axis)[2], FFTNorm norm = FFTNorm::BACKWARD)
1414

1515
Examples
1616
~~~~~~~~

docs_input/api/dft/fft/ifft2.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@ Perform a 2D inverse FFT
99
These functions are currently not supported with host-based executors (CPU)
1010

1111

12-
.. doxygenfunction:: ifft2(OpA &&a)
13-
.. doxygenfunction:: ifft2(OpA &&a, const int32_t (&axis)[2])
12+
.. doxygenfunction:: ifft2(OpA &&a, FFTNorm norm = FFTNorm::BACKWARD)
13+
.. doxygenfunction:: ifft2(OpA &&a, const int32_t (&axis)[2], FFTNorm norm = FFTNorm::BACKWARD)
1414

1515
Examples
1616
~~~~~~~~

docs_input/api/manipulation/basic/copy.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ since it cannot be chained with other expressions.
1414
Examples
1515
~~~~~~~~
1616

17-
.. literalinclude:: ../../../../include/matx/transforms/fft.h
17+
.. literalinclude:: ../../../../include/matx/transforms/fft/fft_common.h
1818
:language: cpp
1919
:start-after: example-begin copy-test-1
2020
:end-before: example-end copy-test-1

docs_input/build.rst

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,16 @@ Optional Third-party Dependencies
4343
- `cutensor <https://developer.nvidia.com/cutensor>`_ 1.7.0.1+ (Required when using `einsum`)
4444
- `cutensornet <https://docs.nvidia.com/cuda/cuquantum/cutensornet>`_ 23.03.0.20+ (Required when using `einsum`)
4545

46+
Host (CPU) Support
47+
------------------
48+
Host support is provided both by the C++ standard library and NVIDIA's NVPL_ library. Host support is
49+
considered experimental and is still a work in progress. Currently all reduction functions are supported,
50+
but only FFT transforms are supported. All host support is limited to a single thread in this release.
51+
52+
To enable NVPL support use the CMake option `-DMATX_EN_NVPL=ON`.
53+
54+
.. _NVPL: https://developer.nvidia.com/nvpl
55+
4656
Build Options
4757
=============
4858
MatX provides 5 primary options for builds, and each can be configured independently:

include/matx/core/error.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,8 @@ namespace matx
6565
matxLUError,
6666
matxInverseError,
6767
matxSolverError,
68-
matxcuTensorError
68+
matxcuTensorError,
69+
matxInvalidExecutor
6970
};
7071

7172
static constexpr const char *matxErrorString(matxError_t e)

include/matx/core/tensor.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1486,6 +1486,7 @@ class tensor_t : public detail::tensor_impl_t<T,RANK,Desc> {
14861486

14871487
if constexpr (N > 0) {
14881488
if (end != matxDropDim) {
1489+
MATX_ASSERT_STR(end != matxKeepDim, matxInvalidParameter, "matxKeepDim only valid for clone(), not slice()");
14891490
if (end == matxEnd) {
14901491
n[d] = this->Size(i) - first;
14911492
}

include/matx/core/type_utils.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -272,8 +272,8 @@ constexpr bool is_executor_t()
272272

273273

274274
namespace detail {
275-
template<typename T> struct is_device_executor : std::false_type {};
276-
template<> struct is_device_executor<matx::cudaExecutor> : std::true_type {};
275+
template<typename T> struct is_cuda_executor : std::false_type {};
276+
template<> struct is_cuda_executor<matx::cudaExecutor> : std::true_type {};
277277
}
278278

279279
/**
@@ -282,11 +282,11 @@ template<> struct is_device_executor<matx::cudaExecutor> : std::true_type {};
282282
* @tparam T Type to test
283283
*/
284284
template <typename T>
285-
inline constexpr bool is_device_executor_v = detail::is_device_executor<typename remove_cvref<T>::type>::value;
285+
inline constexpr bool is_cuda_executor_v = detail::is_cuda_executor<typename remove_cvref<T>::type>::value;
286286

287287
namespace detail {
288-
template<typename T> struct is_single_thread_host_executor : std::false_type {};
289-
template<> struct is_single_thread_host_executor<matx::HostExecutor> : std::true_type {};
288+
template<typename T> struct is_host_executor : std::false_type {};
289+
template<> struct is_host_executor<matx::HostExecutor> : std::true_type {};
290290
}
291291

292292
/**
@@ -295,7 +295,7 @@ template<> struct is_single_thread_host_executor<matx::HostExecutor> : std::true
295295
* @tparam T Type to test
296296
*/
297297
template <typename T>
298-
inline constexpr bool is_single_thread_host_executor_v = detail::is_single_thread_host_executor<remove_cvref_t<T>>::value;
298+
inline constexpr bool is_host_executor_v = detail::is_host_executor<remove_cvref_t<T>>::value;
299299

300300

301301
namespace detail {

include/matx/executors/device.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ namespace matx
6666
/*
6767
* @breif Returns stream associated with executor
6868
*/
69-
auto getStream() { return stream_; }
69+
auto getStream() const { return stream_; }
7070

7171
/**
7272
* Execute an operator on a device

include/matx/executors/executors.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,5 +32,6 @@
3232

3333
#pragma once
3434

35+
#include "matx/executors/support.h"
3536
#include "matx/executors/device.h"
3637
#include "matx/executors/host.h"

0 commit comments

Comments
 (0)