|
| 1 | +"""Shared functions and classes for frontends.""" |
| 2 | +from __future__ import absolute_import as _abs |
| 3 | +import warnings |
| 4 | +from .._base import string_types |
| 5 | + |
| 6 | +class Renamer(object): |
| 7 | + """A simply renamer for operators. |
| 8 | +
|
| 9 | + Parameters |
| 10 | + ---------- |
| 11 | + new_name : str |
| 12 | + The new name for the operator |
| 13 | + """ |
| 14 | + def __init__(self, new_name): |
| 15 | + self._new_name = new_name |
| 16 | + |
| 17 | + def __call__(self, attrs): |
| 18 | + return self._new_name, attrs |
| 19 | + |
| 20 | + |
| 21 | +class AttrConverter(object): |
| 22 | + """Common attribute conveter. An AttrConverter instance is a callable: |
| 23 | + ``` |
| 24 | + attr_converter = AttrConverter(op_name, transforms={'a':'b', 'c':('d', 1)}) |
| 25 | + new_op_name, new_attr = attr_converter(attrs) |
| 26 | + ``` |
| 27 | +
|
| 28 | + Parameters |
| 29 | + ---------- |
| 30 | + op_name : str or callable |
| 31 | + If set as str, returned operator name is the str. |
| 32 | + If set as callable, returned operator is the str returned by calling: |
| 33 | + `op_name = func(attr)` |
| 34 | + transforms : dict of `new_name, or (new_name, default_value, transform function)` |
| 35 | + If only a new_name is provided, it's like renaming the attribute name. |
| 36 | + If default_value if provded, then the attribute is considered as optional. |
| 37 | + If transform function is provided, the original attribute value is handled |
| 38 | + by transform function. |
| 39 | + excludes : list |
| 40 | + A list of excluded attributes that should `NOT` appear. |
| 41 | + Raise NotImplementedError if occured. |
| 42 | + disables : list |
| 43 | + A list of attributes that is disabled in nnvm. Raise warnings. |
| 44 | + ignores : list |
| 45 | + A list of attributes that is ignored in nnvm. Silent. |
| 46 | + extras : dict |
| 47 | + A series of additional attributes should be added anyway to the returned |
| 48 | + attribute dict. |
| 49 | + custom_check : callable |
| 50 | + A custom function takes attribute, and return True/False. |
| 51 | + Raise RuntimeError if not bool(True) returned. |
| 52 | + """ |
| 53 | + def __init__(self, op_name, transforms=None, |
| 54 | + excludes=None, disables=None, ignores=None, |
| 55 | + extras=None, custom_check=None): |
| 56 | + self._op_name = op_name |
| 57 | + self._transforms = transforms if transforms else {} |
| 58 | + self._excludes = excludes if excludes else [] |
| 59 | + self._disables = disables if disables else [] |
| 60 | + self._ignores = ignores if ignores else [] |
| 61 | + self._extras = extras if extras else {} |
| 62 | + self._custom_check = custom_check |
| 63 | + |
| 64 | + def __call__(self, attrs): |
| 65 | + # apply custom check |
| 66 | + if self._custom_check: |
| 67 | + func, msg = self._custom_check |
| 68 | + if not func(attrs): |
| 69 | + raise RuntimeError("Check failed: {}".format(msg)) |
| 70 | + # get new op_name |
| 71 | + if isinstance(self._op_name, string_types): |
| 72 | + op_name = self._op_name |
| 73 | + else: |
| 74 | + assert callable(self._op_name), "op_name can either be string or callable" |
| 75 | + op_name = self._op_name(attrs) |
| 76 | + # convert attributes |
| 77 | + new_attrs = {} |
| 78 | + for k in attrs.keys(): |
| 79 | + if k in self._excludes: |
| 80 | + raise NotImplementedError("Attribute {} not supported yet.".format(k)) |
| 81 | + elif k in self._disables: |
| 82 | + warnings.warn("Attribute {} is disabled in nnvm.sym.{}".format(k, op_name)) |
| 83 | + elif k in self._ignores: |
| 84 | + pass |
| 85 | + elif k in self._transforms: |
| 86 | + new_name, defaults, transform = self._parse_default(self._transforms[k]) |
| 87 | + if defaults is None: |
| 88 | + new_attr = self._required_attr(attrs, k) |
| 89 | + else: |
| 90 | + new_attr = attrs.get(k, None) |
| 91 | + if new_attr is None: |
| 92 | + new_attrs[new_name] = defaults |
| 93 | + else: |
| 94 | + new_attrs[new_name] = transform(new_attr) |
| 95 | + else: |
| 96 | + # copy |
| 97 | + new_attrs[k] = attrs[k] |
| 98 | + # add extras |
| 99 | + new_attrs.update(self._extras) |
| 100 | + return op_name, new_attrs |
| 101 | + |
| 102 | + def _parse_default(self, target): |
| 103 | + """Helper function to parse default values.""" |
| 104 | + if not isinstance(target, (list, tuple)): |
| 105 | + k, v, t = target, None, lambda x: x |
| 106 | + elif len(target) == 1: |
| 107 | + k, v, t = target[0], None, lambda x: x |
| 108 | + elif len(target) == 2: |
| 109 | + k, v, t = target[0], target[1], lambda x: x |
| 110 | + elif len(target) > 2: |
| 111 | + k, v, t = target[0], target[1], target[2] |
| 112 | + else: |
| 113 | + k = None # should raise |
| 114 | + if not isinstance(k, string_types): |
| 115 | + msg = "{} is not a valid target, (name, default) expected.".format(target) |
| 116 | + raise ValueError(msg) |
| 117 | + return k, v, t |
| 118 | + |
| 119 | + def _parse_bool(self, value): |
| 120 | + """Helper function to parse default boolean values.""" |
| 121 | + if isinstance(value, string_types): |
| 122 | + return value.strip().lower() in ['true', '1', 't', 'y', 'yes'] |
| 123 | + return bool(value) |
| 124 | + |
| 125 | + def _required_attr(self, attr, key): |
| 126 | + """Wrapper for getting required attributes.""" |
| 127 | + assert isinstance(attr, dict) |
| 128 | + if key not in attr: |
| 129 | + raise AttributeError("Required attribute {} not found.".format(key)) |
| 130 | + return attr[key] |
0 commit comments