Skip to content

Commit 6a28061

Browse files
author
Siyuan Feng
committed
[Relax] Enhance Relax op and ONNX frontend
This PR adds: - relax.op.one_hot - relax.op.mod and relax.op.floor_mod - relax.op.eye and relax.op.eye_like - ONNX frontend for one_hot - ONNX frontend for mod - ONNX frontend for one_hot and hardmax - ONNX frontend for eye_like
1 parent 7d2fa11 commit 6a28061

File tree

21 files changed

+659
-17
lines changed

21 files changed

+659
-17
lines changed

include/tvm/relax/attrs/manipulate.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,17 @@ struct ScatterNDAttrs : public tvm::AttrsNode<ScatterNDAttrs> {
176176
}
177177
}; // struct ScatterNDAttrs
178178

179+
/*! \brief Attributes used in one_hot operator */
180+
struct OneHotAttrs : public tvm::AttrsNode<OneHotAttrs> {
181+
int depth;
182+
int axis;
183+
184+
TVM_DECLARE_ATTRS(OneHotAttrs, "relax.attrs.OneHotAttrs") {
185+
TVM_ATTR_FIELD(depth).describe("Depth of the one hot dimension.");
186+
TVM_ATTR_FIELD(axis).set_default(-1).describe("Axis to fill.");
187+
}
188+
}; // struct OneHotAttrs
189+
179190
} // namespace relax
180191
} // namespace tvm
181192

python/tvm/relax/frontend/onnx/onnx_frontend.py

Lines changed: 59 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ class Sub(BinaryBase):
287287
relax_op = relax.op.subtract
288288

289289
@classmethod
290-
def _impl_v1(cls, bb, inputs, attr, params):
290+
def _impl_v7(cls, bb, inputs, attr, params):
291291
return cls.base_impl(bb, inputs, attr, params)
292292

293293

@@ -298,7 +298,7 @@ class Mul(BinaryBase):
298298
relax_op = relax.op.multiply
299299

300300
@classmethod
301-
def _impl_v1(cls, bb, inputs, attr, params):
301+
def _impl_v7(cls, bb, inputs, attr, params):
302302
return cls.base_impl(bb, inputs, attr, params)
303303

304304

@@ -309,7 +309,7 @@ class Div(BinaryBase):
309309
relax_op = relax.op.divide
310310

311311
@classmethod
312-
def _impl_v1(cls, bb, inputs, attr, params):
312+
def _impl_v7(cls, bb, inputs, attr, params):
313313
return cls.base_impl(bb, inputs, attr, params)
314314

315315

@@ -320,7 +320,24 @@ class Pow(BinaryBase):
320320
relax_op = relax.op.power
321321

322322
@classmethod
323-
def _impl_v1(cls, bb, inputs, attr, params):
323+
def _impl_v7(cls, bb, inputs, attr, params):
324+
return cls.base_impl(bb, inputs, attr, params)
325+
326+
327+
class Mod(BinaryBase):
328+
"""Converts an onnx Mod node into an equivalent Relax expression."""
329+
330+
numpy_op = _np.mod
331+
relax_op = relax.op.mod
332+
333+
@classmethod
334+
def _impl_v10(cls, bb, inputs, attr, params):
335+
if attr.get("fmod", 0) == 0:
336+
cls.numpy_op = _np.fmod
337+
cls.relax_op = relax.op.floor_mod
338+
else:
339+
cls.numpy_op = _np.mod
340+
cls.relax_op = relax.op.mod
324341
return cls.base_impl(bb, inputs, attr, params)
325342

326343

@@ -523,6 +540,23 @@ def _impl_v13(cls, bb, inputs, attr, params):
523540
return relax.op.nn.log_softmax(inputs[0], axis=axis)
524541

525542

543+
class Hardmax(OnnxOpConverter):
544+
"""Converts an onnx Hardmax node into an equivalent Relax expression."""
545+
546+
@classmethod
547+
def _impl_v13(cls, bb, inputs, attr, params):
548+
axis = attr.get("axis", -1)
549+
indices = inputs[0]
550+
dtype = indices.struct_info.dtype
551+
axis_len = int(inputs[0].struct_info.shape[axis])
552+
argmax = relax.op.argmax(indices, axis=axis)
553+
on_value = relax.PrimValue(tvm.tir.const(1.0, dtype))
554+
off_value = relax.PrimValue(tvm.tir.const(0.0, dtype))
555+
556+
one_hot = relax.op.one_hot(argmax, on_value, off_value, axis_len, axis)
557+
return one_hot
558+
559+
526560
class Transpose(OnnxOpConverter):
527561
"""Converts an onnx Transpose node into an equivalent Relax expression."""
528562

@@ -731,6 +765,22 @@ def _impl_v1(cls, bb, inputs, attr, params):
731765
return relax.op.prod(relax.op.shape_to_tensor(relax.op.shape_of(inputs[0])))
732766

733767

768+
class EyeLike(OnnxOpConverter):
769+
"""Convert an onnx EyeLike node into an equivalent Relax expression."""
770+
771+
@classmethod
772+
def _impl_v9(cls, bb, inputs, attr, params):
773+
k = attr.get("k", 0)
774+
input_dtype = inputs[0].struct_info.dtype
775+
if "dtype" in attr and get_type(attr["dtype"]) != input_dtype:
776+
raise ValueError(
777+
f"dtype mismatch between input ({input_dtype}) and attribute ({attr['dtype']})"
778+
)
779+
else:
780+
dtype = input_dtype
781+
return relax.op.eye_like(inputs[0], k, dtype)
782+
783+
734784
class Gemm(OnnxOpConverter):
735785
"""Convert an onnx Gemm node into an equivalent Relax expression."""
736786

@@ -2520,13 +2570,13 @@ def _impl_v11(cls, bb, inputs, attr, params):
25202570
depth = get_constant(inputs[1], params)
25212571
values = get_constant(inputs[2], params)
25222572
axis = attr.get("axis", -1)
2523-
dtype = values.struct_info.dtype
25242573
assert isinstance(depth, relax.Constant), "Only constant depth currently supported."
25252574
depth = depth.data.numpy().tolist()
25262575
assert isinstance(values, relax.Constant), "Only constant values currently supported."
25272576
values = values.data.numpy().tolist()
25282577
off_value, on_value = values
2529-
return bb.emit_te(topi.one_hot, indices, on_value, off_value, depth, axis, dtype)
2578+
off_value, on_value = relax.PrimValue(off_value), relax.PrimValue(on_value)
2579+
return relax.op.one_hot(indices, on_value, off_value, depth, axis)
25302580

25312581

25322582
class Unique(OnnxOpConverter):
@@ -2800,7 +2850,7 @@ def _get_convert_map():
28002850
"Sub": Sub,
28012851
"Mul": Mul,
28022852
"Div": Div,
2803-
# "Mod": Mod,
2853+
"Mod": Mod,
28042854
"Less": Less,
28052855
"LessOrEqual": LessOrEqual,
28062856
"Greater": Greater,
@@ -2870,7 +2920,7 @@ def _get_convert_map():
28702920
"Sigmoid": Sigmoid,
28712921
"Softmax": Softmax,
28722922
"LogSoftmax": LogSoftmax,
2873-
# "Hardmax": Hardmax,
2923+
"Hardmax": Hardmax,
28742924
"Transpose": Transpose,
28752925
"Unsqueeze": Unsqueeze,
28762926
"Where": Where,
@@ -2889,7 +2939,7 @@ def _get_convert_map():
28892939
"ScatterND": ScatterND,
28902940
# "Compress": Compress,
28912941
"Size": Size,
2892-
# "EyeLike": EyeLike,
2942+
"EyeLike": EyeLike,
28932943
# Normalization
28942944
"BatchNormalization": BatchNormalization,
28952945
"LayerNormalization": LayerNormalization,

python/tvm/relax/op/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
divide,
5151
equal,
5252
floor_divide,
53+
floor_mod,
5354
greater,
5455
greater_equal,
5556
left_shift,
@@ -60,6 +61,7 @@
6061
logical_xor,
6162
maximum,
6263
minimum,
64+
mod,
6365
multiply,
6466
not_equal,
6567
power,
@@ -72,6 +74,8 @@
7274
full_like,
7375
ones,
7476
ones_like,
77+
eye,
78+
eye_like,
7579
tril,
7680
triu,
7781
zeros,
@@ -89,6 +93,7 @@
8993
flatten,
9094
flip,
9195
layout_transform,
96+
one_hot,
9297
permute_dims,
9398
repeat,
9499
reshape,

python/tvm/relax/op/binary.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,32 @@ def subtract(x1: Expr, x2: Expr) -> Expr:
139139
return _ffi_api.subtract(x1, x2) # type: ignore
140140

141141

142+
def mod(x1: Expr, x2: Expr) -> Expr:
143+
"""Modulo with numpy-style broadcasting.
144+
145+
Parameters
146+
----------
147+
x1 : Expr
148+
The first input tensor.
149+
x2 : Expr
150+
The second input tensor.
151+
"""
152+
return _ffi_api.mod(x1, x2) # type: ignore
153+
154+
155+
def floor_mod(x1: Expr, x2: Expr) -> Expr:
156+
"""Floor modulo with numpy-style broadcasting.
157+
158+
Parameters
159+
----------
160+
x1 : Expr
161+
The first input tensor.
162+
x2 : Expr
163+
The second input tensor.
164+
"""
165+
return _ffi_api.floor_mod(x1, x2) # type: ignore
166+
167+
142168
###################### Comparison operators ######################
143169

144170

python/tvm/relax/op/create.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,74 @@ def zeros_like(x: Expr, dtype: Optional[Union[str, DataType]] = None) -> Expr:
163163
return _ffi_api.zeros_like(x, dtype) # type: ignore
164164

165165

166+
def eye(
167+
n: Union[PrimExprLike, PrimValue],
168+
m: Optional[Union[PrimExprLike, PrimValue]] = None,
169+
k: Union[PrimExprLike, PrimValue] = 0,
170+
dtype: Union[str, DataType] = "float32",
171+
) -> Expr:
172+
"""Construct a 2-D tensor with ones on the diagonal and zeros elsewhere.
173+
174+
Parameters
175+
----------
176+
n : Union[PrimExprLike, PrimValue]
177+
Number of rows in the output.
178+
179+
m : Optional[Union[PrimExprLike, PrimValue]]
180+
Number of columns in the output. If None, defaults to n.
181+
182+
k : Union[PrimExprLike, PrimValue]
183+
Index of the diagonal: 0 (the default) refers to the main diagonal,
184+
a positive value refers to an upper diagonal, and a negative value
185+
to a lower diagonal.
186+
187+
dtype : Union[str, DataType]
188+
The data type of the created tensor.
189+
190+
Returns
191+
-------
192+
result : relax.Expr
193+
The result tensor.
194+
"""
195+
m = n if m is None else m
196+
n = n if isinstance(n, PrimValue) else PrimValue(n)
197+
m = m if isinstance(m, PrimValue) else PrimValue(m)
198+
k = k if isinstance(k, PrimValue) else PrimValue(k)
199+
return _ffi_api.eye(n, m, k, dtype) # type: ignore
200+
201+
202+
def eye_like(
203+
x: Expr,
204+
k: Union[PrimExprLike, PrimValue] = 0,
205+
dtype: Optional[Union[str, DataType]] = None,
206+
) -> Expr:
207+
"""Return a 2-D tensor with ones on the diagonal and zeros elsewhere,
208+
with the same shape as the input tensor.
209+
210+
Parameters
211+
----------
212+
x : relax.Expr
213+
The input tensor, which provides the shape, and dtype
214+
when the `dtype` field is not specified.
215+
216+
k : Union[PrimExprLike, PrimValue]
217+
Index of the diagonal: 0 (the default) refers to the main diagonal,
218+
a positive value refers to an upper diagonal, and a negative value
219+
to a lower diagonal.
220+
221+
dtype : Optional[Union[str, DataType]]
222+
The data type of the created tensor.
223+
If dtype is not given, it will by default use the dtype of the input tensor.
224+
225+
Returns
226+
-------
227+
result : relax.Expr
228+
The result tensor.
229+
"""
230+
k = k if isinstance(k, PrimValue) else PrimValue(k)
231+
return _ffi_api.eye_like(x, k, dtype) # type: ignore
232+
233+
166234
def arange(
167235
start: Union[PrimExprLike, PrimValue],
168236
end: Optional[Union[PrimExprLike, PrimValue]] = None,

python/tvm/relax/op/manipulate.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -550,3 +550,47 @@ def scatter_nd(data: Expr, indices: Expr, updates: Expr, reduction: str = "updat
550550
551551
"""
552552
return _ffi_api.scatter_nd(data, indices, updates, reduction) # type: ignore
553+
554+
555+
def one_hot(
556+
indices: Expr, on_value: PrimValue, off_value: PrimValue, depth: int, axis: int = -1
557+
) -> Expr:
558+
"""Returns a one-hot tensor.
559+
560+
Parameters
561+
----------
562+
indices : relax.Expr
563+
The indices to set to `on_value`.
564+
565+
on_value : relax.PrimValue
566+
The value to fill at `indices`.
567+
568+
off_value : relax.PrimValue
569+
The value to fill at other locations.
570+
571+
depth : int
572+
The depth of the one-hot dimension.
573+
574+
axis : int, optional
575+
The axis to fill. Default is -1 which adds a new dimension at the end.
576+
577+
Returns
578+
-------
579+
result : relax.Expr
580+
The computed result.
581+
582+
Examples
583+
--------
584+
.. code-block:: python
585+
586+
indices = [0, 1, 2]
587+
depth = 3
588+
on_value = 1
589+
off_value = 0
590+
591+
one_hot(indices, on_value, off_value, depth) =
592+
[[1, 0, 0],
593+
[0, 1, 0],
594+
[0, 0, 1]]
595+
"""
596+
return _ffi_api.one_hot(indices, on_value, off_value, depth, axis) # type: ignore

python/tvm/relax/transform/legalize_ops/binary.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ def binary_call_te(bb: BlockBuilder, call: Call) -> Expr:
4848
register_legalize("relax.power", _binary(topi.power))
4949
register_legalize("relax.subtract", _binary(topi.subtract))
5050
register_legalize("relax.equal", _binary(topi.equal))
51-
51+
register_legalize("relax.mod", _binary(topi.mod))
52+
register_legalize("relax.floor_mod", _binary(topi.floor_mod))
5253
register_legalize("relax.greater", _binary(topi.greater))
5354
register_legalize("relax.greater_equal", _binary(topi.greater_equal))
5455
register_legalize("relax.less", _binary(topi.less))

python/tvm/relax/transform/legalize_ops/create.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,36 @@ def tril_triu_call_te(bb: BlockBuilder, call: Call) -> Expr:
7070
register_legalize("relax.triu", _tril_triu(is_upper=True, primfunc_name="triu"))
7171

7272

73+
def _eye(is_like: bool, primfunc_name: str) -> LegalizeFunc:
74+
def eye_call_te(bb: BlockBuilder, call: Call) -> Expr:
75+
_convert_to_scalar_const = lambda x: _try_convert_to_scalar_const(x, python_native=True)
76+
if is_like:
77+
x = call.args[0]
78+
k = _convert_to_scalar_const(call.args[1]) if len(call.args) > 1 else 0
79+
n, m = x.struct_info.shape
80+
dtype = x.struct_info.dtype
81+
else:
82+
n = _convert_to_scalar_const(call.args[0])
83+
m = _convert_to_scalar_const(call.args[1]) if len(call.args) > 1 else n
84+
k = _convert_to_scalar_const(call.args[2]) if len(call.args) > 2 else 0
85+
dtype = call.attrs.dtype
86+
87+
return bb.call_te(
88+
topi.eye,
89+
n,
90+
m,
91+
k,
92+
dtype,
93+
primfunc_name_hint=primfunc_name,
94+
)
95+
96+
return eye_call_te
97+
98+
99+
register_legalize("relax.eye", _eye(is_like=False, primfunc_name="eye"))
100+
register_legalize("relax.eye_like", _eye(is_like=True, primfunc_name="eye_like"))
101+
102+
73103
@register_legalize("relax.arange")
74104
def _arange(bb: BlockBuilder, call: Call) -> Expr:
75105
assert len(call.args) == 3

0 commit comments

Comments
 (0)