-
Notifications
You must be signed in to change notification settings - Fork 3.9k
[RELAY][PASS] add a relay pass to count #macs of a model #2609
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 11 commits
453fc03
164046b
ce46195
2f18c5f
c515ecd
a91ea27
885ba93
c47dca8
b84122a
38beaa3
5de327a
4ee456f
5bf2f4e
5422405
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,147 @@ | ||
| /*! | ||
| * Copyright (c) 2019 by Contributors | ||
| * | ||
| * \file mac_count.cc | ||
| * \brief Pass to roughly count the number of MACs (Multiply-Accumulate) | ||
| * operations of a model. Only MACs in CONV and Dense ops are counted. | ||
| * This pass is valid after the type infer pass is called, | ||
| * otherwise the count is 0. | ||
| */ | ||
|
|
||
| #include <tvm/relay/op.h> | ||
| #include <tvm/relay/attrs/nn.h> | ||
| #include <tvm/relay/expr_functor.h> | ||
| #include "../op/layout.h" | ||
|
|
||
| namespace tvm { | ||
| namespace relay { | ||
|
|
||
| namespace { | ||
|
|
||
| bool IsConv2DNode(const ExprNode* node) { | ||
| const auto* call_node = dynamic_cast<const CallNode*>(node); | ||
| return call_node != nullptr && call_node->attrs.as<Conv2DAttrs>(); | ||
| } | ||
|
|
||
| bool IsDenseNode(const ExprNode* node) { | ||
| const auto* call_node = dynamic_cast<const CallNode*>(node); | ||
| return call_node != nullptr && call_node->attrs.as<DenseAttrs>(); | ||
| } | ||
|
|
||
| } // namespace | ||
|
|
||
| class MacCounter : private ExprVisitor { | ||
| public: | ||
| MacCounter() { | ||
| count_ = 0; | ||
| } | ||
| static int64_t GetTotalMacNumber(const Expr& expr) { | ||
| LOG(INFO) << "This pass only counts MACs in direct CONV 2D and Dense ops"; | ||
| MacCounter counter; | ||
| counter(expr); | ||
| return counter.count_; | ||
| } | ||
|
|
||
| private: | ||
| void VisitExpr_(const CallNode* call_node) final { | ||
| if (IsConv2DNode(call_node)) { | ||
| count_ += ComputeConv2DMacs(call_node); | ||
| } else if (IsDenseNode(call_node)) { | ||
| count_ += ComputeDenseMacs(call_node); | ||
| } | ||
| ExprVisitor::VisitExpr_(call_node); | ||
| } | ||
|
|
||
| /* | ||
| * \brief Get the number of MACs of a CONV 2D node. | ||
| * \param call_node The CONV 2D call node. | ||
| * \return The number of MACs. | ||
| */ | ||
| int64_t ComputeConv2DMacs(const CallNode* call_node) { | ||
| CHECK(IsConv2DNode(call_node)) | ||
| << "The input call node must be a CONV 2D node."; | ||
| if (!call_node->checked_type_.defined()) { | ||
| LOG(WARNING) << "The infer type pass should be called before the mac count pass"; | ||
| return 0; | ||
| } | ||
| Array<Expr> args = call_node->args; | ||
| CHECK(args.size() == 2) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am not sure if we still need these checks. Are they supposed to be valid already since we have already inferred types.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it wouldn't hurt. This code can still run with warning if the infer type pass is not called. |
||
| << "The number of input arguments of a CONV 2D node should be 2."; | ||
| const auto* conv_2d_attr = call_node->attrs.as<Conv2DAttrs>(); | ||
| const auto* data_type = args[0]->checked_type().as<TensorTypeNode>(); | ||
| Array<IndexExpr> data_shape = data_type->shape; | ||
| std::string data_layout = conv_2d_attr->data_layout; | ||
| int32_t C_ind = Layout(data_layout).Indexof('C'); | ||
| int32_t c_ind = Layout(data_layout).Indexof('c'); | ||
| CHECK(C_ind != -1 || c_ind != -1) | ||
| << "There is no input channel dimension."; | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| int64_t input_channel = 1; | ||
| if (C_ind != -1) | ||
| input_channel *= static_cast<int64_t>(data_shape[C_ind].as<IntImm>()->value); | ||
| if (c_ind != -1) | ||
| input_channel *= static_cast<int64_t>(data_shape[c_ind].as<IntImm>()->value); | ||
| Array<IndexExpr> kernel_size = conv_2d_attr->kernel_size; | ||
| CHECK(kernel_size.size() == 2) | ||
| << "The dimension of the kernel size in Conv 2D should be 2."; | ||
| const auto* expr = call_node->checked_type().as<TensorTypeNode>(); | ||
| Array<IndexExpr> output_tensor = expr->shape; | ||
| CHECK(output_tensor.size() == 4 || output_tensor.size() == 5) | ||
| << "The dimension of the output tensor in Conv 2D should be 4 or 5."; | ||
| int64_t count = input_channel * GetCartesianProd(output_tensor) * GetCartesianProd(kernel_size); | ||
| return count; | ||
| } | ||
|
|
||
| /* | ||
| * \brief Get the number of MACs of a Dense node. | ||
| * \param call_node The Dense call node. | ||
| * \return The number of MACs. | ||
| */ | ||
| int64_t ComputeDenseMacs(const CallNode* call_node) { | ||
| CHECK(IsDenseNode(call_node)) | ||
| << "The input call node must be a Dense node."; | ||
| if (!call_node->checked_type_.defined()) { | ||
| LOG(WARNING) << "The infer type pass should be called before the mac count pass"; | ||
| return 0; | ||
| } | ||
| Array<Expr> args = call_node->args; | ||
| CHECK(args.size() == 2) | ||
| << "The number of input arguments of a Dense node should be 2."; | ||
| const auto* data_type = args[0]->checked_type().as<TensorTypeNode>(); | ||
| const auto* weight_type = args[1]->checked_type().as<TensorTypeNode>(); | ||
| Array<IndexExpr> data_shape = data_type->shape; | ||
| Array<IndexExpr> weight_shape = weight_type->shape; | ||
| CHECK(data_shape.size() == 2 && weight_shape.size() == 2) | ||
| << "The dimension of an input tensor to Dense node should be 2."; | ||
| int64_t d1 = static_cast<int64_t>(data_shape[0].as<IntImm>()->value); | ||
| int64_t d2 = static_cast<int64_t>(data_shape[1].as<IntImm>()->value); | ||
| int64_t d3 = static_cast<int64_t>(weight_shape[0].as<IntImm>()->value); | ||
| int64_t d4 = static_cast<int64_t>(weight_shape[1].as<IntImm>()->value); | ||
| CHECK(d2 == d4) | ||
| << "The dimensions of input arguments do not match."; | ||
| int64_t count = d1 * d2 * d3; | ||
| return count; | ||
| } | ||
|
|
||
| int64_t GetCartesianProd(Array<IndexExpr> arr) { | ||
| int64_t ret = 1; | ||
| for (size_t i = 0; i < arr.size(); i++) { | ||
| const auto* intImm = arr[i].as<IntImm>(); | ||
| ret *= static_cast<int64_t>(intImm->value); | ||
| } | ||
| return ret; | ||
| } | ||
|
|
||
| int64_t count_; | ||
| }; | ||
|
|
||
| int64_t GetTotalMacNumber(const Expr& expr) { | ||
| return MacCounter::GetTotalMacNumber(expr); | ||
| } | ||
|
|
||
| TVM_REGISTER_API("relay._ir_pass.GetTotalMacNumber") | ||
| .set_body([](TVMArgs args, TVMRetValue *ret) { | ||
| *ret = GetTotalMacNumber(args[0]); | ||
| }); | ||
|
|
||
| } // namespace relay | ||
| } // namespace tvm | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,90 @@ | ||
| """Unit tests for MAC counter.""" | ||
| import tvm | ||
| from tvm import relay | ||
| import sys | ||
|
|
||
| def test_gemm(): | ||
| n = 512 | ||
| k = 1024 | ||
| m = 256 | ||
| dshape1 = (n, k) | ||
| dshape2 = (m, k) | ||
| data1 = relay.var("data1", shape=dshape1) | ||
| data2 = relay.var("data2", shape=dshape2) | ||
| gemm = relay.nn.dense(data1, data2) | ||
| func = relay.Function([data1, data2], | ||
| relay.Tuple(tvm.convert([gemm]))) | ||
| func = relay.ir_pass.infer_type(func) | ||
| compute_count = relay.ir_pass.get_total_mac_number(func) | ||
| expect_count = n * m * k | ||
| assert compute_count == expect_count | ||
|
|
||
| def test_conv(): | ||
| batch_size = 1 | ||
| input_channel = 3 | ||
| h = 224 | ||
| w = 224 | ||
| output_channel = 64 | ||
| kh = 7 | ||
| kw = 7 | ||
| h_padding = 1 | ||
| w_padding = 1 | ||
| oh = h + h_padding * 2 - kh + 1 | ||
| ow = w + w_padding * 2 - kw + 1 | ||
| dshape = (batch_size, input_channel, h, w) | ||
| weight = relay.var("weight", shape=(output_channel, input_channel, kh, kw)) | ||
| data = relay.var("data", shape=dshape) | ||
| conv2d = relay.nn.conv2d( | ||
| data, | ||
| weight, | ||
| channels=output_channel, | ||
| kernel_size=(kh, kw), | ||
| padding=(1, 1)) | ||
| func = relay.Function([data, weight], | ||
| relay.Tuple(tvm.convert([conv2d]))) | ||
| func = relay.ir_pass.infer_type(func) | ||
| compute_count = relay.ir_pass.get_total_mac_number(func) | ||
| expect_count = batch_size * input_channel * oh * ow * output_channel * kh * kw | ||
| assert compute_count == expect_count | ||
|
|
||
| def test_simple_network(): | ||
| batch_size = 1 | ||
| dshape = (batch_size, 64, 56, 56) | ||
| weight_conv = relay.var("weight_conv", shape=(64, 64, 3, 3)) | ||
| data1 = relay.var("data1", shape=dshape) | ||
| data2 = relay.var("data2", shape=dshape) | ||
| weight_dense = relay.var("weight_dense", shape=(1, 56*56*64)) | ||
|
|
||
| conv2d_1 = relay.nn.conv2d( | ||
| data1, | ||
| weight_conv, | ||
| channels=64, | ||
| kernel_size=(3, 3), | ||
| padding=(1, 1)) | ||
| conv2d_2 = relay.nn.conv2d( | ||
| data2, | ||
| weight_conv, | ||
| channels=64, | ||
| kernel_size=(3, 3), | ||
| padding=(1, 1)) | ||
| add = relay.add(conv2d_1, conv2d_2) | ||
| flattened = relay.nn.batch_flatten(add) | ||
| dense_1 = relay.nn.dense( | ||
| flattened, | ||
| weight_dense) | ||
|
|
||
| func = relay.Function([data1, data2, weight_conv, weight_dense], | ||
| relay.Tuple(tvm.convert([conv2d_1, conv2d_2, | ||
| dense_1, add, flattened]))) | ||
| func = relay.ir_pass.infer_type(func) | ||
| # alter the CONV 2D data layout to test | ||
| func = relay.ir_pass.alter_op_layout(func) | ||
| func = relay.ir_pass.infer_type(func) | ||
| compute_count = relay.ir_pass.get_total_mac_number(func) | ||
| expect_count = 231411712 | ||
| assert compute_count == expect_count | ||
|
|
||
| if __name__ == "__main__": | ||
| test_conv() | ||
| test_gemm() | ||
| test_simple_network() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we log somewhere to remind users the pass only calculates MAC for Conv2D and Dense?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a log info in the beginning of the pass.