1313from ..name import NameManager
1414from ..attribute import AttrScope
1515
16- __all__ = ["Symbol" , "Variable" ]
17-
18- class Symbol (object ):
16+ class SymbolBase (object ):
1917 """Symbol is symbolic graph."""
20-
18+ __slots__ = [ "handle" ]
2119 # pylint: disable=no-member
2220 def __init__ (self , handle ):
2321 """Initialize the function with handle
@@ -32,15 +30,6 @@ def __init__(self, handle):
3230 def __del__ (self ):
3331 check_call (_LIB .NNSymbolFree (self .handle ))
3432
35- def __copy__ (self ):
36- return copy .deepcopy (self )
37-
38- def __deepcopy__ (self , _ ):
39- handle = SymbolHandle ()
40- check_call (_LIB .NNSymbolCopy (self .handle ,
41- ctypes .byref (handle )))
42- return Symbol (handle )
43-
4433 def __call__ (self , * args , ** kwargs ):
4534 """Invoke symbol as function on inputs.
4635
@@ -85,10 +74,10 @@ def _compose(self, *args, **kwargs):
8574 either as positional or keyword arguments, not both' )
8675
8776 for arg in args :
88- if not isinstance (arg , Symbol ):
77+ if not isinstance (arg , SymbolBase ):
8978 raise TypeError ('Compose expect `Symbol` as arguments' )
9079 for val in kwargs .values ():
91- if not isinstance (val , Symbol ):
80+ if not isinstance (val , SymbolBase ):
9281 raise TypeError ('Compose expect `Symbol` as arguments' )
9382
9483 num_args = len (args ) + len (kwargs )
@@ -101,65 +90,6 @@ def _compose(self, *args, **kwargs):
10190 check_call (_LIB .NNSymbolCompose (
10291 self .handle , name , num_args , keys , args ))
10392
104- def __getitem__ (self , index ):
105- if isinstance (index , string_types ):
106- idx = None
107- for i , name in enumerate (self .list_outputs ()):
108- if name == index :
109- if idx is not None :
110- raise ValueError ('There are multiple outputs with name \" %s\" ' % index )
111- idx = i
112- if idx is None :
113- raise ValueError ('Cannot find output that matches name \" %s\" ' % index )
114- index = idx
115- if not isinstance (index , int ):
116- raise TypeError ('Symbol only support integer index to fetch i-th output' )
117- handle = SymbolHandle ()
118- check_call (_LIB .NNSymbolGetOutput (
119- self .handle , nn_uint (index ), ctypes .byref (handle )))
120- return Symbol (handle = handle )
121-
122- def attr (self , key ):
123- """Get attribute string from the symbol, this function only works for non-grouped symbol.
124-
125- Parameters
126- ----------
127- key : str
128- The key to get attribute from.
129-
130- Returns
131- -------
132- value : str
133- The attribute value of the key, returns None if attribute do not exist.
134- """
135- ret = ctypes .c_char_p ()
136- success = ctypes .c_int ()
137- check_call (_LIB .NNSymbolGetAttr (
138- self .handle , c_str (key ), ctypes .byref (ret ), ctypes .byref (success )))
139- if success .value != 0 :
140- return py_str (ret .value )
141- else :
142- return None
143-
144- def list_attr (self , recursive = False ):
145- """Get all attributes from the symbol.
146-
147- Parameters
148- ----------
149- recursive : bool
150- Default `False`. When `recursive` is `True`, list recursively all the
151- attributes in the descendents. The attribute names are pre-pended with
152- the symbol names to avoid conflicts. If `False`, then only attributes
153- that belongs to this symbol is returned, and the attribute names will
154- **not** be pre-pended with the symbol name.
155- """
156- size = nn_uint ()
157- pairs = ctypes .POINTER (ctypes .c_char_p )()
158- option = ctypes .c_int (0 ) if recursive else ctypes .c_int (1 )
159- check_call (_LIB .NNSymbolListAttrs (
160- self .handle , option , ctypes .byref (size ), ctypes .byref (pairs )))
161- return {py_str (pairs [i * 2 ]): py_str (pairs [i * 2 + 1 ]) for i in range (size .value )}
162-
16393 def _set_attr (self , ** kwargs ):
16494 """Set the attribute of the symbol.
16595
@@ -168,116 +98,20 @@ def _set_attr(self, **kwargs):
16898 **kwargs
16999 The attributes to set
170100 """
171- keys = c_array (ctypes .c_char_p , [c_str (key ) for key in kwargs .keys ()])
172- vals = c_array (ctypes .c_char_p , [c_str (str (val )) for val in kwargs .values ()])
173- num_args = nn_uint (len (kwargs ))
174- check_call (_LIB .NNSymbolSetAttrs (
101+ keys = _base .c_array (_ctypes .c_char_p ,
102+ [_base .c_str (key ) for key in kwargs .keys ()])
103+ vals = _base .c_array (_ctypes .c_char_p ,
104+ [_base .c_str (str (val )) for val in kwargs .values ()])
105+ num_args = _base .nn_uint (len (kwargs ))
106+ _check_call (_LIB .NNSymbolSetAttrs (
175107 self .handle , num_args , keys , vals ))
176108
177- def get_internals (self ):
178- """Get a new grouped symbol whose output contains all the internal outputs of this symbol.
179109
180- Returns
181- -------
182- sgroup : Symbol
183- The internal of the symbol.
184- """
185- handle = SymbolHandle ()
186- check_call (_LIB .NNSymbolGetInternals (
187- self .handle , ctypes .byref (handle )))
188- return Symbol (handle = handle )
189-
190- def list_arguments (self ):
191- """List all the arguments in the symbol.
192-
193- Returns
194- -------
195- args : list of string
196- List of all the arguments.
197- """
198- size = ctypes .c_uint ()
199- sarr = ctypes .POINTER (ctypes .c_char_p )()
200- check_call (_LIB .NNSymbolListArguments (
201- self .handle , ctypes .byref (size ), ctypes .byref (sarr )))
202- return [py_str (sarr [i ]) for i in range (size .value )]
203-
204- def list_outputs (self ):
205- """List all outputs in the symbol.
110+ _symbol_cls = SymbolBase
206111
207- Returns
208- -------
209- returns : list of string
210- List of all the outputs.
211- """
212- size = ctypes .c_uint ()
213- sarr = ctypes .POINTER (ctypes .c_char_p )()
214- check_call (_LIB .NNSymbolListOutputs (
215- self .handle , ctypes .byref (size ), ctypes .byref (sarr )))
216- return [py_str (sarr [i ]) for i in range (size .value )]
217-
218- def debug_str (self ):
219- """Get a debug string.
220-
221- Returns
222- -------
223- debug_str : string
224- Debug string of the symbol.
225- """
226- debug_str = ctypes .c_char_p ()
227- check_call (_LIB .NNSymbolPrint (
228- self .handle , ctypes .byref (debug_str )))
229- return py_str (debug_str .value )
230-
231-
232- def Variable (name , ** kwargs ):
233- """Create a symbolic variable with specified name.
234-
235- Parameters
236- ----------
237- name : str
238- Name of the variable.
239- kwargs : dict of string -> string
240- Additional attributes to set on the variable.
241-
242- Returns
243- -------
244- variable : Symbol
245- The created variable symbol.
246- """
247- if not isinstance (name , string_types ):
248- raise TypeError ('Expect a string for variable `name`' )
249- handle = SymbolHandle ()
250- check_call (_LIB .NNSymbolCreateVariable (c_str (name ), ctypes .byref (handle )))
251- ret = Symbol (handle )
252- attr = AttrScope .current .get (kwargs )
253- if attr :
254- ret ._set_attr (** attr )
255- return ret
256-
257-
258- def Group (symbols ):
259- """Create a symbol that groups symbols together.
260-
261- Parameters
262- ----------
263- symbols : list
264- List of symbols to be grouped.
265-
266- Returns
267- -------
268- sym : Symbol
269- The created group symbol.
270- """
271- ihandles = []
272- for sym in symbols :
273- if not isinstance (sym , Symbol ):
274- raise TypeError ('Expect Symbols in the list input' )
275- ihandles .append (sym .handle )
276- handle = SymbolHandle ()
277- check_call (_LIB .NNSymbolCreateGroup (
278- nn_uint (len (ihandles )),
279- c_array (SymbolHandle , ihandles ), ctypes .byref (handle )))
280- return Symbol (handle )
112+ def _set_symbol_class (cls ):
113+ global _symbol_cls
114+ _symbol_cls = cls
281115
282116
283117def _make_atomic_symbol_function (handle ):
@@ -332,7 +166,7 @@ def creator(*args, **kwargs):
332166 attr = kwargs .pop ('attr' , None )
333167
334168 for k , v in kwargs .items ():
335- if isinstance (v , Symbol ):
169+ if isinstance (v , SymbolBase ):
336170 symbol_kwargs [k ] = v
337171 else :
338172 param_keys .append (c_str (k ))
@@ -351,7 +185,7 @@ def creator(*args, **kwargs):
351185 raise TypeError (
352186 '%s can only accept input'
353187 'Symbols either as positional or keyword arguments, not both' % func_name )
354- s = Symbol (sym_handle )
188+ s = _symbol_cls (sym_handle )
355189 attr = AttrScope .current .get (attr )
356190 if attr :
357191 s ._set_attr (** attr )
@@ -373,11 +207,12 @@ def _init_symbol_module():
373207 check_call (_LIB .NNSymbolListAtomicSymbolCreators (ctypes .byref (size ),
374208 ctypes .byref (plist )))
375209 module_obj = sys .modules ["nnvm.symbol" ]
210+ module_internal = sys .modules ["nnvm._symbol_internal" ]
376211 for i in range (size .value ):
377212 hdl = SymbolHandle (plist [i ])
378213 function = _make_atomic_symbol_function (hdl )
379214 if function .__name__ .startswith ('_' ):
380- setattr (Symbol , function .__name__ , staticmethod ( function ) )
215+ setattr (module_internal , function .__name__ , function )
381216 else :
382217 setattr (module_obj , function .__name__ , function )
383218
0 commit comments