|
28 | 28 | #include <tvm/relay/op.h> |
29 | 29 | #include <tvm/relay/op_attr_types.h> |
30 | 30 | #include <tvm/relay/qnn/attrs.h> |
| 31 | +#include <tvm/relay/qnn/transform.h> |
31 | 32 |
|
32 | 33 | #include <vector> |
33 | 34 |
|
@@ -321,6 +322,19 @@ static inline bool QnnElementwiseUnaryFuncRel(const Array<Type>& types, int num_ |
321 | 322 | return IdentityRel(tensor_types, 2, attrs, reporter); |
322 | 323 | } |
323 | 324 |
|
| 325 | +static inline Expr LegalizeExpr(const Expr& expr) { |
| 326 | + // Canonicalizations should not contain qnn ops, so use this |
| 327 | + // to lower expressions automatically after using things like qnn.dequantize |
| 328 | + // in the lowering process. |
| 329 | + auto mod = IRModule::FromExpr(expr); |
| 330 | + mod = transform::Legalize()(mod); |
| 331 | + if (expr.as<FunctionNode>()) { |
| 332 | + return mod->Lookup("main"); |
| 333 | + } else { |
| 334 | + return mod->Lookup("main").as<FunctionNode>()->body; |
| 335 | + } |
| 336 | +} |
| 337 | + |
324 | 338 | /*! Quick helper macro |
325 | 339 | * - Expose a positional make function to construct the node. |
326 | 340 | * - Register op to the registry. |
@@ -362,11 +376,12 @@ static inline bool QnnElementwiseUnaryFuncRel(const Array<Type>& types, int num_ |
362 | 376 | [](const Attrs& attrs, const Array<Expr>& new_args, const Array<tvm::relay::Type>& arg_types) { \ |
363 | 377 | QnnUnaryOpArguments args(new_args); \ |
364 | 378 | QnnUnaryOpTensorType input_type(arg_types, 0); \ |
365 | | - auto dequantized_arg = MakeDequantize(args.x, args.scale, args.zero_point, -1); \ |
366 | | - auto output = FloatingPointFunc(dequantized_arg); \ |
367 | | - return MakeQuantize(output, args.output_scale, args.output_zero_point, -1, input_type.dtype); \ |
| 379 | + Expr dequantized_arg = MakeDequantize(args.x, args.scale, args.zero_point, -1); \ |
| 380 | + Expr output = FloatingPointFunc(dequantized_arg); \ |
| 381 | + Expr result = \ |
| 382 | + MakeQuantize(output, args.output_scale, args.output_zero_point, -1, input_type.dtype); \ |
| 383 | + return LegalizeExpr(result); \ |
368 | 384 | } |
369 | | - |
370 | 385 | } // namespace qnn |
371 | 386 | } // namespace relay |
372 | 387 | } // namespace tvm |
|
0 commit comments