Skip to content

Commit 4f664f5

Browse files
zhresholdtqchen
authored andcommitted
[Frontend] Onnx (apache#40)
* init onnx finish onnx frontend add onnx tests fix various backup use transformer [Frontend] graph passed add test forward test forward fix doc and lint fix test graph tuple from_onnx now take 2 args, output (sym, params) fix rename fix input names fix multiple fix lint fix lint check * better doc
1 parent dddd8d1 commit 4f664f5

20 files changed

Lines changed: 1105 additions & 223 deletions

File tree

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
"""NNVM frontends."""
22
from __future__ import absolute_import
33
from .mxnet import from_mxnet
4+
from .onnx import from_onnx
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
"""Shared functions and classes for frontends."""
2+
from __future__ import absolute_import as _abs
3+
import warnings
4+
from .._base import string_types
5+
6+
class Renamer(object):
7+
"""A simply renamer for operators.
8+
9+
Parameters
10+
----------
11+
new_name : str
12+
The new name for the operator
13+
"""
14+
def __init__(self, new_name):
15+
self._new_name = new_name
16+
17+
def __call__(self, attrs):
18+
return self._new_name, attrs
19+
20+
21+
class AttrConverter(object):
22+
"""Common attribute conveter. An AttrConverter instance is a callable:
23+
```
24+
attr_converter = AttrConverter(op_name, transforms={'a':'b', 'c':('d', 1)})
25+
new_op_name, new_attr = attr_converter(attrs)
26+
```
27+
28+
Parameters
29+
----------
30+
op_name : str or callable
31+
If set as str, returned operator name is the str.
32+
If set as callable, returned operator is the str returned by calling:
33+
`op_name = func(attr)`
34+
transforms : dict of `new_name, or (new_name, default_value, transform function)`
35+
If only a new_name is provided, it's like renaming the attribute name.
36+
If default_value if provded, then the attribute is considered as optional.
37+
If transform function is provided, the original attribute value is handled
38+
by transform function.
39+
excludes : list
40+
A list of excluded attributes that should `NOT` appear.
41+
Raise NotImplementedError if occured.
42+
disables : list
43+
A list of attributes that is disabled in nnvm. Raise warnings.
44+
ignores : list
45+
A list of attributes that is ignored in nnvm. Silent.
46+
extras : dict
47+
A series of additional attributes should be added anyway to the returned
48+
attribute dict.
49+
custom_check : callable
50+
A custom function takes attribute, and return True/False.
51+
Raise RuntimeError if not bool(True) returned.
52+
"""
53+
def __init__(self, op_name, transforms=None,
54+
excludes=None, disables=None, ignores=None,
55+
extras=None, custom_check=None):
56+
self._op_name = op_name
57+
self._transforms = transforms if transforms else {}
58+
self._excludes = excludes if excludes else []
59+
self._disables = disables if disables else []
60+
self._ignores = ignores if ignores else []
61+
self._extras = extras if extras else {}
62+
self._custom_check = custom_check
63+
64+
def __call__(self, attrs):
65+
# apply custom check
66+
if self._custom_check:
67+
func, msg = self._custom_check
68+
if not func(attrs):
69+
raise RuntimeError("Check failed: {}".format(msg))
70+
# get new op_name
71+
if isinstance(self._op_name, string_types):
72+
op_name = self._op_name
73+
else:
74+
assert callable(self._op_name), "op_name can either be string or callable"
75+
op_name = self._op_name(attrs)
76+
# convert attributes
77+
new_attrs = {}
78+
for k in attrs.keys():
79+
if k in self._excludes:
80+
raise NotImplementedError("Attribute {} not supported yet.".format(k))
81+
elif k in self._disables:
82+
warnings.warn("Attribute {} is disabled in nnvm.sym.{}".format(k, op_name))
83+
elif k in self._ignores:
84+
pass
85+
elif k in self._transforms:
86+
new_name, defaults, transform = self._parse_default(self._transforms[k])
87+
if defaults is None:
88+
new_attr = self._required_attr(attrs, k)
89+
else:
90+
new_attr = attrs.get(k, None)
91+
if new_attr is None:
92+
new_attrs[new_name] = defaults
93+
else:
94+
new_attrs[new_name] = transform(new_attr)
95+
else:
96+
# copy
97+
new_attrs[k] = attrs[k]
98+
# add extras
99+
new_attrs.update(self._extras)
100+
return op_name, new_attrs
101+
102+
def _parse_default(self, target):
103+
"""Helper function to parse default values."""
104+
if not isinstance(target, (list, tuple)):
105+
k, v, t = target, None, lambda x: x
106+
elif len(target) == 1:
107+
k, v, t = target[0], None, lambda x: x
108+
elif len(target) == 2:
109+
k, v, t = target[0], target[1], lambda x: x
110+
elif len(target) > 2:
111+
k, v, t = target[0], target[1], target[2]
112+
else:
113+
k = None # should raise
114+
if not isinstance(k, string_types):
115+
msg = "{} is not a valid target, (name, default) expected.".format(target)
116+
raise ValueError(msg)
117+
return k, v, t
118+
119+
def _parse_bool(self, value):
120+
"""Helper function to parse default boolean values."""
121+
if isinstance(value, string_types):
122+
return value.strip().lower() in ['true', '1', 't', 'y', 'yes']
123+
return bool(value)
124+
125+
def _required_attr(self, attr, key):
126+
"""Wrapper for getting required attributes."""
127+
assert isinstance(attr, dict)
128+
if key not in attr:
129+
raise AttributeError("Required attribute {} not found.".format(key))
130+
return attr[key]

nnvm/python/nnvm/frontend/mxnet.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,12 +58,12 @@ def _pooling(attrs):
5858
def _batch_norm(attrs):
5959
if _parse_bool_str(attrs, 'output_mean_var'):
6060
_raise_not_supported('output_mean_var', 'batch_norm')
61-
if _parse_bool_str(attrs, 'fix_gamma'):
62-
_warn_not_used('fix_gamma', 'batch_norm')
61+
# if _parse_bool_str(attrs, 'fix_gamma'):
62+
# _warn_not_used('fix_gamma', 'batch_norm')
6363
if _parse_bool_str(attrs, 'use_global_stats'):
6464
_warn_not_used('use_global_stats', 'batch_norm')
65-
if _parse_bool_str(attrs, 'momentum'):
66-
_warn_not_used('momentum', 'batch_norm')
65+
# if _parse_bool_str(attrs, 'momentum'):
66+
# _warn_not_used('momentum', 'batch_norm')
6767
op_name, new_attrs = 'batch_norm', {}
6868
new_attrs['axis'] = attrs.get('axis', 1)
6969
new_attrs['epsilon'] = attrs.get('eps', 0.001)

0 commit comments

Comments
 (0)