Skip to content

Commit cc09497

Browse files
vinx13ZihengJiang
authored andcommitted
[Relay, Quantization, TOPI] int8 dense on CUDA & Dense op quantization (#2877)
* Quantize dense layers * Add out_dtype arggument to dense; Add dense_int8 on CUDA * Add topi unittest of dense int8 * Fix relay * Fix topi integration * Fix quantization * Update dense_rewrite * Triger CI * Change qconfig quantize_dense to quantize_op * Fix * Remove quantize_op from qconfig
1 parent 879f91d commit cc09497

20 files changed

Lines changed: 326 additions & 49 deletions

File tree

include/tvm/relay/attrs/nn.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,10 +336,16 @@ struct GlobalPool2DAttrs : public tvm::AttrsNode<GlobalPool2DAttrs> {
336336
/*! \brief Attributes for dense operator */
337337
struct DenseAttrs : public tvm::AttrsNode<DenseAttrs> {
338338
IndexExpr units;
339+
DataType out_dtype;
339340

340341
TVM_DECLARE_ATTRS(DenseAttrs, "relay.attrs.DenseAttrs") {
341342
TVM_ATTR_FIELD(units)
342343
.describe("Number of hidden units of the dense transformation.");
344+
345+
// use 0 bits to indicate none.
346+
TVM_ATTR_FIELD(out_dtype)
347+
.set_default(NullValue<DataType>())
348+
.describe("Output data type, set to explicit type under mixed precision setting");
343349
}
344350
};
345351

python/tvm/autotvm/task/topi_integration.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ def _topi_nn_conv2d_transpose_nchw(*args, **kwargs):
188188
def _topi_nn_dense(*args, **kwargs):
189189
assert not kwargs, "Do not support kwargs in template function call"
190190
args = deserialize_args(args)
191-
data, weight, bias = args
191+
data, weight, bias, _ = args
192192
C = topi.nn.dense(*args, **kwargs)
193193
s = topi.generic.schedule_dense([C])
194194
if bias is not None:

python/tvm/relay/op/nn/_nn.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,9 @@ def schedule_log_softmax(_, outputs, target):
5151
@reg.register_compute("nn.dense")
5252
def compute_dense(attrs, inputs, out_type, target):
5353
"""Compute definition of dense"""
54-
return [topi.nn.dense(inputs[0], inputs[1])]
54+
out_dtype = attrs.out_dtype
55+
out_dtype = inputs[0].dtype if out_dtype == "" else out_dtype
56+
return [topi.nn.dense(inputs[0], inputs[1], out_dtype=out_dtype)]
5557

5658
@reg.register_schedule("nn.dense")
5759
def schedule_dense(attrs, outputs, target):

python/tvm/relay/op/nn/nn.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -475,7 +475,7 @@ def bias_add(data, bias, axis=1):
475475
return _make.bias_add(data, bias, axis)
476476

477477

478-
def dense(data, weight, units=None):
478+
def dense(data, weight, units=None, out_dtype=""):
479479
"""Dense operator.
480480
Applies a linear transformation
481481
@@ -494,12 +494,15 @@ def dense(data, weight, units=None):
494494
units : int, optional
495495
Number of hidden units of the dense transformation.
496496
497+
out_dtype : str, optional
498+
Specifies the output data type for mixed precision dense.
499+
497500
Returns
498501
-------
499502
result : tvm.relay.Expr
500503
The computed result.
501504
"""
502-
return _make.dense(data, weight, units)
505+
return _make.dense(data, weight, units, out_dtype)
503506

504507

505508
def relu(data):

python/tvm/relay/quantize/_annotate.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,26 @@ def conv2d_rewrite(ref_call, new_args, ctx):
171171
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
172172

173173

174+
@register_annotate_function("nn.dense")
175+
def dense_rewrite(ref_call, new_args, ctx):
176+
"""Rewrite function for dense. Lhs of dense will be quantized to input field, and rhs of
177+
dense will be quantized to weight field. Output would be in activation field."""
178+
cnt = _conv_counter()
179+
if cnt < current_qconfig().skip_k_conv:
180+
return None
181+
lhs_expr, lhs_kind = _get_expr_kind(new_args[0])
182+
rhs_expr, rhs_kind = _get_expr_kind(new_args[1])
183+
184+
if lhs_kind is None or lhs_kind != QAnnotateKind.INPUT:
185+
lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.INPUT)
186+
187+
assert rhs_kind is None
188+
rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT)
189+
190+
expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
191+
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
192+
193+
174194
@register_annotate_function("multiply")
175195
def multiply_rewrite(ref_call, new_args, ctx):
176196
"""Rewrite function for multiply."""

src/relay/op/nn/nn.cc

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,18 +131,24 @@ bool DenseRel(const Array<Type>& types,
131131
oshape.Set((oshape.size() - 1), wshape[0]);
132132
}
133133

134+
DataType out_dtype = param->out_dtype;
135+
if (out_dtype.bits() == 0) {
136+
out_dtype = data->dtype;
137+
}
134138
// assign output type
135-
reporter->Assign(types[2], TensorTypeNode::make(oshape, data->dtype));
139+
reporter->Assign(types[2], TensorTypeNode::make(oshape, out_dtype));
136140
return true;
137141
}
138142

139143

140144
// Positional relay function to create dense operator used by frontend FFI.
141145
Expr MakeDense(Expr data,
142146
Expr weight,
143-
IndexExpr units) {
147+
IndexExpr units,
148+
DataType out_dtype) {
144149
auto attrs = make_node<DenseAttrs>();
145150
attrs->units = units;
151+
attrs->out_dtype = out_dtype;
146152
static const Op& op = Op::Get("nn.dense");
147153
return CallNode::make(op, {data, weight}, Attrs(attrs), {});
148154
}

src/relay/pass/quantize.cc

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,39 @@ RELAY_REGISTER_OP("nn.conv2d")
296296
.set_attr<FForwardRewrite>("FQRealizeRewrite", Conv2dRealize);
297297

298298

299+
Expr DenseRealize(const Call& ref_call,
300+
const Array<Expr>& new_args,
301+
const NodeRef& ctx) {
302+
const QConfig& cfg = QConfig::Current();
303+
CHECK_EQ(new_args.size(), 2);
304+
if (!new_args[0]->derived_from<TempExprNode>() || !new_args[1]->derived_from<TempExprNode>()) {
305+
return Expr(nullptr);
306+
}
307+
const auto* lhs = new_args[0].as<QRealizeIntExprNode>();
308+
const auto* rhs = new_args[1].as<QRealizeIntExprNode>();
309+
310+
Expr ldata = lhs->data;
311+
if (lhs->dtype != cfg->dtype_input) {
312+
ldata = Cast(ldata, cfg->dtype_input);
313+
}
314+
Expr rdata = Cast(rhs->data, cfg->dtype_weight);
315+
316+
const auto ref_attrs = ref_call->attrs.as<DenseAttrs>();
317+
auto attrs = make_node<DenseAttrs>();
318+
*attrs = *ref_attrs;
319+
DataType out_dtype = cfg->dtype_activation;
320+
attrs->out_dtype = out_dtype;
321+
322+
Expr ret = CallNode::make(ref_call->op,
323+
{ldata, rdata}, Attrs(attrs), ref_call->type_args);
324+
Expr dom_scale = FoldConstant(Multiply(lhs->dom_scale, rhs->dom_scale));
325+
return QRealizeIntExprNode::make(ret, dom_scale, out_dtype);
326+
}
327+
328+
RELAY_REGISTER_OP("nn.dense")
329+
.set_attr<FForwardRewrite>("FQRealizeRewrite", DenseRealize);
330+
331+
299332
Expr MulRealize(const Call& ref_call,
300333
const Array<Expr>& new_args,
301334
const NodeRef& ctx) {

topi/include/topi/cuda/dense.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,15 @@ namespace cuda {
4444
* \param data Tensor with shape [batch, in_dim]
4545
* \param weight Tensor with shape [out_dim, in_dim]
4646
* \param bias Tensor with shape [out_dim]. Optional; to omit bias, pass Tensor()
47+
* \param out_dtype Output data type. Used for mixed precision.
4748
*
4849
* \return Tensor with shape [batch, out_dim]
4950
*/
5051
inline tvm::Tensor dense_cuda(const Target& target,
5152
const tvm::Tensor& data,
5253
const tvm::Tensor& weight,
53-
const tvm::Tensor& bias) {
54+
const tvm::Tensor& bias,
55+
const Type& out_dtype) {
5456
CHECK_EQ(data->shape.size(), 2) << "dense requires 2-D data";
5557
CHECK_EQ(weight->shape.size(), 2) << "dense requires 2-D weight";
5658
if (bias.defined()) {
@@ -62,6 +64,7 @@ inline tvm::Tensor dense_cuda(const Target& target,
6264
auto out_dim = weight->shape[0];
6365

6466
if (target->libs().count("cublas")) {
67+
CHECK_EQ(data->dtype, out_dtype) << "Mixed precision not supported.";
6568
auto mm = topi::contrib::cublas_matmul(data, weight, false, true);
6669
if (bias.defined()) {
6770
mm = tvm::compute({ batch, out_dim },
@@ -72,7 +75,7 @@ inline tvm::Tensor dense_cuda(const Target& target,
7275

7376
return mm;
7477
} else {
75-
return topi::nn::dense(data, weight, bias);
78+
return topi::nn::dense(data, weight, bias, out_dtype);
7679
}
7780
}
7881

topi/include/topi/nn/dense.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,14 @@ using namespace tvm;
4040
* \param data Tensor with shape [batch, in_dim]
4141
* \param weight Tensor with shape [out_dim, in_dim]
4242
* \param bias Tensor with shape [out_dim]. Optional; to omit bias, pass Tensor()
43+
* \param out_dtype Output data type. Used for mixed precision.
4344
*
4445
* \return Tensor with shape [batch, out_dim]
4546
*/
4647
inline tvm::Tensor dense(const tvm::Tensor& data,
4748
const tvm::Tensor& weight,
48-
const tvm::Tensor& bias) {
49+
const tvm::Tensor& bias,
50+
const Type& out_dtype) {
4951
CHECK_EQ(data->shape.size(), 2) << "dense requires 2-D data";
5052
CHECK_EQ(weight->shape.size(), 2) << "dense requires 2-D weight";
5153
if (bias.defined()) {
@@ -60,14 +62,15 @@ inline tvm::Tensor dense(const tvm::Tensor& data,
6062
auto matmul = tvm::compute(
6163
{ batch, out_dim },
6264
[&](Var i, Var j) {
63-
return tvm::sum(data(i, k) * weight(j, k), { k });
65+
return tvm::sum(tvm::cast(out_dtype, data(i, k)) *
66+
tvm::cast(out_dtype, weight(j, k)), { k });
6467
}, "tensor", "dense");
6568

6669
if (bias.defined()) {
6770
matmul = tvm::compute(
6871
{ batch, out_dim },
6972
[&](Var i, Var j) {
70-
return matmul(i, j) + bias(j);
73+
return matmul(i, j) + tvm::cast(out_dtype, bias(j));
7174
}, "tensor", kBroadcast);
7275
}
7376

topi/include/topi/rocm/dense.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,13 +45,15 @@ namespace rocm {
4545
* \param data Tensor with shape [batch, in_dim]
4646
* \param weight Tensor with shape [out_dim, in_dim]
4747
* \param bias Tensor with shape [out_dim]. Optional; to omit bias, pass Tensor()
48+
* \param out_dtype Output data type. Used for mixed precision.
4849
*
4950
* \return Tensor with shape [batch, out_dim]
5051
*/
5152
inline tvm::Tensor dense_rocm(const Target& target,
5253
const tvm::Tensor& data,
5354
const tvm::Tensor& weight,
54-
const tvm::Tensor& bias) {
55+
const tvm::Tensor& bias,
56+
const Type& out_dtype) {
5557
CHECK_EQ(data->shape.size(), 2) << "dense requires 2-D data";
5658
CHECK_EQ(weight->shape.size(), 2) << "dense requires 2-D weight";
5759
if (bias.defined()) {
@@ -63,6 +65,7 @@ inline tvm::Tensor dense_rocm(const Target& target,
6365
auto out_dim = weight->shape[0];
6466

6567
if (target->libs().count("rocblas")) {
68+
CHECK_EQ(data->dtype, out_dtype) << "Mixed precision not supported.";
6669
auto mm = topi::contrib::rocblas_matmul(data, weight, false, true);
6770
if (bias.defined()) {
6871
mm = tvm::compute({ batch, out_dim },
@@ -73,7 +76,7 @@ inline tvm::Tensor dense_rocm(const Target& target,
7376

7477
return mm;
7578
} else {
76-
return topi::nn::dense(data, weight, bias);
79+
return topi::nn::dense(data, weight, bias, out_dtype);
7780
}
7881
}
7982

0 commit comments

Comments
 (0)