Skip to content

Commit aeae308

Browse files
kevinthesunanijain2305
authored andcommitted
Improve graph tuner dealing with Tuple (apache#3649)
* Improve graph tuner dealing with Tuple * Add test case * Move some data out of _base.py * Fix lint
1 parent 656ea51 commit aeae308

5 files changed

Lines changed: 139 additions & 25 deletions

File tree

python/tvm/autotvm/graph_tuner/_base.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,6 @@
1818
"""Helper functions and global data"""
1919

2020

21-
# Operators dependent on original layouts.
22-
LAYOUT_FIXED_OP = ["batch_flatten", "transpose", "reshape",
23-
"multibox_prior", "multibox_transform_loc", "where",
24-
"non_max_suppression", "strided_slice"]
25-
2621
# We set a large time to represent an invalid layout-transformation.
2722
# This number is set to be 10e9 seconds to align with autotvm.
2823
INVALID_LAYOUT_TIME = 10e9

python/tvm/autotvm/graph_tuner/base_graph_tuner.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -444,6 +444,7 @@ def _callback(_, inputs, results):
444444
timeout=timeout)
445445
measure_option = autotvm.measure_option(builder=builder, runner=runner)
446446
for args in args_list:
447+
data, in_layout, out_layout = args
447448
args = serialize_args(args)
448449
ltf_workload = ('layout_transform',) + autotvm.task.args_to_workload(args)
449450
if ltf_workload in self._layout_transform_perf_records:
@@ -454,7 +455,18 @@ def _callback(_, inputs, results):
454455
flops = 1
455456
for i in input_shape:
456457
flops *= i
457-
inferred_time = flops * avg_time
458+
459+
# Rule out invalid layout transformations
460+
out = topi.layout_transform(data, in_layout, out_layout)
461+
out_flops = 1
462+
for i in topi.util.get_const_tuple(out.shape):
463+
out_flops *= i
464+
465+
if flops != out_flops:
466+
inferred_time = INVALID_LAYOUT_TIME
467+
else:
468+
inferred_time = flops * avg_time
469+
458470
record_input = MeasureInput(target=self._target, task=None, config=None)
459471
record_output = MeasureResult(costs=(inferred_time,), error_no=0,
460472
all_cost=-1, timestamp=-1)

python/tvm/autotvm/graph_tuner/utils/traverse_graph.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from tvm.relay.ty import TupleType, TensorType
2727
from tvm.autotvm.task import TaskExtractEnv
2828

29-
from .utils import has_multiple_inputs, is_boundary_node
29+
from .utils import has_multiple_inputs, is_boundary_node, is_skipped_node
3030

3131

3232
# Setup relay op base name -> topi compute functions
@@ -252,7 +252,7 @@ def get_in_nodes(node_list, target_ops, input_names):
252252
visited_dict = {}
253253
in_node_dict = {}
254254
for i, node in enumerate(node_list):
255-
if is_boundary_node(node, input_names):
255+
if is_boundary_node(node, input_names) or is_skipped_node(node):
256256
continue
257257
get_direct_ancestor(node_list, visited_dict, target_ops, i, input_names)
258258
for key, val in visited_dict.items():
@@ -282,10 +282,12 @@ def get_in_nodes(node_list, target_ops, input_names):
282282
boundary_nodes.append(key)
283283
if boundary_nodes:
284284
for idx in boundary_nodes:
285-
del in_node_dict[idx]
285+
if idx in in_node_dict:
286+
del in_node_dict[idx]
286287
else:
287288
has_reduced_node = False
288289

290+
289291
# Remove empty nodes to ignore pre-computed sub-graph
290292
has_empty_node = True
291293
while has_empty_node:

python/tvm/autotvm/graph_tuner/utils/utils.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@
1919
from tvm import relay
2020
from tvm.relay import transform
2121

22-
from .._base import LAYOUT_FIXED_OP
23-
2422

2523
def has_multiple_inputs(node_list, node_idx, input_names):
2624
"""Check whether a node has multiple input nodes
@@ -72,11 +70,35 @@ def is_boundary_node(node_entry, input_names):
7270
out : bool
7371
whether node is a boundary node.
7472
"""
75-
out = node_entry["op"] in LAYOUT_FIXED_OP or \
73+
# Operators dependent on original layouts.
74+
_LAYOUT_FIXED_OP = ["batch_flatten", "transpose", "reshape",
75+
"multibox_prior", "multibox_transform_loc", "where",
76+
"non_max_suppression", "strided_slice"]
77+
78+
out = node_entry["op"] in _LAYOUT_FIXED_OP or \
7679
("name" in node_entry and node_entry["name"] in input_names)
7780
return out
7881

7982

83+
def is_skipped_node(node_entry):
84+
"""Whether a node is not counted.
85+
86+
Parameters
87+
----------
88+
node_entry : dict
89+
Node entry.
90+
91+
Returns
92+
-------
93+
out : bool
94+
whether node is skipped.
95+
"""
96+
# Operators not counted in graph tuner.
97+
_SKIPPED_OP = ["Tuple"]
98+
99+
return node_entry["op"] in _SKIPPED_OP
100+
101+
80102
def bind_inputs(expr, input_shapes=None, input_dtypes="float32"):
81103
"""Bind input variables of a relay function expression
82104
to new shapes and/or dtypes.

tests/python/unittest/test_graph_tuner_core.py

Lines changed: 96 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -354,33 +354,115 @@ def test_many_sub_graphs():
354354
ms_output = MeasureResult(costs=(1.91224744e-05,), error_no=0, all_cost=-1, timestamp=-1)
355355
ltf_records.append((ms_input, ms_output))
356356

357-
ltf_keys = []
358-
ltf_arg = [tvm.placeholder((1, 4, 8, 8, 4), dtype=dtype), "NCHW4c", "NCHW8c"]
359-
ltf_arg = autotvm.task.topi_integration.serialize_args(ltf_arg)
360-
ltf_wkl = ('layout_transform',) + autotvm.task.args_to_workload(ltf_arg)
361-
ltf_keys.append(ltf_wkl)
362-
ltf_arg = [tvm.placeholder((1, 1, 8, 8, 32), dtype=dtype), "NCHW32c", "NCHW4c"]
363-
ltf_arg = autotvm.task.topi_integration.serialize_args(ltf_arg)
364-
ltf_wkl = ('layout_transform',) + autotvm.task.args_to_workload(ltf_arg)
365-
ltf_keys.append(ltf_wkl)
366-
ltf_arg = [tvm.placeholder((1, 4, 8, 8, 8), dtype=dtype), "NCHW8c", "NCHW32c"]
357+
executor = DPTuner(net, {"data": dshape}, records, target_ops, target)
358+
executor.benchmark_layout_transform(layout_records=ltf_records, infer_layout=True)
359+
executor.run()
360+
out = [record[0].config for record in executor.get_optimal_records()]
361+
expected_out = [records[3][0].config, records[1][0].config, records[2][0].config]
362+
assert expected_out == out, "Output mismatch: expecting %s but got %s" \
363+
% (str(expected_out), str(out))
364+
365+
executor = PBQPTuner(net, {"data": dshape}, records, target_ops, target)
366+
executor.benchmark_layout_transform(layout_records=ltf_records, infer_layout=True)
367+
executor.run()
368+
out = [record[0].config for record in executor.get_optimal_records()]
369+
expected_out = [records[3][0].config, records[1][0].config, records[2][0].config]
370+
assert expected_out == out, "Output mismatch: expecting %s but got %s" \
371+
% (str(expected_out), str(out))
372+
373+
374+
def test_tuple():
375+
target = "llvm"
376+
dtype = "float32"
377+
dshape = (1, 5, 32, 32)
378+
layout = "NCHW"
379+
target_ops = [relay.nn.conv2d]
380+
381+
data = relay.var("data", shape=dshape, dtype=dtype)
382+
w0 = relay.var("w0_weight")
383+
conv0 = relay.nn.conv2d(data, w0, channels=2, kernel_size=(3, 3), padding=(1, 1))
384+
w1 = relay.var("w1_weight")
385+
conv1 = relay.nn.conv2d(data, w1, channels=3, kernel_size=(3, 3), padding=(1, 1))
386+
out = relay.concatenate([conv0, conv1], axis=1)
387+
net = relay.Function(relay.analysis.free_vars(out), out)
388+
net, params = relay.testing.create_workload(net)
389+
390+
tasks = autotvm.task.extract_from_program(net["main"],
391+
target=target,
392+
params=params,
393+
ops=(relay.op.nn.conv2d,))
394+
wkl_list = [
395+
create_workload((1, 5, 32, 32), (2, 5, 3, 3), (1, 1), (1, 1), (1, 1), layout, layout, dtype, dtype),
396+
create_workload((1, 5, 32, 32), (3, 5, 3, 3), (1, 1), (1, 1), (1, 1), layout, layout, dtype, dtype),
397+
]
398+
costs = [0.01, 0.012, 0.03, 0.04]
399+
config_list = []
400+
cfg_dict = {"i": -1,
401+
"c": None,
402+
"e": [["tile_ic", "sp", [1, 5]],
403+
["tile_oc", "sp", [1, 2]],
404+
["tile_ow", "sp", [4, 8]],
405+
["unroll_kw", "ot", True]],
406+
"t": ""}
407+
config_list.append(ConfigEntity.from_json_dict(cfg_dict))
408+
cfg_dict = {"i": -1,
409+
"c": None,
410+
"e": [["tile_ic", "sp", [1, 5]],
411+
["tile_oc", "sp", [1, 3]],
412+
["tile_ow", "sp", [2, 16]],
413+
["unroll_kw", "ot", False]],
414+
"t": ""}
415+
config_list.append(ConfigEntity.from_json_dict(cfg_dict))
416+
cfg_dict = {"i": -1,
417+
"c": None,
418+
"e": [["tile_ic", "sp", [1, 5]],
419+
["tile_oc", "sp", [2, 1]],
420+
["tile_ow", "sp", [4, 8]],
421+
["unroll_kw", "ot", True]],
422+
"t": ""}
423+
config_list.append(ConfigEntity.from_json_dict(cfg_dict))
424+
cfg_dict = {"i": -1,
425+
"c": None,
426+
"e": [["tile_ic", "sp", [1, 5]],
427+
["tile_oc", "sp", [3, 1]],
428+
["tile_ow", "sp", [2, 16]],
429+
["unroll_kw", "ot", False]],
430+
"t": ""}
431+
config_list.append(ConfigEntity.from_json_dict(cfg_dict))
432+
433+
records = []
434+
435+
wkl_list = wkl_list + wkl_list
436+
tasks = tasks + tasks
437+
for wkl, cost, config, task in zip(wkl_list, costs, config_list, tasks):
438+
task.workload = wkl
439+
ms_input = MeasureInput(target=target, task=task, config=config)
440+
ms_output = MeasureResult(costs=(cost,), error_no=0, all_cost=-1, timestamp=-1)
441+
records.append((ms_input, ms_output))
442+
443+
ltf_records = []
444+
ltf_arg = [tvm.placeholder((1, 64, 16, 16, 8), dtype=dtype), "NCHW8c", "NCHW512c"]
367445
ltf_arg = autotvm.task.topi_integration.serialize_args(ltf_arg)
368446
ltf_wkl = ('layout_transform',) + autotvm.task.args_to_workload(ltf_arg)
369-
ltf_keys.append(ltf_wkl)
447+
ltf_task = copy.deepcopy(tasks[0])
448+
ltf_task.workload = ltf_wkl
449+
ms_input = MeasureInput(target=target, task=ltf_task, config=None)
450+
ms_output = MeasureResult(costs=(1.91224744e-05,), error_no=0, all_cost=-1, timestamp=-1)
451+
ltf_records.append((ms_input, ms_output))
370452

371453
executor = DPTuner(net, {"data": dshape}, records, target_ops, target)
372454
executor.benchmark_layout_transform(layout_records=ltf_records, infer_layout=True)
373455
executor.run()
374456
out = [record[0].config for record in executor.get_optimal_records()]
375-
expected_out = [records[3][0].config, records[1][0].config, records[2][0].config]
457+
expected_out = [records[2][0].config, records[1][0].config]
376458
assert expected_out == out, "Output mismatch: expecting %s but got %s" \
377459
% (str(expected_out), str(out))
378460

379461
executor = PBQPTuner(net, {"data": dshape}, records, target_ops, target)
380462
executor.benchmark_layout_transform(layout_records=ltf_records, infer_layout=True)
381463
executor.run()
382464
out = [record[0].config for record in executor.get_optimal_records()]
383-
expected_out = [records[3][0].config, records[1][0].config, records[2][0].config]
465+
expected_out = [records[2][0].config, records[1][0].config]
384466
assert expected_out == out, "Output mismatch: expecting %s but got %s" \
385467
% (str(expected_out), str(out))
386468

@@ -390,3 +472,4 @@ def test_many_sub_graphs():
390472
test_DPTuner_run()
391473
test_PBQPTuner_run()
392474
test_many_sub_graphs()
475+
test_tuple()

0 commit comments

Comments
 (0)