Skip to content

Commit fbc798a

Browse files
authored
Add implementation of dpnp.real_if_close (#2002)
* Add implementation of dpnp.real_if_close * Updated third party tests * Added CFD tests * Add more tests * State default value in the description * Add negative test for 'tol' keyword * Add proper link description per review comment
1 parent dea8ada commit fbc798a

File tree

7 files changed

+139
-21
lines changed

7 files changed

+139
-21
lines changed

dpnp/dpnp_iface_mathematical.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@
114114
"prod",
115115
"proj",
116116
"real",
117+
"real_if_close",
117118
"remainder",
118119
"rint",
119120
"round",
@@ -505,6 +506,10 @@ def _process_ediff1d_args(arg, arg_name, ary_dtype, ary_sycl_queue, usm_type):
505506
:obj:`dpnp.arctan2` : Element-wise arc tangent of `x1/x2` choosing the quadrant correctly.
506507
:obj:`dpnp.arctan` : Trigonometric inverse tangent, element-wise.
507508
:obj:`dpnp.absolute` : Calculate the absolute value element-wise.
509+
:obj:`dpnp.real` : Return the real part of the complex argument.
510+
:obj:`dpnp.imag` : Return the imaginary part of the complex argument.
511+
:obj:`dpnp.real_if_close` : Return the real part of the input is complex
512+
with all imaginary parts close to zero.
508513
509514
Examples
510515
--------
@@ -2201,6 +2206,9 @@ def gradient(f, *varargs, axis=None, edge_order=1):
22012206
See Also
22022207
--------
22032208
:obj:`dpnp.real` : Return the real part of the complex argument.
2209+
:obj:`dpnp.angle` : Return the angle of the complex argument.
2210+
:obj:`dpnp.real_if_close` : Return the real part of the input is complex
2211+
with all imaginary parts close to zero.
22042212
:obj:`dpnp.conj` : Return the complex conjugate, element-wise.
22052213
:obj:`dpnp.conjugate` : Return the complex conjugate, element-wise.
22062214
@@ -3054,6 +3062,28 @@ def prod(
30543062
the same data type. If the input is a complex floating-point
30553063
data type, the returned array has a floating-point data type
30563064
with the same floating-point precision as complex input.
3065+
3066+
See Also
3067+
--------
3068+
:obj:`dpnp.real_if_close` : Return the real part of the input is complex
3069+
with all imaginary parts close to zero.
3070+
:obj:`dpnp.imag` : Return the imaginary part of the complex argument.
3071+
:obj:`dpnp.angle` : Return the angle of the complex argument.
3072+
3073+
Examples
3074+
--------
3075+
>>> import dpnp as np
3076+
>>> a = np.array([1+2j, 3+4j, 5+6j])
3077+
>>> a.real
3078+
array([1., 3., 5.])
3079+
>>> a.real = 9
3080+
>>> a
3081+
array([9.+2.j, 9.+4.j, 9.+6.j])
3082+
>>> a.real = np.array([9, 8, 7])
3083+
>>> a
3084+
array([9.+2.j, 8.+4.j, 7.+6.j])
3085+
>>> np.real(np.array(1 + 1j))
3086+
array(1.)
30573087
"""
30583088

30593089
real = DPNPReal(
@@ -3064,6 +3094,69 @@ def prod(
30643094
)
30653095

30663096

3097+
def real_if_close(a, tol=100):
3098+
"""
3099+
If input is complex with all imaginary parts close to zero, return real
3100+
parts.
3101+
3102+
"Close to zero" is defined as `tol` * (machine epsilon of the type for `a`).
3103+
3104+
For full documentation refer to :obj:`numpy.real_if_close`.
3105+
3106+
Parameters
3107+
----------
3108+
a : {dpnp.ndarray, usm_ndarray}
3109+
Input array.
3110+
tol : scalar, optional
3111+
Tolerance in machine epsilons for the complex part of the elements in
3112+
the array. If the tolerance is <=1, then the absolute tolerance is used.
3113+
Default: ``100``.
3114+
3115+
Returns
3116+
-------
3117+
out : dpnp.ndarray
3118+
If `a` is real, the type of `a` is used for the output. If `a` has
3119+
complex elements, the returned type is float.
3120+
3121+
See Also
3122+
--------
3123+
:obj:`dpnp.real` : Return the real part of the complex argument.
3124+
:obj:`dpnp.imag` : Return the imaginary part of the complex argument.
3125+
:obj:`dpnp.angle` : Return the angle of the complex argument.
3126+
3127+
Examples
3128+
--------
3129+
>>> import dpnp as np
3130+
>>> np.finfo(np.float64).eps
3131+
2.220446049250313e-16 # may vary
3132+
3133+
>>> a = np.array([2.1 + 4e-14j, 5.2 + 3e-15j])
3134+
>>> np.real_if_close(a, tol=1000)
3135+
array([2.1, 5.2])
3136+
3137+
>>> a = np.array([2.1 + 4e-13j, 5.2 + 3e-15j])
3138+
>>> np.real_if_close(a, tol=1000)
3139+
array([2.1+4.e-13j, 5.2+3.e-15j])
3140+
3141+
"""
3142+
3143+
dpnp.check_supported_arrays_type(a)
3144+
3145+
if not dpnp.issubdtype(a.dtype, dpnp.complexfloating):
3146+
return a
3147+
3148+
if not dpnp.isscalar(tol):
3149+
raise TypeError(f"Tolerance must be a scalar, but got {type(tol)}")
3150+
3151+
if tol > 1:
3152+
f = dpnp.finfo(a.dtype.type)
3153+
tol = f.eps * tol
3154+
3155+
if dpnp.all(dpnp.abs(a.imag) < tol):
3156+
return a.real
3157+
return a
3158+
3159+
30673160
_REMAINDER_DOCSTRING = """
30683161
Calculates the remainder of division for each element `x1_i` of the input array
30693162
`x1` with the respective element `x2_i` of the input array `x2`.

tests/skipped_tests.tbl

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -182,14 +182,6 @@ tests/third_party/cupy/manipulation_tests/test_dims.py::TestInvalidBroadcast_par
182182
tests/third_party/cupy/manipulation_tests/test_dims.py::TestInvalidBroadcast_param_2_{shapes=[(3, 2), (3, 4)]}::test_invalid_broadcast
183183
tests/third_party/cupy/manipulation_tests/test_dims.py::TestInvalidBroadcast_param_3_{shapes=[(0,), (2,)]}::test_invalid_broadcast
184184

185-
tests/third_party/cupy/math_tests/test_misc.py::TestMisc::test_real_if_close_real_dtypes
186-
tests/third_party/cupy/math_tests/test_misc.py::TestMisc::test_real_if_close_with_tol_real_dtypes
187-
tests/third_party/cupy/math_tests/test_misc.py::TestMisc::test_real_if_close_true
188-
tests/third_party/cupy/math_tests/test_misc.py::TestMisc::test_real_if_close_false
189-
tests/third_party/cupy/math_tests/test_misc.py::TestMisc::test_real_if_close_with_integer_tol_true
190-
tests/third_party/cupy/math_tests/test_misc.py::TestMisc::test_real_if_close_with_integer_tol_false
191-
tests/third_party/cupy/math_tests/test_misc.py::TestMisc::test_real_if_close_with_float_tol_true
192-
tests/third_party/cupy/math_tests/test_misc.py::TestMisc::test_real_if_close_with_float_tol_false
193185
tests/third_party/cupy/math_tests/test_misc.py::TestMisc::test_interp
194186
tests/third_party/cupy/math_tests/test_misc.py::TestMisc::test_interp_period
195187
tests/third_party/cupy/math_tests/test_misc.py::TestMisc::test_interp_left_right

tests/skipped_tests_gpu.tbl

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -236,14 +236,6 @@ tests/third_party/cupy/manipulation_tests/test_dims.py::TestInvalidBroadcast_par
236236
tests/third_party/cupy/manipulation_tests/test_dims.py::TestInvalidBroadcast_param_2_{shapes=[(3, 2), (3, 4)]}::test_invalid_broadcast
237237
tests/third_party/cupy/manipulation_tests/test_dims.py::TestInvalidBroadcast_param_3_{shapes=[(0,), (2,)]}::test_invalid_broadcast
238238

239-
tests/third_party/cupy/math_tests/test_misc.py::TestMisc::test_real_if_close_real_dtypes
240-
tests/third_party/cupy/math_tests/test_misc.py::TestMisc::test_real_if_close_with_tol_real_dtypes
241-
tests/third_party/cupy/math_tests/test_misc.py::TestMisc::test_real_if_close_true
242-
tests/third_party/cupy/math_tests/test_misc.py::TestMisc::test_real_if_close_false
243-
tests/third_party/cupy/math_tests/test_misc.py::TestMisc::test_real_if_close_with_integer_tol_true
244-
tests/third_party/cupy/math_tests/test_misc.py::TestMisc::test_real_if_close_with_integer_tol_false
245-
tests/third_party/cupy/math_tests/test_misc.py::TestMisc::test_real_if_close_with_float_tol_true
246-
tests/third_party/cupy/math_tests/test_misc.py::TestMisc::test_real_if_close_with_float_tol_false
247239
tests/third_party/cupy/math_tests/test_misc.py::TestMisc::test_interp
248240
tests/third_party/cupy/math_tests/test_misc.py::TestMisc::test_interp_period
249241
tests/third_party/cupy/math_tests/test_misc.py::TestMisc::test_interp_left_right

tests/test_mathematical.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1377,6 +1377,41 @@ def test_f16_corner_values_with_scalar(self, val, scalar):
13771377
assert_equal(result, expected)
13781378

13791379

1380+
class TestRealIfClose:
1381+
@pytest.mark.parametrize("dt", get_all_dtypes(no_none=True))
1382+
def test_basic(self, dt):
1383+
a = numpy.random.rand(10).astype(dt)
1384+
ia = dpnp.array(a)
1385+
1386+
result = dpnp.real_if_close(ia + 1e-15j)
1387+
expected = numpy.real_if_close(a + 1e-15j)
1388+
assert_equal(result, expected)
1389+
1390+
@pytest.mark.parametrize("dt", get_float_dtypes())
1391+
def test_singlecomplex(self, dt):
1392+
a = numpy.random.rand(10).astype(dt)
1393+
ia = dpnp.array(a)
1394+
1395+
result = dpnp.real_if_close(ia + 1e-7j)
1396+
expected = numpy.real_if_close(a + 1e-7j)
1397+
assert_equal(result, expected)
1398+
1399+
@pytest.mark.parametrize("dt", get_float_dtypes())
1400+
def test_tol(self, dt):
1401+
a = numpy.random.rand(10).astype(dt)
1402+
ia = dpnp.array(a)
1403+
1404+
result = dpnp.real_if_close(ia + 1e-7j, tol=1e-6)
1405+
expected = numpy.real_if_close(a + 1e-7j, tol=1e-6)
1406+
assert_equal(result, expected)
1407+
1408+
@pytest.mark.parametrize("xp", [numpy, dpnp])
1409+
@pytest.mark.parametrize("tol_val", [[10], (1, 2)], ids=["list", "tuple"])
1410+
def test_wrong_tol_type(self, xp, tol_val):
1411+
a = xp.array([2.1 + 4e-14j, 5.2 + 3e-15j])
1412+
assert_raises(TypeError, xp.real_if_close, a, tol=tol_val)
1413+
1414+
13801415
class TestUnwrap:
13811416
@pytest.mark.parametrize("dt", get_float_dtypes())
13821417
def test_basic(self, dt):

tests/test_sycl_queue.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -493,6 +493,7 @@ def test_meshgrid(device):
493493
pytest.param(
494494
"real", [complex(1.0, 2.0), complex(3.0, 4.0), complex(5.0, 6.0)]
495495
),
496+
pytest.param("real_if_close", [2.1 + 4e-15j, 5.2 + 3e-16j]),
496497
pytest.param("reciprocal", [1.0, 2.0, 4.0, 7.0]),
497498
pytest.param("sign", [-5.0, 0.0, 4.5]),
498499
pytest.param("signbit", [-5.0, 0.0, 4.5]),

tests/test_usm_type.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -618,6 +618,7 @@ def test_norm(usm_type, ord, axis):
618618
pytest.param(
619619
"real", [complex(1.0, 2.0), complex(3.0, 4.0), complex(5.0, 6.0)]
620620
),
621+
pytest.param("real_if_close", [2.1 + 4e-15j, 5.2 + 3e-16j]),
621622
pytest.param("reciprocal", [1.0, 2.0, 4.0, 7.0]),
622623
pytest.param("reduce_hypot", [1.0, 2.0, 4.0, 7.0]),
623624
pytest.param("rsqrt", [1, 8, 27]),

tests/third_party/cupy/math_tests/test_misc.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ def test_nan_to_num_inf(self):
245245
def test_nan_to_num_nan(self):
246246
self.check_unary_nan("nan_to_num")
247247

248-
@pytest.mark.skip(reason="Scalar input is not supported")
248+
@pytest.mark.skip(reason="scalar input is not supported")
249249
@testing.numpy_cupy_allclose(atol=1e-5)
250250
def test_nan_to_num_scalar_nan(self, xp):
251251
return xp.nan_to_num(xp.nan)
@@ -301,7 +301,8 @@ def test_real_if_close_with_tol_real_dtypes(self, xp, dtype):
301301
def test_real_if_close_true(self, xp, dtype):
302302
dtype = numpy.dtype(dtype).char.lower()
303303
tol = numpy.finfo(dtype).eps * 90
304-
x = testing.shaped_random((10,), xp, dtype) + tol * 1j
304+
x = testing.shaped_random((10,), xp, dtype)
305+
x = xp.add(x, tol * 1j, dtype=xp.result_type(x, 1j))
305306
out = xp.real_if_close(x)
306307
assert x.dtype != out.dtype
307308
return out
@@ -311,7 +312,8 @@ def test_real_if_close_true(self, xp, dtype):
311312
def test_real_if_close_false(self, xp, dtype):
312313
dtype = numpy.dtype(dtype).char.lower()
313314
tol = numpy.finfo(dtype).eps * 110
314-
x = testing.shaped_random((10,), xp, dtype) + tol * 1j
315+
x = testing.shaped_random((10,), xp, dtype)
316+
x = xp.add(x, tol * 1j, dtype=xp.result_type(x, 1j))
315317
out = xp.real_if_close(x)
316318
assert x.dtype == out.dtype
317319
return out
@@ -321,7 +323,8 @@ def test_real_if_close_false(self, xp, dtype):
321323
def test_real_if_close_with_integer_tol_true(self, xp, dtype):
322324
dtype = numpy.dtype(dtype).char.lower()
323325
tol = numpy.finfo(dtype).eps * 140
324-
x = testing.shaped_random((10,), xp, dtype) + tol * 1j
326+
x = testing.shaped_random((10,), xp, dtype)
327+
x = xp.add(x, tol * 1j, dtype=xp.result_type(x, 1j))
325328
out = xp.real_if_close(x, tol=150)
326329
assert x.dtype != out.dtype
327330
return out
@@ -331,7 +334,8 @@ def test_real_if_close_with_integer_tol_true(self, xp, dtype):
331334
def test_real_if_close_with_integer_tol_false(self, xp, dtype):
332335
dtype = numpy.dtype(dtype).char.lower()
333336
tol = numpy.finfo(dtype).eps * 50
334-
x = testing.shaped_random((10,), xp, dtype) + tol * 1j
337+
x = testing.shaped_random((10,), xp, dtype)
338+
x = xp.add(x, tol * 1j, dtype=xp.result_type(x, 1j))
335339
out = xp.real_if_close(x, tol=30)
336340
assert x.dtype == out.dtype
337341
return out

0 commit comments

Comments
 (0)