Skip to content

Commit 43be829

Browse files
committed
[QNN] Concatenate operator
1 parent 3ac27fc commit 43be829

2 files changed

Lines changed: 108 additions & 0 deletions

File tree

python/tvm/relay/qnn/op/qnn.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
from __future__ import absolute_import as _abs
2121
from . import _make
22+
from tvm import relay
2223

2324
def requantize(data,
2425
input_scale,
@@ -72,3 +73,55 @@ def requantize(data,
7273
output_zero_point,
7374
rounding,
7475
out_dtype)
76+
77+
def concatenate(data,
78+
input_scales,
79+
input_zero_points,
80+
output_scale,
81+
output_zero_point,
82+
output_dtype,
83+
axis):
84+
"""Concatenate the quantized input tensors along the given axis.
85+
86+
Parameters
87+
----------
88+
data : Union(List[relay.Expr], Tuple[relay.Expr])
89+
A list of quantized tensors
90+
91+
input_scales : List[float32]
92+
A list of scales of quantized tensors
93+
94+
input_zero_points : List[int32]
95+
A list of zero points of quantized tensors
96+
97+
output_scale : float32
98+
A scales of output
99+
100+
output_zero_point : int32
101+
A zero points of output
102+
103+
axis : int
104+
The axis along which the tensors are concatenated.
105+
106+
Returns
107+
-------
108+
result: relay.Expr
109+
The concatenated tensor
110+
"""
111+
112+
data = list(data)
113+
requantized_exprs = list(data)
114+
# If the output qnn params do not match the input qnn params, we call requantize on the input
115+
# params.
116+
for idx, quantized_expr in enumerate(data):
117+
scale = input_scales[idx]
118+
zero_point = input_zero_points[idx]
119+
if scale != output_scale or zero_point != output_zero_point:
120+
requantized_exprs[idx] = requantize(quantized_expr,
121+
input_scale=scale,
122+
input_zero_point=zero_point,
123+
output_scale=output_scale,
124+
output_zero_point=output_zero_point,
125+
out_dtype=output_dtype)
126+
# As all tensors now share same qnn params, we can directly call relay concatenate.
127+
return relay.concatenate(tuple(requantized_exprs), axis)
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
import tvm
19+
import numpy as np
20+
from tvm import relay
21+
from tvm.contrib import graph_runtime
22+
import topi.testing
23+
24+
def test_qnn_concatenate():
25+
data_dtype = 'int32'
26+
axis = 0
27+
x_data = np.arange(-32, 32, 1).reshape(1, 64).astype(data_dtype)
28+
y_data = np.arange(-64, 64, 2).reshape(1, 64).astype(data_dtype)
29+
x_scale = (62 + 64) / (np.power(2, 32) - 1.0)
30+
y_scale = (62 + 64) / (np.power(2, 32) - 1.0)
31+
32+
x = relay.var("x", shape=(1, 64), dtype=data_dtype)
33+
y = relay.var("y", shape=(1, 64), dtype=data_dtype)
34+
z = relay.qnn.op.concatenate((x, y),
35+
input_scales=[x_scale, y_scale],
36+
input_zero_points=[0, 0],
37+
output_scale=y_scale,
38+
output_zero_point=1,
39+
output_dtype=data_dtype,
40+
axis=axis)
41+
42+
func = relay.Function([x, y], z)
43+
mod = relay.Module.from_expr(func)
44+
mod = relay.transform.Legalize()(mod)
45+
func = mod["main"]
46+
47+
golden_output = np.concatenate((x_data, y_data), axis=axis)
48+
golden_output = np.add(1, golden_output)
49+
50+
intrp = relay.create_executor("graph", ctx=tvm.cpu(0), target="llvm")
51+
op_res = intrp.evaluate(func)(x_data, y_data)
52+
np.testing.assert_equal(op_res.asnumpy(), golden_output)
53+
54+
if __name__ == '__main__':
55+
test_qnn_concatenate()

0 commit comments

Comments
 (0)