Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions python/tvm/relay/ir_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,3 +533,19 @@ def gradient(expr, mod=None):
The output expression.
"""
return _ir_pass.first_order_gradient(expr, mod)

def get_total_mac_number(expr):
"""
Count the number of MACs (multiply-accumulate) of a model

Parameters
----------
expr : tvm.relay.Expr
The input expression.

Returns
-------
ret : int64
The number of MACs (multiply-accumulate) of a model
"""
return _ir_pass.GetTotalMacNumber(expr)
147 changes: 147 additions & 0 deletions src/relay/pass/mac_count.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
/*!
* Copyright (c) 2019 by Contributors
*
* \file mac_count.cc
* \brief Pass to roughly count the number of MACs (Multiply-Accumulate)
* operations of a model. Only MACs in CONV and Dense ops are counted.
* This pass is valid after the type infer pass is called,
* otherwise the count is 0.
*/

#include <tvm/relay/op.h>
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/expr_functor.h>
#include "../op/layout.h"

namespace tvm {
namespace relay {

namespace {

bool IsConv2DNode(const ExprNode* node) {
const auto* call_node = dynamic_cast<const CallNode*>(node);
return call_node != nullptr && call_node->attrs.as<Conv2DAttrs>();
}

bool IsDenseNode(const ExprNode* node) {
const auto* call_node = dynamic_cast<const CallNode*>(node);
return call_node != nullptr && call_node->attrs.as<DenseAttrs>();
}

} // namespace

class MacCounter : private ExprVisitor {
public:
MacCounter() {
count_ = 0;
}
static int64_t GetTotalMacNumber(const Expr& expr) {
LOG(INFO) << "This pass only counts MACs in direct CONV 2D and Dense ops";
MacCounter counter;
counter(expr);
return counter.count_;
}

private:
void VisitExpr_(const CallNode* call_node) final {
if (IsConv2DNode(call_node)) {
count_ += ComputeConv2DMacs(call_node);
} else if (IsDenseNode(call_node)) {
count_ += ComputeDenseMacs(call_node);
}
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we log somewhere to remind users the pass only calculates MAC for Conv2D and Dense?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a log info in the beginning of the pass.

ExprVisitor::VisitExpr_(call_node);
}

/*
* \brief Get the number of MACs of a CONV 2D node.
* \param call_node The CONV 2D call node.
* \return The number of MACs.
*/
int64_t ComputeConv2DMacs(const CallNode* call_node) {
CHECK(IsConv2DNode(call_node))
<< "The input call node must be a CONV 2D node.";
if (!call_node->checked_type_.defined()) {
LOG(WARNING) << "The infer type pass should be called before the mac count pass";
return 0;
}
Array<Expr> args = call_node->args;
CHECK(args.size() == 2)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure if we still need these checks. Are they supposed to be valid already since we have already inferred types.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it wouldn't hurt. This code can still run with warning if the infer type pass is not called.

<< "The number of input arguments of a CONV 2D node should be 2.";
const auto* conv_2d_attr = call_node->attrs.as<Conv2DAttrs>();
const auto* data_type = args[0]->checked_type().as<TensorTypeNode>();
Array<IndexExpr> data_shape = data_type->shape;
std::string data_layout = conv_2d_attr->data_layout;
int32_t C_ind = Layout(data_layout).Indexof('C');
int32_t c_ind = Layout(data_layout).Indexof('c');
CHECK(C_ind != -1 || c_ind != -1)
<< "There is no input channel dimension.";
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

C_ind must exist actually.

int64_t input_channel = 1;
if (C_ind != -1)
input_channel *= static_cast<int64_t>(data_shape[C_ind].as<IntImm>()->value);
if (c_ind != -1)
input_channel *= static_cast<int64_t>(data_shape[c_ind].as<IntImm>()->value);
Array<IndexExpr> kernel_size = conv_2d_attr->kernel_size;
CHECK(kernel_size.size() == 2)
<< "The dimension of the kernel size in Conv 2D should be 2.";
const auto* expr = call_node->checked_type().as<TensorTypeNode>();
Array<IndexExpr> output_tensor = expr->shape;
CHECK(output_tensor.size() == 4 || output_tensor.size() == 5)
<< "The dimension of the output tensor in Conv 2D should be 4 or 5.";
int64_t count = input_channel * GetCartesianProd(output_tensor) * GetCartesianProd(kernel_size);
return count;
}

/*
* \brief Get the number of MACs of a Dense node.
* \param call_node The Dense call node.
* \return The number of MACs.
*/
int64_t ComputeDenseMacs(const CallNode* call_node) {
CHECK(IsDenseNode(call_node))
<< "The input call node must be a Dense node.";
if (!call_node->checked_type_.defined()) {
LOG(WARNING) << "The infer type pass should be called before the mac count pass";
return 0;
}
Array<Expr> args = call_node->args;
CHECK(args.size() == 2)
<< "The number of input arguments of a Dense node should be 2.";
const auto* data_type = args[0]->checked_type().as<TensorTypeNode>();
const auto* weight_type = args[1]->checked_type().as<TensorTypeNode>();
Array<IndexExpr> data_shape = data_type->shape;
Array<IndexExpr> weight_shape = weight_type->shape;
CHECK(data_shape.size() == 2 && weight_shape.size() == 2)
<< "The dimension of an input tensor to Dense node should be 2.";
int64_t d1 = static_cast<int64_t>(data_shape[0].as<IntImm>()->value);
int64_t d2 = static_cast<int64_t>(data_shape[1].as<IntImm>()->value);
int64_t d3 = static_cast<int64_t>(weight_shape[0].as<IntImm>()->value);
int64_t d4 = static_cast<int64_t>(weight_shape[1].as<IntImm>()->value);
CHECK(d2 == d4)
<< "The dimensions of input arguments do not match.";
int64_t count = d1 * d2 * d3;
return count;
}

int64_t GetCartesianProd(Array<IndexExpr> arr) {
int64_t ret = 1;
for (size_t i = 0; i < arr.size(); i++) {
const auto* intImm = arr[i].as<IntImm>();
ret *= static_cast<int64_t>(intImm->value);
}
return ret;
}

int64_t count_;
};

int64_t GetTotalMacNumber(const Expr& expr) {
return MacCounter::GetTotalMacNumber(expr);
}

TVM_REGISTER_API("relay._ir_pass.GetTotalMacNumber")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = GetTotalMacNumber(args[0]);
});

} // namespace relay
} // namespace tvm
90 changes: 90 additions & 0 deletions tests/python/relay/test_pass_mac_count.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
"""Unit tests for MAC counter."""
import tvm
from tvm import relay
import sys

def test_gemm():
n = 512
k = 1024
m = 256
dshape1 = (n, k)
dshape2 = (m, k)
data1 = relay.var("data1", shape=dshape1)
data2 = relay.var("data2", shape=dshape2)
gemm = relay.nn.dense(data1, data2)
func = relay.Function([data1, data2],
relay.Tuple(tvm.convert([gemm])))
func = relay.ir_pass.infer_type(func)
compute_count = relay.ir_pass.get_total_mac_number(func)
expect_count = n * m * k
assert compute_count == expect_count

def test_conv():
batch_size = 1
input_channel = 3
h = 224
w = 224
output_channel = 64
kh = 7
kw = 7
h_padding = 1
w_padding = 1
oh = h + h_padding * 2 - kh + 1
ow = w + w_padding * 2 - kw + 1
dshape = (batch_size, input_channel, h, w)
weight = relay.var("weight", shape=(output_channel, input_channel, kh, kw))
data = relay.var("data", shape=dshape)
conv2d = relay.nn.conv2d(
data,
weight,
channels=output_channel,
kernel_size=(kh, kw),
padding=(1, 1))
func = relay.Function([data, weight],
relay.Tuple(tvm.convert([conv2d])))
func = relay.ir_pass.infer_type(func)
compute_count = relay.ir_pass.get_total_mac_number(func)
expect_count = batch_size * input_channel * oh * ow * output_channel * kh * kw
assert compute_count == expect_count

def test_simple_network():
batch_size = 1
dshape = (batch_size, 64, 56, 56)
weight_conv = relay.var("weight_conv", shape=(64, 64, 3, 3))
data1 = relay.var("data1", shape=dshape)
data2 = relay.var("data2", shape=dshape)
weight_dense = relay.var("weight_dense", shape=(1, 56*56*64))

conv2d_1 = relay.nn.conv2d(
data1,
weight_conv,
channels=64,
kernel_size=(3, 3),
padding=(1, 1))
conv2d_2 = relay.nn.conv2d(
data2,
weight_conv,
channels=64,
kernel_size=(3, 3),
padding=(1, 1))
add = relay.add(conv2d_1, conv2d_2)
flattened = relay.nn.batch_flatten(add)
dense_1 = relay.nn.dense(
flattened,
weight_dense)

func = relay.Function([data1, data2, weight_conv, weight_dense],
relay.Tuple(tvm.convert([conv2d_1, conv2d_2,
dense_1, add, flattened])))
func = relay.ir_pass.infer_type(func)
# alter the CONV 2D data layout to test
func = relay.ir_pass.alter_op_layout(func)
func = relay.ir_pass.infer_type(func)
compute_count = relay.ir_pass.get_total_mac_number(func)
expect_count = 231411712
assert compute_count == expect_count

if __name__ == "__main__":
test_conv()
test_gemm()
test_simple_network()