Skip to content

Commit 6f5f53e

Browse files
Ubuntuanijain2305
authored andcommitted
[QNN][Relay] Calling Dialect passes from inside Relay Build API.
1 parent 4ba911a commit 6f5f53e

File tree

8 files changed

+118
-10
lines changed

8 files changed

+118
-10
lines changed

include/tvm/relay/qnn/transform.h

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* \file tvm/relay/qnn/transform.h
22+
*
23+
* This file implements a pass manager for QNN ops using Relay Pass manager.
24+
*/
25+
#ifndef TVM_RELAY_QNN_TRANSFORM_H_
26+
#define TVM_RELAY_QNN_TRANSFORM_H_
27+
28+
#include <tvm/relay/transform.h>
29+
30+
namespace tvm {
31+
namespace relay {
32+
33+
using relay::transform::Pass;
34+
35+
namespace qnn {
36+
namespace transform {
37+
38+
/*!
39+
* \brief Legalizes a QNN expr.
40+
* \param Contains specifically two types of Legalizations. First, converts/Lowers an expression
41+
* containing QNN ops to an expression containing only core Relay ops. Each QNN op is lowered to a
42+
* sequence of exisiting Relay ops. This is a target-independent pass. One can register the
43+
* lowering/transformation function for this op using FTVMQnnCanonicalize attr_name for FTVMLegalize
44+
* op attribute. Second, as opposed to Relay Legalize, this one legalizes only QNN ops. One can
45+
* register a transformation/legalization function for an op by using the FTVMQnnLegalize attr_name
46+
* for FTVMLegalize op attribute. The isolation of QNN and Relay Legalize gives us separation of
47+
* concerns, leading to a better software practice. The legalization can be configured to happen per
48+
* target.
49+
*
50+
* \return The pass.
51+
*/
52+
Pass Legalize();
53+
54+
} // namespace transform
55+
56+
} // namespace qnn
57+
} // namespace relay
58+
} // namespace tvm
59+
60+
#endif // TVM_RELAY_QNN_TRANSFORM_H_

src/relay/backend/build_module.cc

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include <tvm/runtime/vm.h>
2828
#include <tvm/relay/expr.h>
2929
#include <tvm/relay/transform.h>
30+
#include <tvm/relay/qnn/transform.h>
3031
#include <memory>
3132

3233
#include "utils.h"
@@ -282,6 +283,15 @@ class RelayBuildModule : public runtime::ModuleNode {
282283
const TargetsMap& targets,
283284
const std::unordered_map<std::string, runtime::NDArray>& params) {
284285
Array<Pass> pass_seqs;
286+
287+
// Run all dialect legalization passes.
288+
pass_seqs.push_back(relay::qnn::transform::Legalize());
289+
290+
// Legalize pass is restricted to homogeneous execution for now.
291+
if (targets.size() == 1) {
292+
pass_seqs.push_back(transform::Legalize());
293+
}
294+
285295
pass_seqs.push_back(transform::SimplifyInference());
286296
PackedFunc fskip = PackedFunc([](TVMArgs args, TVMRetValue* rv) {
287297
Expr expr = args[0];
@@ -304,11 +314,6 @@ class RelayBuildModule : public runtime::ModuleNode {
304314
pass_seqs.push_back(transform::CanonicalizeCast());
305315
pass_seqs.push_back(transform::CanonicalizeOps());
306316

307-
// Legalize pass is restricted to homogeneous execution for now.
308-
if (targets.size() == 1) {
309-
pass_seqs.push_back(transform::Legalize());
310-
}
311-
312317
// Alter layout transformation is only applied to homogeneous execution yet.
313318
if (targets.size() == 1) {
314319
pass_seqs.push_back(transform::AlterOpLayout());

src/relay/pass/legalize.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ Pass Legalize(const std::string& legalize_map_attr_name) {
9595
[=](Function f, Module m, PassContext pc) {
9696
return Downcast<Function>(relay::legalize::Legalize(f, legalize_map_attr_name));
9797
};
98-
return CreateFunctionPass(pass_func, 3, "Legalize", {ir::StringImm::make("InferType")});
98+
return CreateFunctionPass(pass_func, 0, "Legalize", {ir::StringImm::make("InferType")});
9999
}
100100

101101
TVM_REGISTER_API("relay._transform.Legalize").set_body_typed(Legalize);

src/relay/qnn/pass/legalize.cc

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* \file relay/qnn/pass/legalize.cc
22+
* \brief The Legalize wrapper for QNN.
23+
*/
24+
25+
#include <tvm/relay/qnn/transform.h>
26+
27+
namespace tvm {
28+
namespace relay {
29+
namespace qnn {
30+
31+
namespace transform {
32+
33+
Pass Legalize() {
34+
Array<Pass> pass_seqs;
35+
pass_seqs.push_back(relay::transform::Legalize("FTVMQnnLegalize"));
36+
pass_seqs.push_back(relay::transform::Legalize("FTVMQnnCanonicalize"));
37+
relay::transform::Pass seq = relay::transform::Sequential(pass_seqs);
38+
return seq;
39+
}
40+
41+
TVM_REGISTER_API("relay.qnn._transform.Legalize").set_body_typed(Legalize);
42+
43+
} // namespace transform
44+
45+
} // namespace qnn
46+
} // namespace relay
47+
} // namespace tvm

tests/python/relay/test_op_qnn_conv2d.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,6 @@ def get_qnn_func(data,
7878

7979
mod = relay.Function(relay.analysis.free_vars(func), func)
8080
mod = relay.Module.from_expr(mod)
81-
mod = relay.qnn.transform.CanonicalizeOps()(mod)
8281
return mod
8382

8483
def get_funcs(data_shape,

tests/python/relay/test_op_qnn_dequantize.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ def quantize_test_driver(in_dtype, quant_args, in_data, verify_output_data):
3131
input_zero_point=input_zero_point)
3232
mod = relay.Function(relay.analysis.free_vars(quantized_output), quantized_output)
3333
mod = relay.Module.from_expr(mod)
34-
mod = relay.qnn.transform.CanonicalizeOps()(mod)
3534
with relay.build_config(opt_level=3):
3635
graph, lib, params = relay.build(mod, "llvm", params=None)
3736
rt_mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0))

tests/python/relay/test_op_qnn_quantize.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ def quantize_test_driver(in_dtype, quant_args, out_dtype, in_data, verify_output
3131
output_zero_point=output_zero_point,out_dtype=out_dtype)
3232
mod = relay.Function(relay.analysis.free_vars(quantized_output), quantized_output)
3333
mod = relay.Module.from_expr(mod)
34-
mod = relay.qnn.transform.CanonicalizeOps()(mod)
3534
with relay.build_config(opt_level=3):
3635
graph, lib, params = relay.build(mod, "llvm", params=None)
3736
rt_mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0))

tests/python/relay/test_op_qnn_requantize.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@ def get_mod(data_shape, data_dtype, out_dtype, input_scale, output_scale,
4949

5050
mod = relay.Function(relay.analysis.free_vars(mod), mod)
5151
mod = relay.Module.from_expr(mod)
52-
mod = relay.qnn.transform.CanonicalizeOps()(mod)
5352
return mod
5453

5554
def same_scale_test():

0 commit comments

Comments
 (0)