Skip to content

Commit f06ef4f

Browse files
anijain2305vinx13
authored andcommitted
[QNN] Concatenate operator (#3730)
1 parent 5498e54 commit f06ef4f

File tree

2 files changed

+218
-0
lines changed

2 files changed

+218
-0
lines changed

python/tvm/relay/qnn/op/qnn.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
"""QNN dialect operators."""
1919

2020
from __future__ import absolute_import as _abs
21+
from tvm import relay
2122
from . import _make
2223

2324
def requantize(data,
@@ -72,3 +73,75 @@ def requantize(data,
7273
output_zero_point,
7374
rounding,
7475
out_dtype)
76+
77+
def concatenate(data,
78+
input_scales,
79+
input_zero_points,
80+
output_scale,
81+
output_zero_point,
82+
axis):
83+
"""Concatenate the quantized input tensors along the given axis.
84+
85+
Parameters
86+
----------
87+
data : Union(List[relay.Expr], Tuple[relay.Expr])
88+
The list of quantized tensors.
89+
90+
input_scales : List[float32]
91+
The list of scales of input quantized tensors.
92+
93+
input_zero_points : List[int32]
94+
The list of zero points of input quantized tensors.
95+
96+
output_scale : float32
97+
The scale of the output quantized tensor.
98+
99+
output_zero_point : int32
100+
The zero point of the output quantized tensor.
101+
102+
axis : int
103+
The axis along which the tensors are concatenated.
104+
105+
Returns
106+
-------
107+
result: relay.Expr
108+
The concatenated quantized tensor.
109+
"""
110+
111+
data = list(data)
112+
requantized_exprs = list(data)
113+
114+
# Find the dtype of the input expr. This is required for the requantize op. Since, this is
115+
# concatenate op, the dtype of the input is same as dtype of the output.
116+
data0 = relay.transform.infer_type(data[0])
117+
in_dtype = data0.checked_type.dtype
118+
119+
# First check if all the input qnn params match. If yes, we can call concatenate first, followed
120+
# by a requantize.
121+
if all(scale == input_scales[0] for scale in input_scales)\
122+
and all(zero_point == input_zero_points[0] for zero_point in input_zero_points):
123+
out = relay.concatenate(tuple(data), axis)
124+
input_scale = input_scales[0]
125+
input_zero_point = input_zero_points[0]
126+
if input_scale != output_scale or input_zero_point != output_zero_point:
127+
out = requantize(data=out,
128+
input_scale=input_scales[0],
129+
input_zero_point=input_zero_points[0],
130+
output_scale=output_scale,
131+
output_zero_point=output_zero_point,
132+
out_dtype=in_dtype)
133+
return out
134+
135+
# If the output qnn params do not match the input qnn params, we can call requantize on the
136+
# input expr first, followed by a concatenate on the requantized input exprs.
137+
for idx, quantized_expr in enumerate(data):
138+
input_scale = input_scales[idx]
139+
input_zero_point = input_zero_points[idx]
140+
if input_scale != output_scale or input_zero_point != output_zero_point:
141+
requantized_exprs[idx] = requantize(data=quantized_expr,
142+
input_scale=input_scale,
143+
input_zero_point=input_zero_point,
144+
output_scale=output_scale,
145+
output_zero_point=output_zero_point,
146+
out_dtype=in_dtype)
147+
return relay.concatenate(tuple(requantized_exprs), axis)
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
import tvm
19+
import numpy as np
20+
from tvm import relay
21+
from tvm.contrib import graph_runtime
22+
import topi.testing
23+
24+
def test_same_io_qnn_params():
25+
data_dtype = 'int32'
26+
axis = 0
27+
x_data = np.arange(-32, 32, 1).reshape(1, 64).astype(data_dtype)
28+
y_data = np.arange(-64, 64, 2).reshape(1, 64).astype(data_dtype)
29+
x_scale = (62 + 64) / (np.power(2, 32) - 1.0)
30+
y_scale = (62 + 64) / (np.power(2, 32) - 1.0)
31+
32+
x = relay.var("x", shape=(1, 64), dtype=data_dtype)
33+
y = relay.var("y", shape=(1, 64), dtype=data_dtype)
34+
z = relay.qnn.op.concatenate((x, y),
35+
input_scales=[x_scale, y_scale],
36+
input_zero_points=[0, 0],
37+
output_scale=y_scale,
38+
output_zero_point=0,
39+
axis=axis)
40+
41+
func = relay.Function([x, y], z)
42+
assert func.astext().count('requantize') == 0
43+
mod = relay.Module.from_expr(func)
44+
mod = relay.transform.Legalize()(mod)
45+
func = mod["main"]
46+
47+
golden_output = np.concatenate((x_data, y_data), axis=axis)
48+
49+
intrp = relay.create_executor("graph", ctx=tvm.cpu(0), target="llvm")
50+
op_res = intrp.evaluate(func)(x_data, y_data)
51+
np.testing.assert_equal(op_res.asnumpy(), golden_output)
52+
53+
def test_different_io_qnn_params():
54+
data_dtype = 'int32'
55+
axis = 0
56+
x_data = np.arange(-32, 32, 1).reshape(1, 64).astype(data_dtype)
57+
y_data = np.arange(-64, 64, 2).reshape(1, 64).astype(data_dtype)
58+
x_scale = (62 + 64) / (np.power(2, 32) - 1.0)
59+
y_scale = (62 + 64) / (np.power(2, 32) - 1.0)
60+
61+
x = relay.var("x", shape=(1, 64), dtype=data_dtype)
62+
y = relay.var("y", shape=(1, 64), dtype=data_dtype)
63+
z = relay.qnn.op.concatenate((x, y),
64+
input_scales=[x_scale, y_scale],
65+
input_zero_points=[3, 4],
66+
output_scale=y_scale,
67+
output_zero_point=1,
68+
axis=axis)
69+
70+
func = relay.Function([x, y], z)
71+
assert func.astext().count('requantize') == 2
72+
mod = relay.Module.from_expr(func)
73+
mod = relay.transform.Legalize()(mod)
74+
func = mod["main"]
75+
76+
golden_output = np.concatenate((x_data - 2, y_data - 3), axis=axis)
77+
78+
intrp = relay.create_executor("graph", ctx=tvm.cpu(0), target="llvm")
79+
op_res = intrp.evaluate(func)(x_data, y_data)
80+
np.testing.assert_equal(op_res.asnumpy(), golden_output)
81+
82+
def test_few_same_io_qnn_params():
83+
data_dtype = 'int32'
84+
axis = 0
85+
x_data = np.arange(-32, 32, 1).reshape(1, 64).astype(data_dtype)
86+
y_data = np.arange(-64, 64, 2).reshape(1, 64).astype(data_dtype)
87+
x_scale = (62 + 64) / (np.power(2, 32) - 1.0)
88+
y_scale = (62 + 64) / (np.power(2, 32) - 1.0)
89+
90+
x = relay.var("x", shape=(1, 64), dtype=data_dtype)
91+
y = relay.var("y", shape=(1, 64), dtype=data_dtype)
92+
z = relay.qnn.op.concatenate((x, y),
93+
input_scales=[x_scale, y_scale],
94+
input_zero_points=[0, 1],
95+
output_scale=y_scale,
96+
output_zero_point=1,
97+
axis=axis)
98+
99+
func = relay.Function([x, y], z)
100+
assert func.astext().count('requantize') == 1
101+
mod = relay.Module.from_expr(func)
102+
mod = relay.transform.Legalize()(mod)
103+
func = mod["main"]
104+
105+
golden_output = np.concatenate((x_data + 1, y_data), axis=axis)
106+
107+
intrp = relay.create_executor("graph", ctx=tvm.cpu(0), target="llvm")
108+
op_res = intrp.evaluate(func)(x_data, y_data)
109+
np.testing.assert_equal(op_res.asnumpy(), golden_output)
110+
111+
def test_same_i_qnn_params():
112+
data_dtype = 'int32'
113+
axis = 0
114+
x_data = np.arange(-32, 32, 1).reshape(1, 64).astype(data_dtype)
115+
y_data = np.arange(-64, 64, 2).reshape(1, 64).astype(data_dtype)
116+
x_scale = (62 + 64) / (np.power(2, 32) - 1.0)
117+
y_scale = (62 + 64) / (np.power(2, 32) - 1.0)
118+
119+
x = relay.var("x", shape=(1, 64), dtype=data_dtype)
120+
y = relay.var("y", shape=(1, 64), dtype=data_dtype)
121+
z = relay.qnn.op.concatenate((x, y),
122+
input_scales=[x_scale, y_scale],
123+
input_zero_points=[0, 0],
124+
output_scale=y_scale,
125+
output_zero_point=1,
126+
axis=axis)
127+
128+
func = relay.Function([x, y], z)
129+
assert func.astext().count('requantize') == 1
130+
mod = relay.Module.from_expr(func)
131+
mod = relay.transform.Legalize()(mod)
132+
func = mod["main"]
133+
134+
golden_output = np.concatenate((x_data + 1, y_data + 1), axis=axis)
135+
136+
intrp = relay.create_executor("graph", ctx=tvm.cpu(0), target="llvm")
137+
op_res = intrp.evaluate(func)(x_data, y_data)
138+
np.testing.assert_equal(op_res.asnumpy(), golden_output)
139+
140+
141+
if __name__ == '__main__':
142+
test_same_io_qnn_params()
143+
test_different_io_qnn_params()
144+
test_few_same_io_qnn_params()
145+
test_same_i_qnn_params()

0 commit comments

Comments
 (0)