Skip to content
  •  
  •  
  •  
15 changes: 8 additions & 7 deletions bench/00_operators/reduction.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ void softmax(nvbench::state &state, nvbench::type_list<ValueType>)
t4.PrefetchDevice(0);
t4out.PrefetchDevice(0);

softmax(t4out, t4, {3});
(t4out = softmax(t4, {3})).run();

state.exec(
[&t4, &t4out](nvbench::launch &launch) {
matx::softmax(t4out, t4, (cudaStream_t)launch.get_stream());
(t4out = softmax(t4)).run((cudaStream_t)launch.get_stream());
});
}
NVBENCH_BENCH_TYPES(softmax, NVBENCH_TYPE_AXES(softmax_types));
Expand All @@ -40,11 +40,11 @@ void reduce_0d_matx(nvbench::state &state, nvbench::type_list<ValueType>)
auto xv2 = make_tensor<ValueType>();
xv.PrefetchDevice(0);

matx::sum(xv2, xv);
(xv2 = matx::sum(xv)).run();

state.exec(
[&xv, &xv2](nvbench::launch &launch) {
matx::sum(xv2, xv, (cudaStream_t)launch.get_stream());
(xv2 = matx::sum(xv)).run((cudaStream_t)launch.get_stream());
});

}
Expand Down Expand Up @@ -74,11 +74,11 @@ void reduce_0d_cub(nvbench::state &state, nvbench::type_list<ValueType>)
auto xv2 = make_tensor<ValueType>();
xv.PrefetchDevice(0);

sum(xv2, xv, 0);
(xv2 = matx::sum(xv)).run();

state.exec(
[&xv, &xv2](nvbench::launch &launch) {
sum(xv2, xv, (cudaStream_t)launch.get_stream());
(xv2 = matx::sum(xv)).run((cudaStream_t)launch.get_stream());
});

}
Expand Down Expand Up @@ -141,7 +141,8 @@ void reduce_4d(
(t4 = random<float>(t4.Shape(), UNIFORM)).run();
cudaDeviceSynchronize();

state.exec([&t4, &t1](nvbench::launch &launch) { matx::sum(t1, t4, (cudaStream_t)launch.get_stream()); });
state.exec([&t4, &t1](nvbench::launch &launch) {
(t1 = matx::sum(t4, {1, 2, 3})).run((cudaStream_t)launch.get_stream()); });

}

Expand Down
8 changes: 4 additions & 4 deletions bench/00_transform/conv.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ void conv1d_4d_batch(nvbench::state &state,
cudaDeviceSynchronize();
MATX_NVTX_START_RANGE( "Exec", matx_nvxtLogLevels::MATX_NVTX_LOG_ALL, 1 )
state.exec(
[&out, &at, &bt](nvbench::launch &launch) { conv1d(out, at, bt, MATX_C_MODE_FULL, launch.get_stream()); });
[&out, &at, &bt](nvbench::launch &launch) { (out = conv1d(at, bt, MATX_C_MODE_FULL)).run(cudaExecutor(launch.get_stream())); });
MATX_NVTX_END_RANGE( 1 )

}
Expand All @@ -48,7 +48,7 @@ void conv1d_2d_batch(nvbench::state &state,
cudaDeviceSynchronize();

state.exec(
[&out, &at, &bt](nvbench::launch &launch) { conv1d(out, at, bt, MATX_C_MODE_FULL, launch.get_stream()); });
[&out, &at, &bt](nvbench::launch &launch) { (out = conv1d(at, bt, MATX_C_MODE_FULL)).run(cudaExecutor(launch.get_stream())); });
}
NVBENCH_BENCH_TYPES(conv1d_2d_batch, NVBENCH_TYPE_AXES(conv_types));

Expand All @@ -67,7 +67,7 @@ void conv1d_large(nvbench::state &state,
cudaDeviceSynchronize();

state.exec(
[&out, &at, &bt](nvbench::launch &launch) { conv1d(out, at, bt, MATX_C_MODE_FULL, launch.get_stream()); });
[&out, &at, &bt](nvbench::launch &launch) { (out = conv1d(at, bt, MATX_C_MODE_FULL)).run(cudaExecutor(launch.get_stream())); });
}
NVBENCH_BENCH_TYPES(conv1d_large, NVBENCH_TYPE_AXES(conv_types));

Expand All @@ -88,7 +88,7 @@ void conv2d_batch(nvbench::state &state,
cudaDeviceSynchronize();

state.exec(
[&out, &at, &bt](nvbench::launch &launch) { conv2d(out, at, bt, MATX_C_MODE_FULL, launch.get_stream()); });
[&out, &at, &bt](nvbench::launch &launch) { (out = conv2d(at, bt, MATX_C_MODE_FULL)).run(cudaExecutor(launch.get_stream())); });

auto seconds = state.get_summary("Batch GPU").get_float64("value");
auto &flops = state.add_summary("TFLOPS");
Expand Down
8 changes: 6 additions & 2 deletions bench/00_transform/cub.cu
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@ void sort1d(

(randomData = random<float>(sortedData.Shape(), NORMAL)).run();

state.exec( [&sortedData, &randomData](nvbench::launch &launch) { matx::sort(sortedData, randomData, SORT_DIR_ASC, (cudaStream_t)launch.get_stream()); });
state.exec( [&sortedData, &randomData](nvbench::launch &launch) {
(sortedData = matx::sort(randomData, SORT_DIR_ASC)).run(cudaExecutor(launch.get_stream()));
});

}

Expand Down Expand Up @@ -69,7 +71,9 @@ void sort2d(

(randomData = random<float>(sortedData.Shape(), NORMAL)).run();

state.exec( [&sortedData, &randomData](nvbench::launch &launch) { matx::sort(sortedData, randomData, SORT_DIR_ASC, (cudaStream_t)launch.get_stream()); });
state.exec( [&sortedData, &randomData](nvbench::launch &launch) {
(sortedData = matx::sort(randomData, SORT_DIR_ASC)).run(cudaExecutor(launch.get_stream()));
});

}

Expand Down
4 changes: 2 additions & 2 deletions bench/00_transform/einsum.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@ void einsum_permute(nvbench::state &state, nvbench::type_list<ValueType>)

x.PrefetchDevice(0);

cutensor::einsum(y, "ijkl->likj", 0, x);
(y = cutensor::einsum("ijkl->likj", x)).run();

state.exec(
[&x, &y](nvbench::launch &launch) {
cutensor::einsum(y, "ijkl->likj", (cudaStream_t)launch.get_stream(), x);
(y = cutensor::einsum("ijkl->likj", x)).run(cudaExecutor(launch.get_stream()));
});
}

Expand Down
6 changes: 3 additions & 3 deletions bench/00_transform/fft.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ void fft1d_no_batches_pow_2(nvbench::state &state,
xv.PrefetchDevice(0);

state.exec(
[&xv](nvbench::launch &launch) { fft(xv, xv, 0, launch.get_stream()); });
[&xv](nvbench::launch &launch) { (xv = fft(xv)).run(cudaExecutor(launch.get_stream())); });
}
NVBENCH_BENCH_TYPES(fft1d_no_batches_pow_2, NVBENCH_TYPE_AXES(fft_types))
.add_int64_power_of_two_axis("FFT size", nvbench::range(10, 18, 1));
Expand All @@ -35,7 +35,7 @@ void fft1d_no_batches_non_pow_2(nvbench::state &state,
xv.PrefetchDevice(0);

state.exec(
[&xv](nvbench::launch &launch) { fft(xv, xv, 0, launch.get_stream()); });
[&xv](nvbench::launch &launch) { (xv = fft(xv)).run(cudaExecutor(launch.get_stream())); });
}
NVBENCH_BENCH_TYPES(fft1d_no_batches_non_pow_2, NVBENCH_TYPE_AXES(fft_types))
.add_int64_axis("FFT size", nvbench::range(50000, 250000, 50000));
Expand All @@ -50,7 +50,7 @@ void fft1d_batches_pow_2(nvbench::state &state, nvbench::type_list<ValueType>)
xv.PrefetchDevice(0);

state.exec(
[&xv](nvbench::launch &launch) { fft(xv, xv, 0, launch.get_stream()); });
[&xv](nvbench::launch &launch) { (xv = fft(xv)).run(cudaExecutor(launch.get_stream())); });
}
NVBENCH_BENCH_TYPES(fft1d_batches_pow_2, NVBENCH_TYPE_AXES(fft_types))
.add_int64_power_of_two_axis("FFT size", nvbench::range(10, 18, 1));
2 changes: 1 addition & 1 deletion bench/00_transform/matmul.cu
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ void pow2_matmul_bench(nvbench::state &state, nvbench::type_list<ValueType>)
// Report throughput stats:

state.exec([&av, &bv, &cv](nvbench::launch &launch) {
matmul(cv, av, bv, launch.get_stream());
(cv = matmul(av, bv)).run(cudaExecutor(launch.get_stream()));
});

auto seconds =
Expand Down
4 changes: 2 additions & 2 deletions bench/00_transform/qr.cu
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,15 @@ void qr_batch(nvbench::state &state,

// warm up
nvtxRangePushA("Warmup");
qr(Q, R, A, stream);
(mtie(Q, R) = qr(A)).run(stream);

cudaDeviceSynchronize();
nvtxRangePop();

MATX_NVTX_START_RANGE( "Exec", matx_nvxtLogLevels::MATX_NVTX_LOG_ALL, 1 )
state.exec(
[&Q, &R, &A](nvbench::launch &launch) {
qr(Q, R, A, launch.get_stream()); });
(mtie(Q, R) = qr(A)).run(cudaExecutor{launch.get_stream()}); });
MATX_NVTX_END_RANGE( 1 )

}
Expand Down
8 changes: 4 additions & 4 deletions bench/00_transform/svd_power.cu
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,14 @@ void svdpi_batch(nvbench::state &state,

// warm up
nvtxRangePushA("Warmup");
svdpi(U, S, VT, A, x0, iterations, stream, r);
(mtie(U, S, VT) = svdpi(A, x0, iterations, r)).run(stream);
cudaDeviceSynchronize();
nvtxRangePop();

MATX_NVTX_START_RANGE( "Exec", matx_nvxtLogLevels::MATX_NVTX_LOG_ALL, 1 )
state.exec(
[&U, &S, &VT, &A, &x0, &iterations, &r](nvbench::launch &launch) {
svdpi(U, S, VT, A, x0, iterations, launch.get_stream(), r); });
(mtie(U, S, VT) = svdpi(A, x0, iterations, r)).run(cudaExecutor{launch.get_stream()}); });
MATX_NVTX_END_RANGE( 1 )

}
Expand Down Expand Up @@ -96,14 +96,14 @@ void svdbpi_batch(nvbench::state &state,

// warm up
nvtxRangePushA("Warmup");
svdbpi(U, S, VT, A, iterations, stream);
(mtie(U, S, VT) = svdbpi(A, iterations)).run(stream);
cudaDeviceSynchronize();
nvtxRangePop();

MATX_NVTX_START_RANGE( "Exec", matx_nvxtLogLevels::MATX_NVTX_LOG_ALL, 1 )
state.exec(
[&U, &S, &VT, &A, &iterations, &r](nvbench::launch &launch) {
svdbpi(U, S, VT, A, iterations, launch.get_stream()); });
(mtie(U, S, VT) = svdbpi(A, iterations)).run(cudaExecutor{launch.get_stream()}); });
MATX_NVTX_END_RANGE( 1 )
}

Expand Down
21 changes: 0 additions & 21 deletions docs/_sources/api/creation/operators/diag.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,27 +19,6 @@ Examples
:end-before: example-end diag-op-test-1
:dedent:

The generator form of ``diag()`` has both a shaped and a shape-less form. The shaped form takes a single argument specifying
the shape of the operator. This is useful when the operator is used in contexts where it must have a shape.
For example:

.. code-block:: cpp

auto krondiag = kron(diag({4, 4}, 5));

Without a shape, the ``kron`` operator would not be able to generate a Kronecker product, and will result
in a compiler error.

Shapeless is useful when the size is already known by another operator:

.. code-block:: cpp

auto t2 = make_tensor<float>({5, 5});
(t2 = diag(5)).run();

In the case above the lazy assignment of ``t2`` is done at runtime and will only request elements 0:5,0:5
since the number of elements fetched is dictated by the size of ``t2``.

Generator
_________

Expand Down
22 changes: 0 additions & 22 deletions docs/_sources/api/creation/operators/eye.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,30 +5,8 @@ eye

Generate an identity tensor

``eye()`` has both a shaped and a shape-less form. The shaped form takes a single argument specifying
the shape of the operator. This is useful when the operator is used in contexts where it must have a shape.
For example:

.. code-block:: cpp

auto kroneye = kron(eye({4, 4}));

Without a shape, the ``kron`` operator would not be able to generate a Kronecker product, and will result
in a compiler error.

Shapeless is useful when the size is already known by another operator:

.. code-block:: cpp

auto t2 = make_tensor<float>({5, 5});
(t2 = eye()).run();

In the case above the lazy assignment of ``t2`` is done at runtime and will only request elements 0:5,0:5
since the number of elements fetched is dictated by the size of ``t2``.

.. doxygenfunction:: matx::eye(ShapeType &&s)
.. doxygenfunction:: matx::eye(const index_t (&s)[RANK])
.. doxygenfunction:: matx::eye()

Examples
~~~~~~~~
Expand Down
2 changes: 1 addition & 1 deletion docs/_sources/api/creation/tensors/make.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ Return by Value
.. doxygenfunction:: make_tensor( const index_t (&shape)[RANK], matxMemorySpace_t space = MATX_MANAGED_MEMORY, cudaStream_t stream = 0)
.. doxygenfunction:: make_tensor( TensorType &tensor, const index_t (&shape)[TensorType::Rank()], matxMemorySpace_t space = MATX_MANAGED_MEMORY, cudaStream_t stream = 0)
.. doxygenfunction:: make_tensor( ShapeType &&shape, matxMemorySpace_t space = MATX_MANAGED_MEMORY, cudaStream_t stream = 0)
.. doxygenfunction:: make_tensor( TensorType &tensor, typename TensorType::shape_container &&shape, matxMemorySpace_t space = MATX_MANAGED_MEMORY, cudaStream_t stream = 0)
.. doxygenfunction:: make_tensor( TensorType &tensor, ShapeType &&shape, matxMemorySpace_t space = MATX_MANAGED_MEMORY, cudaStream_t stream = 0)
.. doxygenfunction:: make_tensor( matxMemorySpace_t space = MATX_MANAGED_MEMORY, cudaStream_t stream = 0)
.. doxygenfunction:: make_tensor( TensorType &tensor, matxMemorySpace_t space = MATX_MANAGED_MEMORY, cudaStream_t stream = 0)
.. doxygenfunction:: make_tensor( T *data, const index_t (&shape)[RANK], bool owning = false)
Expand Down
4 changes: 2 additions & 2 deletions docs/_sources/api/dft/fft/fft.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ Perform a 1D FFT
These functions are currently not supported with host-based executors (CPU)


.. doxygenfunction:: fft(OutputTensor o, const InputTensor i, uint64_t fft_size = 0, cudaStream_t stream = 0)
.. doxygenfunction:: fft(OpA &&a, uint64_t fft_size = 0)
.. doxygenfunction:: fft(OpA &&a, const int32_t (&axis)[1], uint64_t fft_size = 0)

Examples
~~~~~~~~
Expand All @@ -19,7 +20,6 @@ Examples
:end-before: example-end fft-1
:dedent:

.. doxygenfunction:: fft(OutputTensor out, const InputTensor in, const int32_t (&axis)[1], uint64_t fft_size = 0, cudaStream_t stream = 0)

.. literalinclude:: ../../../../test/00_transform/FFT.cu
:language: cpp
Expand Down
5 changes: 2 additions & 3 deletions docs/_sources/api/dft/fft/fft2d.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ Perform a 2D FFT
These functions are currently not supported with host-based executors (CPU)


.. doxygenfunction:: fft2(OutputTensor o, const InputTensor i, cudaStream_t stream = 0)
.. doxygenfunction:: fft2(OpA &&a)
.. doxygenfunction:: fft2(OpA &&a, const int32_t (&axis)[2])

Examples
~~~~~~~~
Expand All @@ -19,8 +20,6 @@ Examples
:end-before: example-end fft2-1
:dedent:

.. doxygenfunction:: fft2(OutputTensor out, const InputTensor in, const int (&axis)[2], cudaStream_t stream = 0)

.. literalinclude:: ../../../../test/00_transform/FFT.cu
:language: cpp
:start-after: example-begin fft2-2
Expand Down
4 changes: 2 additions & 2 deletions docs/_sources/api/dft/fft/ifft.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ Perform a 1D inverse FFT
These functions are currently not supported with host-based executors (CPU)


.. doxygenfunction:: ifft(OutputTensor o, const InputTensor i, uint64_t fft_size = 0, cudaStream_t stream = 0)
.. doxygenfunction:: ifft(OpA &&a, uint64_t fft_size = 0)
.. doxygenfunction:: ifft(OpA &&a, const int32_t (&axis)[1], uint64_t fft_size = 0)

Examples
~~~~~~~~
Expand All @@ -19,7 +20,6 @@ Examples
:end-before: example-end ifft-1
:dedent:

.. doxygenfunction:: ifft(OutputTensor out, const InputTensor in, const int (&axis)[1], uint64_t fft_size = 0, cudaStream_t stream = 0)

.. literalinclude:: ../../../../test/00_transform/FFT.cu
:language: cpp
Expand Down
5 changes: 2 additions & 3 deletions docs/_sources/api/dft/fft/ifft2.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ Perform a 2D inverse FFT
These functions are currently not supported with host-based executors (CPU)


.. doxygenfunction:: ifft2(OutputTensor o, const InputTensor i, cudaStream_t stream = 0)
.. doxygenfunction:: ifft2(OpA &&a)
.. doxygenfunction:: ifft2(OpA &&a, const int32_t (&axis)[2])

Examples
~~~~~~~~
Expand All @@ -19,8 +20,6 @@ Examples
:end-before: example-end ifft2-1
:dedent:

.. doxygenfunction:: ifft2(OutputTensor out, const InputTensor in, const int (&axis)[2], cudaStream_t stream = 0)

.. literalinclude:: ../../../../test/00_transform/FFT.cu
:language: cpp
:start-after: example-begin ifft2-2
Expand Down
2 changes: 1 addition & 1 deletion docs/_sources/api/linalg/decomp/inverse.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ getri/getrf functions for LU decomposition.
.. note::
This function is currently is not supported with host-based executors (CPU)

.. doxygenfunction:: inv(TensorTypeAInv &a_inv, const TensorTypeA &a, cudaStream_t stream = 0)
.. doxygenfunction:: inv(const OpA &a)

Examples
~~~~~~~~
Expand Down
15 changes: 14 additions & 1 deletion docs/_sources/api/linalg/matvec/einsum.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,15 @@ of ``einsum`` operations are:
* Inner products
* Transposes
* Reductions
* Trace

While many of these operations are possible using other methods in MatX, ``einsum`` typically has a
shorter syntax, and is sometimes more optimized than a direct version of the operation.

.. note::
Using einsum() requires a minimum of cuTENSOR 1.7.0 and cuTensorNet 23.03.0.20. These are downloaded
automatically as part of CMake, but for offline environments these versions are needed.

As of now, MatX only supports a limited set of ``einsum`` operations that would be supported in
the NumPy version. Specifically only tensor contractions, inner products, and GEMMs are supported
and tested at this time. MatX also does not support broadcast '...' notation and has no plans to. While
Expand Down Expand Up @@ -119,4 +124,12 @@ Sum
:language: cpp
:start-after: example-begin einsum-sum-1
:end-before: example-end einsum-sum-1
:dedent:
:dedent:

Trace
~~~~~
.. literalinclude:: ../../../../test/00_tensor/EinsumTests.cu
:language: cpp
:start-after: example-begin einsum-trace-1
:end-before: example-end einsum-trace-1
:dedent:
Loading