Skip to content

Commit de18299

Browse files
author
Peter Yeh
committed
Merge branch 'master' into benchmark
* master: (21 commits) [Fix][VM] Fix VM invoke with set_params (apache#4079) [QNN] Refactor fixed point multiplication in requantize (apache#4073) Fix match case in Python-side expr functor (apache#4037) Hide symbols from dependent libraries if HIDE_PRIVATE_SYMBOLS is ON. (apache#4041) Add gradient for log-softmax (apache#4069) [DOC] Fix typos in tutorials (apache#4066) dicrease the complexity of CalcDep from exponential to linear (apache#4053) [Relay][AlterOp] Minor refactor. (apache#4064) [Relay][AlterOp] Improving support for broadcast layout alteration. (apache#4040) Add parses support for zeros_like tflite operator (apache#4042) [Bugfix][TF] reset graph after getting tag of savedmodel (apache#4055) [Relay][VM] Add more passes to VMCompiler (apache#4058) [Relay][VM] Add autotvm context when compile (apache#4062) [Bugfix] Fix target host for vm compiler (apache#4057) [Relay][Training] Add gradient for Crossentropy (apache#3925) [llvm] switch to use Align for llvm trunk (apache#4051) [Relay][TopHub] Add switch to disable TopHub download (apache#4015) [Relay][Op] Add instance norm op (apache#4004) [QNN][Relay] Calling Dialect passes from inside Relay Build API. (apache#3971) [RELAY/PASS] Fix the extent for the post_stmt in the loop partition (apache#3734) ...
2 parents 703c556 + b5bcdbb commit de18299

61 files changed

Lines changed: 1374 additions & 301 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

CMakeLists.txt

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,17 @@ target_link_libraries(tvm_topi tvm ${TVM_LINKER_LIBS} ${TVM_RUNTIME_LINKER_LIBS}
294294
target_link_libraries(tvm_runtime ${TVM_RUNTIME_LINKER_LIBS})
295295
target_link_libraries(nnvm_compiler tvm)
296296

297+
if (HIDE_PRIVATE_SYMBOLS AND NOT ${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
298+
set(HIDE_SYMBOLS_LINKER_FLAGS "-Wl,--exclude-libs,ALL")
299+
# Note: 'target_link_options' with 'PRIVATE' keyword would be cleaner
300+
# but it's not available until CMake 3.13. Switch to 'target_link_options'
301+
# once minimum CMake version is bumped up to 3.13 or above.
302+
target_link_libraries(tvm ${HIDE_SYMBOLS_LINKER_FLAGS})
303+
target_link_libraries(tvm_topi ${HIDE_SYMBOLS_LINKER_FLAGS})
304+
target_link_libraries(tvm_runtime ${HIDE_SYMBOLS_LINKER_FLAGS})
305+
target_link_libraries(nnvm_compiler ${HIDE_SYMBOLS_LINKER_FLAGS})
306+
endif()
307+
297308
# Related headers
298309
target_include_directories(
299310
tvm

include/tvm/data_layout.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,28 @@ class Layout : public NodeRef {
210210
return ct;
211211
}
212212

213+
/*!
214+
* \brief Returns a new layout where the dims have been expanded to match the primal dimensions.
215+
* \param dst_layout The dst layout to which current layout has to be expanded.
216+
* \return The expanded Layout.
217+
*/
218+
inline Layout ExpandPrimal(const Layout& dst_layout) {
219+
Layout new_src_layout;
220+
// 1) Find the axis which are missing in the current layout. Make them the prefix.
221+
std::string new_src_layout_str = "";
222+
for (auto dst_axis : dst_layout->axes) {
223+
if (LayoutAxis::Get(dst_axis).IsPrimal()) {
224+
if (!this->Contains(LayoutAxis::Get(dst_axis))) {
225+
new_src_layout_str += dst_axis->var->name_hint;
226+
}
227+
}
228+
}
229+
// 2) Now, add the primal axis of the current layout.
230+
new_src_layout_str += this->name();
231+
new_src_layout = Layout(new_src_layout_str);
232+
return new_src_layout;
233+
}
234+
213235
/*!
214236
* \brief return the index of the input axis.
215237
* If it is not found in the layout or the layout is undefined,

include/tvm/relay/attrs/nn.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -492,6 +492,29 @@ struct BatchNormAttrs : public tvm::AttrsNode<BatchNormAttrs> {
492492
}; // struct BatchNormAttrs
493493

494494

495+
/*! \brief Attributes used in instance_norm operator */
496+
struct InstanceNormAttrs : public tvm::AttrsNode<InstanceNormAttrs> {
497+
int axis;
498+
double epsilon;
499+
bool center;
500+
bool scale;
501+
502+
TVM_DECLARE_ATTRS(InstanceNormAttrs, "relay.attrs.InstanceNormAttrs") {
503+
TVM_ATTR_FIELD(axis)
504+
.describe("Specify which shape axis denotes the channel.")
505+
.set_default(1);
506+
TVM_ATTR_FIELD(epsilon)
507+
.describe("Small float added to variance to avoid dividing by zero")
508+
.set_default(1e-5);
509+
TVM_ATTR_FIELD(center).set_default(true)
510+
.describe("If true, add offset of beta to normalized tensor; "
511+
"otherwise, beta is ignored.");
512+
TVM_ATTR_FIELD(scale).set_default(true)
513+
.describe("If true, multiply by gamma; otherwise, gamma is ignored.");
514+
}
515+
}; // struct InstanceNormAttrs
516+
517+
495518
/*! \brief Attributes used in layer_norm operator */
496519
struct LayerNormAttrs : public tvm::AttrsNode<LayerNormAttrs> {
497520
int axis;

include/tvm/relay/op.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,12 @@ class Op : public relay::Expr {
153153
*/
154154
template <typename ValueType>
155155
inline static OpMap<ValueType> GetAttr(const std::string& attr_name);
156+
/*!
157+
* \brief Checks if an attr is present in the registry.
158+
* \param attr_name The name of the attribute.
159+
* \return bool True if the attr is present.
160+
*/
161+
inline static bool HasAttr(const std::string& attr_name);
156162
/*!
157163
* \brief Get an Op for a given operator name.
158164
* Will raise an error if the op has not been registered.
@@ -171,6 +177,12 @@ class Op : public relay::Expr {
171177
* \return reference to GenericOpMap
172178
*/
173179
TVM_DLL static const GenericOpMap& GetGenericAttr(const std::string& key);
180+
/*!
181+
* \brief Checks if the key is present in the registry
182+
* \param key The attribute key
183+
* \return bool True if the key is present
184+
*/
185+
TVM_DLL static const bool HasGenericAttr(const std::string& key);
174186
};
175187

176188
/*! \brief Helper structure to register operators */
@@ -393,6 +405,10 @@ inline OpMap<ValueType> Op::GetAttr(const std::string& key) {
393405
return OpMap<ValueType>(Op::GetGenericAttr(key));
394406
}
395407

408+
inline bool Op::HasAttr(const std::string& key) {
409+
return Op::HasGenericAttr(key);
410+
}
411+
396412
inline OpNode* OpRegistry::get() {
397413
return const_cast<OpNode*>(op_.operator->());
398414
}

include/tvm/relay/qnn/transform.h

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* \file tvm/relay/qnn/transform.h
22+
*
23+
* This file implements a pass manager for QNN ops using Relay Pass manager.
24+
*/
25+
#ifndef TVM_RELAY_QNN_TRANSFORM_H_
26+
#define TVM_RELAY_QNN_TRANSFORM_H_
27+
28+
#include <tvm/runtime/c_runtime_api.h>
29+
#include <tvm/relay/transform.h>
30+
31+
namespace tvm {
32+
namespace relay {
33+
34+
using relay::transform::Pass;
35+
36+
namespace qnn {
37+
namespace transform {
38+
39+
/*!
40+
* \brief Legalizes a QNN expr. Contains specifically two types of Legalizations. First,
41+
* converts/Lowers an expression containing QNN ops to an expression containing only core Relay ops.
42+
* Each QNN op is lowered to a sequence of exisiting Relay ops. This is a target-independent pass.
43+
* One can register the lowering/transformation function for this op using FTVMQnnCanonicalize
44+
* attr_name for FTVMLegalize op attribute. Second, as opposed to Relay Legalize, this one legalizes
45+
* only QNN ops. One can register a transformation/legalization function for an op by using the
46+
* FTVMQnnLegalize attr_name for FTVMLegalize op attribute. The isolation of QNN and Relay Legalize
47+
* gives us separation of concerns, leading to a better software practice. The legalization can be
48+
* configured to happen per target.
49+
*
50+
* \return The pass.
51+
*/
52+
TVM_DLL Pass Legalize();
53+
54+
} // namespace transform
55+
56+
} // namespace qnn
57+
} // namespace relay
58+
} // namespace tvm
59+
60+
#endif // TVM_RELAY_QNN_TRANSFORM_H_

python/tvm/autotvm/tophub.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818
TopHub: Tensor Operator Hub
1919
To get the best performance, we typically need auto-tuning for the specific devices.
2020
TVM releases pre-tuned parameters in TopHub for some common networks and hardware targets.
21-
TVM will download these parameters for you when you call nnvm.compiler.build_module .
21+
TVM will download these parameters for you when you call
22+
nnvm.compiler.build_module or relay.build.
2223
"""
2324
# pylint: disable=invalid-name
2425

@@ -30,6 +31,16 @@
3031
from .. import target as _target
3132
from ..contrib.download import download
3233
from .record import load_from_file
34+
from .util import EmptyContext
35+
36+
# environment variable to read TopHub location
37+
AUTOTVM_TOPHUB_LOC_VAR = "TOPHUB_LOCATION"
38+
39+
# default location of TopHub
40+
AUTOTVM_TOPHUB_DEFAULT_LOC = "https://raw.githubusercontent.com/uwsampl/tvm-distro/master/tophub"
41+
42+
# value of AUTOTVM_TOPHUB_LOC_VAR to specify to not read from TopHub
43+
AUTOTVM_TOPHUB_NONE_LOC = "NONE"
3344

3445
# root path to store TopHub files
3546
AUTOTVM_TOPHUB_ROOT_PATH = os.path.join(os.path.expanduser('~'), ".tvm", "tophub")
@@ -61,6 +72,9 @@ def _alias(name):
6172
}
6273
return table.get(name, name)
6374

75+
def _get_tophub_location():
76+
location = os.getenv(AUTOTVM_TOPHUB_LOC_VAR, None)
77+
return AUTOTVM_TOPHUB_DEFAULT_LOC if location is None else location
6478

6579
def context(target, extra_files=None):
6680
"""Return the dispatch context with pre-tuned parameters.
@@ -75,6 +89,10 @@ def context(target, extra_files=None):
7589
extra_files: list of str, optional
7690
Extra log files to load
7791
"""
92+
tophub_location = _get_tophub_location()
93+
if tophub_location == AUTOTVM_TOPHUB_NONE_LOC:
94+
return EmptyContext()
95+
7896
best_context = ApplyHistoryBest([])
7997

8098
targets = target if isinstance(target, (list, tuple)) else [target]
@@ -94,7 +112,7 @@ def context(target, extra_files=None):
94112
for name in possible_names:
95113
name = _alias(name)
96114
if name in all_packages:
97-
if not check_backend(name):
115+
if not check_backend(tophub_location, name):
98116
continue
99117

100118
filename = "%s_%s.log" % (name, PACKAGE_VERSION[name])
@@ -108,7 +126,7 @@ def context(target, extra_files=None):
108126
return best_context
109127

110128

111-
def check_backend(backend):
129+
def check_backend(tophub_location, backend):
112130
"""Check whether have pre-tuned parameters of the certain target.
113131
If not, will download it.
114132
@@ -135,18 +153,21 @@ def check_backend(backend):
135153
else:
136154
import urllib2
137155
try:
138-
download_package(package_name)
156+
download_package(tophub_location, package_name)
139157
return True
140158
except urllib2.URLError as e:
141159
logging.warning("Failed to download tophub package for %s: %s", backend, e)
142160
return False
143161

144162

145-
def download_package(package_name):
163+
def download_package(tophub_location, package_name):
146164
"""Download pre-tuned parameters of operators for a backend
147165
148166
Parameters
149167
----------
168+
tophub_location: str
169+
The location to download TopHub parameters from
170+
150171
package_name: str
151172
The name of package
152173
"""
@@ -160,9 +181,9 @@ def download_package(package_name):
160181
if not os.path.isdir(path):
161182
os.mkdir(path)
162183

163-
logger.info("Download pre-tuned parameters package %s", package_name)
164-
download("https://raw.githubusercontent.com/uwsampl/tvm-distro/master/tophub/%s"
165-
% package_name, os.path.join(rootpath, package_name), True, verbose=0)
184+
download_url = "{0}/{1}".format(tophub_location, package_name)
185+
logger.info("Download pre-tuned parameters package from %s", download_url)
186+
download(download_url, os.path.join(rootpath, package_name), True, verbose=0)
166187

167188

168189
# global cache for load_reference_log

python/tvm/relay/backend/vm.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import numpy as np
2424

2525
import tvm
26+
from tvm import autotvm
2627
from tvm._ffi.runtime_ctypes import TVMByteArray
2728
from . import _vm
2829
from . import vmobj as _obj
@@ -178,10 +179,24 @@ def compile(self, mod, target=None, target_host=None):
178179
"""
179180
target = _update_target(target)
180181
target_host = None if target_host == "" else target_host
182+
if not target_host:
183+
for device_type, tgt in target.items():
184+
if device_type.value == tvm.nd.cpu(0).device_type:
185+
target_host = tgt
186+
break
181187
if not target_host:
182188
target_host = "llvm" if tvm.module.enabled("llvm") else "stackvm"
183189
target_host = tvm.target.create(target_host)
184-
self._compile(mod, target, target_host)
190+
191+
# If current dispatch context is fallback context (the default root context),
192+
# then load pre-tuned parameters from TopHub
193+
if isinstance(autotvm.DispatchContext.current, autotvm.FallbackContext):
194+
tophub_context = autotvm.tophub.context(list(target.values()))
195+
else:
196+
tophub_context = autotvm.util.EmptyContext()
197+
198+
with tophub_context:
199+
self._compile(mod, target, target_host)
185200
return VirtualMachine(self._get_vm())
186201

187202

python/tvm/relay/expr_functor.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,10 @@ def visit_constructor(self, con):
249249
return con
250250

251251
def visit_match(self, m):
252-
return Match(self.visit(m.data), [Clause(c.lhs, self.visit(c.rhs)) for c in m.clauses])
252+
return Match(
253+
self.visit(m.data),
254+
[Clause(c.lhs, self.visit(c.rhs)) for c in m.clauses],
255+
complete=m.complete)
253256

254257
def visit_ref_create(self, r):
255258
return RefCreate(self.visit(r.value))

python/tvm/relay/frontend/mxnet.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,14 @@ def _mx_batch_norm(inputs, attrs):
324324
return _op.nn.batch_norm(*inputs, **new_attrs)
325325

326326

327+
def _mx_instance_norm(inputs, attrs):
328+
assert len(inputs) == 3
329+
new_attrs = {}
330+
new_attrs["axis"] = attrs.get_int("axis", 1)
331+
new_attrs["epsilon"] = attrs.get_float("eps", 1e-5)
332+
return _op.nn.instance_norm(*inputs, **new_attrs)
333+
334+
327335
def _mx_layer_norm(inputs, attrs):
328336
assert len(inputs) == 3
329337
if attrs.get_bool("output_mean_var", False):
@@ -1133,6 +1141,7 @@ def _mx_one_hot(inputs, attrs):
11331141
"Dropout" : _mx_dropout,
11341142
"BatchNorm" : _mx_batch_norm,
11351143
"BatchNorm_v1" : _mx_batch_norm,
1144+
"InstanceNorm" : _mx_instance_norm,
11361145
"LayerNorm" : _mx_layer_norm,
11371146
"LRN" : _mx_lrn,
11381147
"L2Normalization" : _mx_l2_normalize,

python/tvm/relay/frontend/onnx.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,15 @@ def _impl_v1(cls, inputs, attr, params):
176176
return out[0]
177177

178178

179+
class InstanceNorm(OnnxOpConverter):
180+
""" Operator converter for BatchNorm.
181+
"""
182+
183+
@classmethod
184+
def _impl_v1(cls, inputs, attr, params):
185+
return AttrCvt(op_name='instance_norm')(inputs, attr, params)
186+
187+
179188
class Conv(OnnxOpConverter):
180189
""" Operator converter for Conv.
181190
"""
@@ -999,7 +1008,7 @@ def _get_convert_map(opset):
9991008
'GlobalAveragePool': Renamer('global_avg_pool2d'),
10001009
'GlobalMaxPool': Renamer('global_max_pool2d'),
10011010
'BatchNormalization': BatchNorm.get_converter(opset),
1002-
# 'InstanceNormalization'
1011+
'InstanceNormalization': InstanceNorm.get_converter(opset),
10031012
# 'LpNormalization'
10041013
'Dropout': AttrCvt('dropout', {'ratio': 'rate'}, ignores=['is_test']),
10051014
'Flatten': Flatten.get_converter(opset),

0 commit comments

Comments
 (0)