Skip to content

Commit a473635

Browse files
committed
fix lowering process for using dequantize and quantize ops
1 parent d962b0f commit a473635

2 files changed

Lines changed: 19 additions & 5 deletions

File tree

src/relay/qnn/op/op_common.h

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
#include <tvm/relay/op.h>
2929
#include <tvm/relay/op_attr_types.h>
3030
#include <tvm/relay/qnn/attrs.h>
31+
#include <tvm/relay/qnn/transform.h>
3132

3233
#include <vector>
3334

@@ -321,6 +322,19 @@ static inline bool QnnElementwiseUnaryFuncRel(const Array<Type>& types, int num_
321322
return IdentityRel(tensor_types, 2, attrs, reporter);
322323
}
323324

325+
static inline Expr LegalizeExpr(const Expr& expr) {
326+
// Canonicalizations should not contain qnn ops, so use this
327+
// to lower expressions automatically after using things like qnn.dequantize
328+
// in the lowering process.
329+
auto mod = IRModule::FromExpr(expr);
330+
mod = transform::Legalize()(mod);
331+
if (expr.as<FunctionNode>()) {
332+
return mod->Lookup("main");
333+
} else {
334+
return mod->Lookup("main").as<FunctionNode>()->body;
335+
}
336+
}
337+
324338
/*! Quick helper macro
325339
* - Expose a positional make function to construct the node.
326340
* - Register op to the registry.
@@ -362,11 +376,12 @@ static inline bool QnnElementwiseUnaryFuncRel(const Array<Type>& types, int num_
362376
[](const Attrs& attrs, const Array<Expr>& new_args, const Array<tvm::relay::Type>& arg_types) { \
363377
QnnUnaryOpArguments args(new_args); \
364378
QnnUnaryOpTensorType input_type(arg_types, 0); \
365-
auto dequantized_arg = MakeDequantize(args.x, args.scale, args.zero_point, -1); \
366-
auto output = FloatingPointFunc(dequantized_arg); \
367-
return MakeQuantize(output, args.output_scale, args.output_zero_point, -1, input_type.dtype); \
379+
Expr dequantized_arg = MakeDequantize(args.x, args.scale, args.zero_point, -1); \
380+
Expr output = FloatingPointFunc(dequantized_arg); \
381+
Expr result = \
382+
MakeQuantize(output, args.output_scale, args.output_zero_point, -1, input_type.dtype); \
383+
return LegalizeExpr(result); \
368384
}
369-
370385
} // namespace qnn
371386
} // namespace relay
372387
} // namespace tvm

tests/python/relay/test_pass_fake_quantization_to_integer.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
def compare_fq_to_int(expr, args, allow_rounding_error=False):
2525
mod = tvm.IRModule.from_expr(expr)
2626
mod = tvm.relay.transform.InferType()(mod)
27-
2827
mod_int = tvm.relay.transform.FakeQuantizationToInteger()(mod)
2928
assert not tvm.ir.structural_equal(mod, mod_int)
3029

0 commit comments

Comments
 (0)