Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions api/src/main/java/ai/djl/ndarray/NDArray.java
Original file line number Diff line number Diff line change
Expand Up @@ -4215,6 +4215,66 @@ default NDArray argSort(int axis) {
*/
NDArray cumSum(int axis);

/**
* Return specified diagonals.
*
* <p>Examples
*
* <pre>
* jshell&gt; NDArray array = manager.arange(9.0f).reshape(3, 3);
* jshell&gt; array.diagonal();
* ND: (3) cpu() float32
* [0., 4., 8.]
* </pre>
*
* @return specified diagonals
*/
NDArray diagonal();

/**
* Return specified diagonals.
*
* <p>Examples
*
* <pre>
* jshell&gt; NDArray array = manager.arange(9.0f).reshape(3, 3);
* jshell&gt; array.diagonal(1);
* ND: (2) cpu() float32
* [1., 5.]
* jshell&gt; array.diagonal(-1);
* ND: (2) cpu() float32
* [3., 7.]
* </pre>
*
* @param offset Offset of the diagonal from the main diagonal. Can be positive or negative.
* @return specified diagonals
*/
NDArray diagonal(int offset);

/**
* Return specified diagonals.
*
* <p>Examples
*
* <pre>
* jshell&gt; NDArray array = manager.arange(27f).reshape(3, 3, 3);
* jshell&gt; array.diagonal(0, 1, 2);
* ND: (3, 3) cpu() float32
* [[ 0., 4., 8.],
* [ 9., 13., 17.],
* [18., 22., 26.],
* ]
* </pre>
*
* @param offset Offset of the diagonal from the main diagonal. Can be positive or negative.
* @param axis1 Axis to be used as the first axis of the 2-D sub-arrays from which the diagonals
* should be taken.
* @param axis2 Axis to be used as the second axis of the 2-D sub-arrays from which the
* diagonals should be taken.
* @return specified diagonals
*/
NDArray diagonal(int offset, int axis1, int axis2);

/**
* Replace the handle of the NDArray with the other. The NDArray used for replacement will be
* killed.
Expand Down
18 changes: 18 additions & 0 deletions api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java
Original file line number Diff line number Diff line change
Expand Up @@ -1038,6 +1038,24 @@ public NDArray cumSum(int axis) {
return getAlternativeArray().cumSum(axis);
}

/** {@inheritDoc} */
@Override
public NDArray diagonal() {
return getAlternativeArray().diagonal(0);
}

/** {@inheritDoc} */
@Override
public NDArray diagonal(int offset) {
return getAlternativeArray().diagonal(offset);
}

/** {@inheritDoc} */
@Override
public NDArray diagonal(int offset, int axis1, int axis2) {
return getAlternativeArray().diagonal(offset, axis1, axis2);
}

/** {@inheritDoc} */
@Override
public NDArray isInfinite() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1353,6 +1353,24 @@ public NDArray cumSum(int axis) {
return manager.invoke("_np_cumsum", this, params);
}

/** {@inheritDoc} */
@Override
public NDArray diagonal() {
throw new UnsupportedOperationException();
}

/** {@inheritDoc} */
@Override
public NDArray diagonal(int offset) {
throw new UnsupportedOperationException();
}

/** {@inheritDoc} */
@Override
public NDArray diagonal(int offset, int axis1, int axis2) {
throw new UnsupportedOperationException();
}

/** {@inheritDoc} */
@Override
public void intern(NDArray replaced) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1288,6 +1288,24 @@ public PtNDArray cumSum(int axis) {
return JniUtils.cumSum(this, axis);
}

/** {@inheritDoc} */
@Override
public NDArray diagonal() {
return JniUtils.diagonal(this, 0, 0, 1);
}

/** {@inheritDoc} */
@Override
public NDArray diagonal(int offset) {
return JniUtils.diagonal(this, offset, 0, 1);
}

/** {@inheritDoc} */
@Override
public NDArray diagonal(int offset, int axis1, int axis2) {
return JniUtils.diagonal(this, offset, axis1, axis2);
}

/** {@inheritDoc} */
@Override
public void intern(NDArray replaced) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -943,6 +943,12 @@ public static PtNDArray cumSum(PtNDArray ndArray, long dim) {
ndArray.getManager(), PyTorchLibrary.LIB.torchCumSum(ndArray.getHandle(), dim));
}

public static PtNDArray diagonal(PtNDArray ndArray, long offset, long axis1, long axis2) {
return new PtNDArray(
ndArray.getManager(),
PyTorchLibrary.LIB.torchDiagonal(ndArray.getHandle(), offset, axis1, axis2));
}

public static PtNDArray oneHot(PtNDArray ndArray, int depth, DataType dataType) {
return new PtNDArray(
ndArray.getManager(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,8 @@ native void torchIndexPut(

native long torchCumSum(long handle, long dim);

native long torchDiagonal(long handle, long offset, long axis1, long axis2);

native long torchFlatten(long handle, long startDim, long endDim);

native long torchFft(long handle, long length, long axis);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,3 +137,12 @@ JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchViewAsComple
API_END_RETURN()
#endif
}

JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchDiagonal(
JNIEnv* env, jobject jthis, jlong jhandle, jlong offset, jlong axis1, jlong axis2) {
API_BEGIN()
const auto* tensor_ptr = reinterpret_cast<torch::Tensor*>(jhandle);
const auto* result_ptr = new torch::Tensor(tensor_ptr->diagonal(offset, axis1, axis2));
return reinterpret_cast<uintptr_t>(result_ptr);
API_END_RETURN()
}
Original file line number Diff line number Diff line change
Expand Up @@ -1463,6 +1463,24 @@ public NDArray cumSum() {
return cumSum(0);
}

/** {@inheritDoc} */
@Override
public NDArray diagonal() {
return manager.opExecutor("DiagPart").addInput(this).buildSingletonOrThrow();
}

/** {@inheritDoc} */
@Override
public NDArray diagonal(int offset) {
throw new UnsupportedOperationException();
}

/** {@inheritDoc} */
@Override
public NDArray diagonal(int offset, int axis1, int axis2) {
throw new UnsupportedOperationException();
}

/** {@inheritDoc} */
@Override
public void intern(NDArray replaced) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1276,6 +1276,24 @@ public RsNDArray cumSum(int axis) {
return toArray(RustLibrary.cumSum(getHandle(), axis));
}

/** {@inheritDoc} */
@Override
public NDArray diagonal() {
throw new UnsupportedOperationException();
}

/** {@inheritDoc} */
@Override
public NDArray diagonal(int offset) {
throw new UnsupportedOperationException();
}

/** {@inheritDoc} */
@Override
public NDArray diagonal(int offset, int axis1, int axis2) {
throw new UnsupportedOperationException();
}

/** {@inheritDoc} */
@Override
public void intern(NDArray replaced) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,32 @@ public void testCumsum() {
}
}

@Test
public void testDiagonal() {
try (NDManager manager = NDManager.newBaseManager(TestUtils.getEngine())) {
NDArray array = manager.arange(4.0f).reshape(2, 2);
NDArray expected = manager.create(new float[] {0f, 3f});
Assert.assertEquals(array.diagonal(), expected);

array = manager.arange(9.0f).reshape(3, 3);
expected = manager.create(new float[] {0f, 4f, 8f});
Assert.assertEquals(array.diagonal(), expected);

array = manager.arange(9.0f).reshape(3, 3);
expected = manager.create(new float[] {1f, 5f});
Assert.assertEquals(array.diagonal(1), expected);

array = manager.arange(9.0f).reshape(3, 3);
expected = manager.create(new float[] {3f, 7f});
Assert.assertEquals(array.diagonal(-1), expected);

array = manager.arange(27f).reshape(3, 3, 3);
expected =
manager.create(new float[][] {{0f, 4f, 8f}, {9f, 13f, 17f}, {18f, 22f, 26f}});
Assert.assertEquals(array.diagonal(0, 1, 2), expected);
}
}

@Test(expectedExceptions = IllegalArgumentException.class)
public void testTile() {
try (NDManager manager = NDManager.newBaseManager(TestUtils.getEngine())) {
Expand Down
Loading