Skip to content

Commit 20ac351

Browse files
committed
[CYTHON] Make speedup component minimum (#13)
1 parent ad0ab0a commit 20ac351

File tree

4 files changed

+241
-359
lines changed

4 files changed

+241
-359
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Module space to register internal functions. Leave empty"""

nnvm/python/nnvm/ctypes/symbol.py

Lines changed: 18 additions & 183 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,9 @@
1313
from ..name import NameManager
1414
from ..attribute import AttrScope
1515

16-
__all__ = ["Symbol", "Variable"]
17-
18-
class Symbol(object):
16+
class SymbolBase(object):
1917
"""Symbol is symbolic graph."""
20-
18+
__slots__ = ["handle"]
2119
# pylint: disable=no-member
2220
def __init__(self, handle):
2321
"""Initialize the function with handle
@@ -32,15 +30,6 @@ def __init__(self, handle):
3230
def __del__(self):
3331
check_call(_LIB.NNSymbolFree(self.handle))
3432

35-
def __copy__(self):
36-
return copy.deepcopy(self)
37-
38-
def __deepcopy__(self, _):
39-
handle = SymbolHandle()
40-
check_call(_LIB.NNSymbolCopy(self.handle,
41-
ctypes.byref(handle)))
42-
return Symbol(handle)
43-
4433
def __call__(self, *args, **kwargs):
4534
"""Invoke symbol as function on inputs.
4635
@@ -85,10 +74,10 @@ def _compose(self, *args, **kwargs):
8574
either as positional or keyword arguments, not both')
8675

8776
for arg in args:
88-
if not isinstance(arg, Symbol):
77+
if not isinstance(arg, SymbolBase):
8978
raise TypeError('Compose expect `Symbol` as arguments')
9079
for val in kwargs.values():
91-
if not isinstance(val, Symbol):
80+
if not isinstance(val, SymbolBase):
9281
raise TypeError('Compose expect `Symbol` as arguments')
9382

9483
num_args = len(args) + len(kwargs)
@@ -101,65 +90,6 @@ def _compose(self, *args, **kwargs):
10190
check_call(_LIB.NNSymbolCompose(
10291
self.handle, name, num_args, keys, args))
10392

104-
def __getitem__(self, index):
105-
if isinstance(index, string_types):
106-
idx = None
107-
for i, name in enumerate(self.list_outputs()):
108-
if name == index:
109-
if idx is not None:
110-
raise ValueError('There are multiple outputs with name \"%s\"' % index)
111-
idx = i
112-
if idx is None:
113-
raise ValueError('Cannot find output that matches name \"%s\"' % index)
114-
index = idx
115-
if not isinstance(index, int):
116-
raise TypeError('Symbol only support integer index to fetch i-th output')
117-
handle = SymbolHandle()
118-
check_call(_LIB.NNSymbolGetOutput(
119-
self.handle, nn_uint(index), ctypes.byref(handle)))
120-
return Symbol(handle=handle)
121-
122-
def attr(self, key):
123-
"""Get attribute string from the symbol, this function only works for non-grouped symbol.
124-
125-
Parameters
126-
----------
127-
key : str
128-
The key to get attribute from.
129-
130-
Returns
131-
-------
132-
value : str
133-
The attribute value of the key, returns None if attribute do not exist.
134-
"""
135-
ret = ctypes.c_char_p()
136-
success = ctypes.c_int()
137-
check_call(_LIB.NNSymbolGetAttr(
138-
self.handle, c_str(key), ctypes.byref(ret), ctypes.byref(success)))
139-
if success.value != 0:
140-
return py_str(ret.value)
141-
else:
142-
return None
143-
144-
def list_attr(self, recursive=False):
145-
"""Get all attributes from the symbol.
146-
147-
Parameters
148-
----------
149-
recursive : bool
150-
Default `False`. When `recursive` is `True`, list recursively all the
151-
attributes in the descendents. The attribute names are pre-pended with
152-
the symbol names to avoid conflicts. If `False`, then only attributes
153-
that belongs to this symbol is returned, and the attribute names will
154-
**not** be pre-pended with the symbol name.
155-
"""
156-
size = nn_uint()
157-
pairs = ctypes.POINTER(ctypes.c_char_p)()
158-
option = ctypes.c_int(0) if recursive else ctypes.c_int(1)
159-
check_call(_LIB.NNSymbolListAttrs(
160-
self.handle, option, ctypes.byref(size), ctypes.byref(pairs)))
161-
return {py_str(pairs[i*2]): py_str(pairs[i*2+1]) for i in range(size.value)}
162-
16393
def _set_attr(self, **kwargs):
16494
"""Set the attribute of the symbol.
16595
@@ -168,116 +98,20 @@ def _set_attr(self, **kwargs):
16898
**kwargs
16999
The attributes to set
170100
"""
171-
keys = c_array(ctypes.c_char_p, [c_str(key) for key in kwargs.keys()])
172-
vals = c_array(ctypes.c_char_p, [c_str(str(val)) for val in kwargs.values()])
173-
num_args = nn_uint(len(kwargs))
174-
check_call(_LIB.NNSymbolSetAttrs(
101+
keys = _base.c_array(_ctypes.c_char_p,
102+
[_base.c_str(key) for key in kwargs.keys()])
103+
vals = _base.c_array(_ctypes.c_char_p,
104+
[_base.c_str(str(val)) for val in kwargs.values()])
105+
num_args = _base.nn_uint(len(kwargs))
106+
_check_call(_LIB.NNSymbolSetAttrs(
175107
self.handle, num_args, keys, vals))
176108

177-
def get_internals(self):
178-
"""Get a new grouped symbol whose output contains all the internal outputs of this symbol.
179109

180-
Returns
181-
-------
182-
sgroup : Symbol
183-
The internal of the symbol.
184-
"""
185-
handle = SymbolHandle()
186-
check_call(_LIB.NNSymbolGetInternals(
187-
self.handle, ctypes.byref(handle)))
188-
return Symbol(handle=handle)
189-
190-
def list_arguments(self):
191-
"""List all the arguments in the symbol.
192-
193-
Returns
194-
-------
195-
args : list of string
196-
List of all the arguments.
197-
"""
198-
size = ctypes.c_uint()
199-
sarr = ctypes.POINTER(ctypes.c_char_p)()
200-
check_call(_LIB.NNSymbolListArguments(
201-
self.handle, ctypes.byref(size), ctypes.byref(sarr)))
202-
return [py_str(sarr[i]) for i in range(size.value)]
203-
204-
def list_outputs(self):
205-
"""List all outputs in the symbol.
110+
_symbol_cls = SymbolBase
206111

207-
Returns
208-
-------
209-
returns : list of string
210-
List of all the outputs.
211-
"""
212-
size = ctypes.c_uint()
213-
sarr = ctypes.POINTER(ctypes.c_char_p)()
214-
check_call(_LIB.NNSymbolListOutputs(
215-
self.handle, ctypes.byref(size), ctypes.byref(sarr)))
216-
return [py_str(sarr[i]) for i in range(size.value)]
217-
218-
def debug_str(self):
219-
"""Get a debug string.
220-
221-
Returns
222-
-------
223-
debug_str : string
224-
Debug string of the symbol.
225-
"""
226-
debug_str = ctypes.c_char_p()
227-
check_call(_LIB.NNSymbolPrint(
228-
self.handle, ctypes.byref(debug_str)))
229-
return py_str(debug_str.value)
230-
231-
232-
def Variable(name, **kwargs):
233-
"""Create a symbolic variable with specified name.
234-
235-
Parameters
236-
----------
237-
name : str
238-
Name of the variable.
239-
kwargs : dict of string -> string
240-
Additional attributes to set on the variable.
241-
242-
Returns
243-
-------
244-
variable : Symbol
245-
The created variable symbol.
246-
"""
247-
if not isinstance(name, string_types):
248-
raise TypeError('Expect a string for variable `name`')
249-
handle = SymbolHandle()
250-
check_call(_LIB.NNSymbolCreateVariable(c_str(name), ctypes.byref(handle)))
251-
ret = Symbol(handle)
252-
attr = AttrScope.current.get(kwargs)
253-
if attr:
254-
ret._set_attr(**attr)
255-
return ret
256-
257-
258-
def Group(symbols):
259-
"""Create a symbol that groups symbols together.
260-
261-
Parameters
262-
----------
263-
symbols : list
264-
List of symbols to be grouped.
265-
266-
Returns
267-
-------
268-
sym : Symbol
269-
The created group symbol.
270-
"""
271-
ihandles = []
272-
for sym in symbols:
273-
if not isinstance(sym, Symbol):
274-
raise TypeError('Expect Symbols in the list input')
275-
ihandles.append(sym.handle)
276-
handle = SymbolHandle()
277-
check_call(_LIB.NNSymbolCreateGroup(
278-
nn_uint(len(ihandles)),
279-
c_array(SymbolHandle, ihandles), ctypes.byref(handle)))
280-
return Symbol(handle)
112+
def _set_symbol_class(cls):
113+
global _symbol_cls
114+
_symbol_cls = cls
281115

282116

283117
def _make_atomic_symbol_function(handle):
@@ -332,7 +166,7 @@ def creator(*args, **kwargs):
332166
attr = kwargs.pop('attr', None)
333167

334168
for k, v in kwargs.items():
335-
if isinstance(v, Symbol):
169+
if isinstance(v, SymbolBase):
336170
symbol_kwargs[k] = v
337171
else:
338172
param_keys.append(c_str(k))
@@ -351,7 +185,7 @@ def creator(*args, **kwargs):
351185
raise TypeError(
352186
'%s can only accept input'
353187
'Symbols either as positional or keyword arguments, not both' % func_name)
354-
s = Symbol(sym_handle)
188+
s = _symbol_cls(sym_handle)
355189
attr = AttrScope.current.get(attr)
356190
if attr:
357191
s._set_attr(**attr)
@@ -373,11 +207,12 @@ def _init_symbol_module():
373207
check_call(_LIB.NNSymbolListAtomicSymbolCreators(ctypes.byref(size),
374208
ctypes.byref(plist)))
375209
module_obj = sys.modules["nnvm.symbol"]
210+
module_internal = sys.modules["nnvm._symbol_internal"]
376211
for i in range(size.value):
377212
hdl = SymbolHandle(plist[i])
378213
function = _make_atomic_symbol_function(hdl)
379214
if function.__name__.startswith('_'):
380-
setattr(Symbol, function.__name__, staticmethod(function))
215+
setattr(module_internal, function.__name__, function)
381216
else:
382217
setattr(module_obj, function.__name__, function)
383218

0 commit comments

Comments
 (0)