Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 179 additions & 4 deletions python/tvm/relay/qnn/op/qnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,16 @@
import tvm
import tvm.ir
from tvm import relay
from tvm.runtime import Object
from tvm.relay.expr import Tuple, TupleWrapper
from tvm.relay.op.nn.utils import get_pad_tuple2d
from tvm.topi.nn.qnn import SQNN_DTYPE_TO_CODE
from tvm.runtime import Object
from tvm.target import Target
from tvm.topi.nn.qnn import SQNN_DTYPE_TO_CODE
from tvm.topi.x86.utils import target_has_sse41

from ... import op as reg
from ...op import OpPattern
from . import _make
from . import _requantize
from . import _make, _requantize


@tvm._ffi.register_object("relay.qnn.op.RequantizeConfig")
Expand Down Expand Up @@ -750,6 +750,111 @@ def mul(
)


def tanh(x, scale, zero_point, output_scale, output_zero_point):
"""Quantized tanh.

Parameters
----------
x : relay.Expr
The quantized input tensor.

scale: relay.Expr
The scale of the quantized expr.

zero_point: relay.Expr
The zero point of quantized expr.

output_scale: relay.Expr
The scale of the output quantized expr.

output_zero_point: relay.Expr
The zero point of output quantized expr.

Returns
-------
result : relay.Expr
The computed result.

"""
return _make.tanh(
x,
scale,
zero_point,
output_scale,
output_zero_point,
)


def exp(x, scale, zero_point, output_scale, output_zero_point):
"""Quantized exponential function.

Parameters
----------
x : relay.Expr
The quantized input tensor.

scale: relay.Expr
The scale of the quantized expr.

zero_point: relay.Expr
The zero point of quantized expr.

output_scale: relay.Expr
The scale of the output quantized expr.

output_zero_point: relay.Expr
The zero point of output quantized expr.

Returns
-------
result : relay.Expr
The computed result.

"""
return _make.exp(
x,
scale,
zero_point,
output_scale,
output_zero_point,
)


def sqrt(x, scale, zero_point, output_scale, output_zero_point):
"""Quantized square root.

Parameters
----------
x : relay.Expr
The quantized input tensor.

scale: relay.Expr
The scale of the quantized expr.

zero_point: relay.Expr
The zero point of quantized expr.

output_scale: relay.Expr
The scale of the output quantized expr.

output_zero_point: relay.Expr
The zero point of output quantized expr.

Returns
-------
result : relay.Expr
The computed result.

"""
return _make.sqrt(
x,
scale,
zero_point,
output_scale,
output_zero_point,
)


def rsqrt(x, scale, zero_point, output_scale, output_zero_point):
"""Quantized reciprocal square root.

Expand Down Expand Up @@ -785,6 +890,76 @@ def rsqrt(x, scale, zero_point, output_scale, output_zero_point):
)


def erf(x, scale, zero_point, output_scale, output_zero_point):
"""Quantized error function.

Parameters
----------
x : relay.Expr
The quantized input tensor.

scale: relay.Expr
The scale of the quantized expr.

zero_point: relay.Expr
The zero point of quantized expr.

output_scale: relay.Expr
The scale of the output quantized expr.

output_zero_point: relay.Expr
The zero point of output quantized expr.

Returns
-------
result : relay.Expr
The computed result.

"""
return _make.erf(
x,
scale,
zero_point,
output_scale,
output_zero_point,
)


def sigmoid(x, scale, zero_point, output_scale, output_zero_point):
"""Quantized sigmoid.

Parameters
----------
x : relay.Expr
The quantized input tensor.

scale: relay.Expr
The scale of the quantized expr.

zero_point: relay.Expr
The zero point of quantized expr.

output_scale: relay.Expr
The scale of the output quantized expr.

output_zero_point: relay.Expr
The zero point of output quantized expr.

Returns
-------
result : relay.Expr
The computed result.

"""
return _make.sigmoid(
x,
scale,
zero_point,
output_scale,
output_zero_point,
)


def subtract(
lhs, rhs, lhs_scale, lhs_zero_point, rhs_scale, rhs_zero_point, output_scale, output_zero_point
):
Expand Down
93 changes: 93 additions & 0 deletions src/relay/qnn/op/op_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/qnn/attrs.h>
#include <tvm/relay/qnn/transform.h>

#include <vector>

Expand Down Expand Up @@ -289,6 +290,98 @@ static inline bool QnnBroadcastRel(const Array<Type>& types, int num_inputs, con
.set_attr<TNonComputational>("TNonComputational", true) \
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", QnnBinaryBroadcastLayout)

static inline bool QnnElementwiseUnaryFuncRel(const Array<Type>& types, int num_inputs,
const Attrs& attrs, const TypeReporter& reporter) {
// Expected Types: data, scale, zero_point, output_scale, output_zero_point
ICHECK_EQ(types.size(), 6);
const auto* x = types[0].as<TensorTypeNode>();
if (x == nullptr) return false;
ICHECK(x->dtype == DataType::Int(8) || x->dtype == DataType::UInt(8))
<< "Expected quantized type(int8, uint8) for input but was " << x->dtype;

// Check the types of scale and zero points.
for (size_t i = 1; i < 5; ++i) {
if (types[i].as<IncompleteTypeNode>()) {
return false;
}
}
ICHECK(IsScalarType(types[1], DataType::Float(32))); // scale
ICHECK(IsScalarType(types[2], DataType::Int(32))); // zero_point
ICHECK(IsScalarType(types[3], DataType::Float(32))); // output_scale
ICHECK(IsScalarType(types[4], DataType::Int(32))); // output_zero_point

// Assign types for scale and zero points.
reporter->Assign(types[1], TensorType({}, DataType::Float(32))); // scale
reporter->Assign(types[2], TensorType({}, DataType::Int(32))); // zero_point
reporter->Assign(types[3], TensorType({}, DataType::Float(32))); // output_scale
reporter->Assign(types[4], TensorType({}, DataType::Int(32))); // output_zero_point

// Collect the input tensor and output tensor devoid of scale and zero points to reuse Relay
// IdentityRel infer type function.
Array<Type> tensor_types = {types[0], types[5]};
return IdentityRel(tensor_types, 2, attrs, reporter);
}

static inline Expr LegalizeExpr(const Expr& expr) {
// Canonicalizations should not contain qnn ops, so use this
// to lower expressions automatically after using things like qnn.dequantize
// in the lowering process.
auto mod = IRModule::FromExpr(expr);
mod = transform::Legalize()(mod);
if (expr.as<FunctionNode>()) {
return mod->Lookup("main");
} else {
return mod->Lookup("main").as<FunctionNode>()->body;
}
}

/*! Quick helper macro
* - Expose a positional make function to construct the node.
* - Register op to the registry.
*
* For Unary Operators which also take in QParams.
*
* \param OpName the name of registry.
*/
#define QNN_CREATE_UNARY_ELEMENTWISE_OP(OpName) \
TVM_REGISTER_GLOBAL("relay.qnn.op._make." OpName) \
.set_body_typed( \
[](Expr x, Expr scale, Expr zero_point, Expr output_scale, Expr output_zero_point) { \
return Call(Op::Get("qnn." OpName), \
{x, scale, zero_point, output_scale, output_zero_point}, Attrs(), {}); \
}); \
\
RELAY_REGISTER_OP("qnn." OpName) \
.describe("Elementwise " OpName " for quantized tensors.") \
.set_num_inputs(5) \
.add_argument("data", "Quantized Tensor", "The input data.") \
.add_argument("scale", "Tensor", "The quantization scale of the input tensor.") \
.add_argument("zero_point", "Tensor", "The quantization zero_point of the input tensor.") \
.add_argument("output_scale", "Tensor", "The quantization scale of the output tensor.") \
.add_argument("output_zero_point", "Tensor", \
"The quantization zero_point of the output tensor.") \
.set_support_level(11) \
.add_type_rel("qnn." OpName, QnnElementwiseUnaryFuncRel) \
.set_attr<TNonComputational>("TNonComputational", true)

/*! Quick helper macro
* Create a default canonicalization for a QNN operator, which dequantizes the operator
* runs the calculation using the provided Call func, and then requantizes.
*
* FloatingPointFunc is usually a handle from "src/relay/transforms/pattern_utils.h"
*
* \param FloatingPointFunc the floating point function with function signature `Expr Erf(Expr e)`
*/
#define QNN_UNARY_OP_DEFAULT_CANONICALIZATION(FloatingPointFunc) \
[](const Attrs& attrs, const Array<Expr>& new_args, const Array<tvm::relay::Type>& arg_types) { \
QnnUnaryOpArguments args(new_args); \
QnnUnaryOpTensorType input_type(arg_types, 0); \
Expr dequantized_arg = MakeDequantize(args.x, args.scale, args.zero_point, -1); \
Expr output = FloatingPointFunc(dequantized_arg); \
Expr result = \
MakeQuantize(output, args.output_scale, args.output_zero_point, -1, input_type.dtype); \
return LegalizeExpr(result); \
}
} // namespace qnn
} // namespace relay
} // namespace tvm
Expand Down
Loading