Skip to content

Commit b3c000f

Browse files
authored
Parity 1 (apache#24)
2 parents b8e6464 + 6c24c8d commit b3c000f

15 files changed

Lines changed: 489 additions & 80 deletions

frontend/guard_tracker.py

Lines changed: 174 additions & 34 deletions
Large diffs are not rendered by default.

frontend/object_table.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from .utils import NullObject, ReadOnlyObject
88
from .store_pos import StorePos
99
from .fx_graph import FxGraph
10-
import torch
10+
import numpy as np
1111

1212

1313
class ObjectTable:
@@ -35,7 +35,6 @@ def add(self, var: Variable, value: Any) -> None:
3535
old_var.need_guard_check |= var.need_guard_check
3636
else:
3737
self.add_by_id(var, id(value))
38-
var.add_subvars_to_table(self)
3938

4039
def add_by_id(self, var: Variable, idx: int) -> None:
4140
assert idx not in self.objs
@@ -68,11 +67,13 @@ def get(self,
6867
return self.objs[id(value)]
6968
elif allow_unexist_const:
7069
if isinstance(value, get_args(CONST_TYPES)) or isinstance(
71-
value, (list, tuple, set, dict, CodeType)):
70+
value, (list, tuple, set, dict, range, CodeType,
71+
type(Ellipsis), np.ndarray)):
7272
return make_var_from_value(value, False, self.helper_functions,
7373
fx_graph)
7474
raise RuntimeError(
75-
f"Object({id(value)}) {value} not found in object table")
75+
f"Object({id(value)}) {value} {type(value)} not found in object table"
76+
)
7677

7778
def get_or_none(self, value: Any) -> Optional[Variable]:
7879
if id(value) in self.objs:

frontend/utils.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import contextlib
99
import torch
1010
import torch._C
11+
import collections
1112
from .config import get_config, set_config
1213

1314
if TYPE_CHECKING:
@@ -93,11 +94,15 @@ def is_call_bytecode(inst: 'Instruction') -> bool:
9394
operator.rshift,
9495
operator.and_,
9596
operator.or_,
97+
operator.is_,
9698
operator.xor,
9799
operator.eq,
98100
operator.lt,
99101
operator.ne,
100102
operator.le,
103+
operator.gt,
104+
operator.ge,
105+
operator.contains,
101106
}
102107
fx_graph_functions = fx_graph_functions.union(fx_graph_inplace_functions)
103108

@@ -124,7 +129,7 @@ def get_root_module(func: Callable[..., Any]) -> str:
124129
if hasattr(func, '__objclass__'):
125130
if func.__objclass__ == torch._C._TensorBase:
126131
return 'torch'
127-
elif func.__objclass__ in (list, tuple, set, dict):
132+
elif func.__objclass__ in (list, tuple, set, dict, str):
128133
return 'builtins'
129134

130135
if hasattr(func, '__self__') and isinstance(func.__self__, torch.Tensor):
@@ -135,9 +140,13 @@ def get_root_module(func: Callable[..., Any]) -> str:
135140
return 'numpy'
136141

137142
module = inspect.getmodule(func)
138-
if module is None:
143+
module_str = ""
144+
if module is not None:
145+
module_str = str(module).split('\'')[1]
146+
if module is None or module_str in ('torch.distributions.bernoulli',
147+
'torch.distributions.distribution'):
139148
return ""
140-
root_module = str(module).split('\'')[1].split('.')[0]
149+
root_module = module_str.split('.')[0]
141150
return root_module
142151

143152

@@ -161,15 +170,18 @@ def get_method_defined_class(cls: type[Any],
161170

162171
def is_user_defined_func(func: Callable[..., Any]) -> bool:
163172
# print([(x, getattr(func, x)) for x in dir(func)])
164-
if hasattr(func,
165-
'__objclass__') and func.__objclass__ in (torch._C._TensorBase,
166-
dict):
173+
if hasattr(func, '__objclass__') and func.__objclass__ in (
174+
torch._C._TensorBase, dict, str, collections.OrderedDict):
167175
return False
168176

169177
# NOTE: random should be called as a UDF, not handled
170-
if hasattr(func, '__self__') and isinstance(func.__self__,
171-
(torch.Tensor, random.Random)):
172-
return False
178+
if hasattr(func, '__self__'):
179+
if isinstance(func.__self__, (torch.Tensor, random.Random)):
180+
return False
181+
elif isinstance(func.__self__, (list, tuple, set, dict, str)):
182+
return False
183+
elif isinstance(func.__self__, torch.nn.Sequential):
184+
return True
173185

174186
if hasattr(func, '__name__') and func.__name__ == '<genexpr>':
175187
return False
@@ -213,6 +225,11 @@ def is_graph_func(func: Callable[..., Any]) -> bool:
213225
return root_module == 'torch'
214226

215227

228+
def is_math_func(func: Callable[..., Any]) -> bool:
229+
root_module = get_root_module(func)
230+
return root_module == 'math'
231+
232+
216233
random_state = None
217234

218235

frontend/variables/__init__.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,11 @@
88
from .tensor import TensorVar, TorchParamVar, TorchSizeVar, TorchDtypeVar, TorchDeviceVar
99
from .torch_module import TorchModuleVar, TorchSequentialVar, TorchModuleListVar
1010
from .any_ import AnyVar
11-
from .const import NullVar, NoneVar, SliceVar, ModuleVar, FunctionVar, RangeVar, CodeVar
11+
from .const import NullVar, NoneVar, SliceVar, ModuleVar, FunctionVar, RangeVar, CodeVar, EllipsisVar
1212
from .iterator import IteratorVar, RangeIterVar
1313
from .tuple_ import TupleVar
1414
from .set_ import SetVar
15-
from .list_ import ListVar
15+
from .list_ import ListVar, NdarrayVar
1616
from .dict_ import DictVar, OrderedDictVar
1717
from .builtin_types import CellVar, MappingProxyVar
1818
from ..fx_graph import FxGraph
@@ -37,7 +37,8 @@
3737
torch.device: TorchDeviceVar,
3838
dict: DictVar,
3939
CodeType: CodeVar,
40-
OrderedDict: OrderedDictVar
40+
OrderedDict: OrderedDictVar,
41+
np.ndarray: NdarrayVar,
4142
}
4243

4344
CONST_TYPES = Union[int, float, bool, str, NullObject, None, slice]
@@ -86,6 +87,9 @@ def make_var_from_value(
8687
return MappingProxyVar.from_value(value, need_guard_check,
8788
helper_functions, fx_graph,
8889
extract_code_at_start)
90+
elif isinstance(value, type(Ellipsis)):
91+
return EllipsisVar.from_value(value, need_guard_check, helper_functions,
92+
fx_graph, extract_code_at_start)
8993
else:
9094
# NOTE: use any instead of iteartor_var to represent iterator with unknown source due to the hardness of getting iterable and num_iters
9195
print("generate any for", value, type(value), extract_code_at_start)

frontend/variables/const.py

Lines changed: 67 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,10 @@
88
from ..pycode_writer import get_float_string
99
from ..fx_graph import NodeArgs, FxGraph
1010
from ..utils import NullObject, null_object
11-
from ..store_pos import StorePos
11+
from ..store_pos import StorePos, StoreInFreeVar, StoreInAttr
1212
if TYPE_CHECKING:
1313
from ..pycode_generator import GraphFnCodegen, GuardFnCodegen
14+
from ..object_table import ObjectTable
1415

1516

1617
class NoneVar(Variable):
@@ -132,6 +133,33 @@ def as_fx_node(self) -> NodeArgs:
132133
return slice(self.start, self.stop, self.step)
133134

134135

136+
class EllipsisVar(Variable):
137+
138+
def __init__(self, need_guard_check: bool, obj: Any,
139+
extract_code_at_start: list[StorePos]) -> None:
140+
super().__init__(need_guard_check, obj, extract_code_at_start)
141+
142+
def make_guard_inner(self, codegen: "GuardFnCodegen",
143+
pos: StorePos) -> None:
144+
codegen.add_id_check(f"id({pos}) == {id(self.obj)}", self.obj)
145+
146+
def make_output_inner(self, name_in_graph_fn: str, store_pos: StorePos,
147+
codegen: "GraphFnCodegen", in_return: bool,
148+
idx: int) -> None:
149+
name = codegen.add_obj(self.obj, "Ellipsis_VAR")
150+
codegen.output(name_in_graph_fn, store_pos, name, in_return, idx)
151+
152+
@classmethod
153+
def from_value(cls, value: Any, need_guard_check: bool,
154+
_helper_functions: HelperFunctions,
155+
_fx_graph: Optional[FxGraph],
156+
extract_code_at_start: list[StorePos]) -> "EllipsisVar":
157+
return cls(need_guard_check, value, extract_code_at_start)
158+
159+
def as_fx_node(self) -> NodeArgs:
160+
return Ellipsis
161+
162+
135163
torch_modules = set([torch])
136164

137165

@@ -162,10 +190,24 @@ def from_value(cls, value: ModuleType, need_guard_check: bool,
162190

163191

164192
class FunctionVar(Variable):
193+
closure_vars: list[Variable]
194+
obj_ids: list[int]
165195

166196
def __init__(self, func: Callable[..., Any], need_guard_check: bool,
197+
helper_functions: HelperFunctions,
167198
extract_code_at_start: list[StorePos]) -> None:
168199
super().__init__(need_guard_check, func, extract_code_at_start)
200+
self.closure_vars = []
201+
self.obj_ids = []
202+
if hasattr(func, "__code__") and hasattr(func, "__closure__"):
203+
if func.__closure__ is not None:
204+
assert len(func.__code__.co_freevars) == len(func.__closure__)
205+
for i, x in enumerate(func.__closure__):
206+
if x.cell_contents != func:
207+
cell_var = helper_functions.get_or_make_var(
208+
x, need_guard_check, None, [StoreInFreeVar(i)])
209+
self.closure_vars.append(cell_var)
210+
self.obj_ids.append(id(x))
169211

170212
def make_guard_inner(self, codegen: "GuardFnCodegen",
171213
pos: StorePos) -> None:
@@ -187,17 +229,31 @@ def from_value(cls, value: Callable[..., Any], need_guard_check: bool,
187229
_helper_functions: HelperFunctions,
188230
_fx_graph: Optional[FxGraph],
189231
extract_code_at_start: list[StorePos]) -> "FunctionVar":
190-
return cls(value, need_guard_check, extract_code_at_start)
232+
return cls(value, need_guard_check, _helper_functions,
233+
extract_code_at_start)
234+
235+
def add_subvars_to_table(self, table: 'ObjectTable') -> None:
236+
for i, (var, idx) in enumerate(zip(self.closure_vars, self.obj_ids)):
237+
old_var = table.get_or_none_by_id(idx)
238+
if old_var is not None:
239+
new_extract: list[StorePos] = [StoreInFreeVar(i)]
240+
old_var.extract_code_at_start.extend(new_extract)
241+
old_var.need_guard_check |= self.need_guard_check
242+
else:
243+
table.add_by_id(var, idx)
244+
var.add_subvars_to_table(table)
245+
246+
# def as_fx_node(self) -> NodeArgs:
247+
# return self.obj
191248

192249

193250
class RangeVar(Variable):
194-
start: Optional[int]
195-
stop: Optional[int]
196-
step: Optional[int]
251+
start: int
252+
stop: int
253+
step: int
197254

198-
def __init__(self, start: Optional[int], stop: Optional[int],
199-
step: Optional[int], need_guard_check: bool, obj: range,
200-
extract_code_at_start: list[StorePos]) -> None:
255+
def __init__(self, start: int, stop: int, step: int, need_guard_check: bool,
256+
obj: range, extract_code_at_start: list[StorePos]) -> None:
201257
super().__init__(need_guard_check, obj, extract_code_at_start)
202258
self.start = start
203259
self.stop = stop
@@ -222,3 +278,6 @@ def from_value(cls, value: range, need_guard_check: bool,
222278
extract_code_at_start: list[StorePos]) -> "RangeVar":
223279
return cls(value.start, value.stop, value.step, need_guard_check, value,
224280
extract_code_at_start)
281+
282+
def as_fx_node(self) -> NodeArgs:
283+
return range(self.start, self.stop, self.step)

frontend/variables/dict_.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,18 @@ def make_output_inner(self, name_in_graph_fn: str, store_pos: StorePos,
7373
codegen.output(name_in_graph_fn, store_pos, str(old_store_pos),
7474
in_return, idx)
7575
else:
76-
codegen.output(
77-
name_in_graph_fn, store_pos,
78-
f"{{{','.join(f'{key}: {name_in_graph_fn}_{j}' for key, j in zip(self.value.keys(), range(len(self.vars))))}}}"
79-
if len(self.vars) > 0 else "{}", in_return, idx)
76+
items = []
77+
for key, j in zip(self.value.keys(), range(len(self.vars))):
78+
if isinstance(key, str):
79+
key_part = f"'{key}'"
80+
else:
81+
key_part = key
82+
item = f'{key_part}: {name_in_graph_fn}_{j}'
83+
items.append(item)
84+
target = f"{{{', '.join(i for i in items)}}}"
85+
codegen.output(name_in_graph_fn, store_pos,
86+
target if len(self.vars) > 0 else "{}", in_return,
87+
idx)
8088

8189
@classmethod
8290
def from_value(cls,

frontend/variables/list_.py

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import TYPE_CHECKING, Optional, Any, Callable
22
from copy import copy
3+
import numpy as np
34
from .base import Variable, HelperFunctions
45
from ..fx_graph import NodeArgs, FxGraph
56
from ..store_pos import StorePos, StoreInIndex
@@ -83,4 +84,73 @@ def add_subvars_to_table(self, table: 'ObjectTable') -> None:
8384
old_var.need_guard_check |= self.need_guard_check
8485
else:
8586
table.add_by_id(var, idx)
86-
var.add_subvars_to_table(table)
87+
var.add_subvars_to_table(table)
88+
89+
90+
class NdarrayVar(Variable):
91+
vars: list[Variable]
92+
obj_ids: list[int]
93+
length: int
94+
95+
def __init__(self, value: np.ndarray[Any, Any], need_guard_check: bool,
96+
helper_functions: HelperFunctions, fx_graph: Optional[FxGraph],
97+
extract_code_at_start: list[StorePos]) -> None:
98+
super().__init__(need_guard_check, value, extract_code_at_start)
99+
self.value = value
100+
self.length = len(value)
101+
self.vars = []
102+
self.obj_ids = []
103+
for i, obj in enumerate(value):
104+
new_extract: list[StorePos] = [
105+
StoreInIndex(pos, id(obj), i)
106+
for pos in self.extract_code_at_start
107+
]
108+
var = helper_functions.get_or_make_var(obj, need_guard_check,
109+
fx_graph, new_extract)
110+
self.vars.append(var)
111+
self.obj_ids.append(id(obj))
112+
113+
def make_guard_inner(self, codegen: "GuardFnCodegen",
114+
pos: StorePos) -> None:
115+
codegen.add_import("numpy")
116+
codegen.add_check(f"isinstance({pos}, numpy.ndarray)")
117+
codegen.add_check(f"len({pos}) == {self.length}")
118+
for i, obj in enumerate(self.vars):
119+
obj.make_guard_inner(codegen, StoreInIndex(pos, id(obj), i))
120+
121+
def make_output_inner(self, name_in_graph_fn: str, store_pos: StorePos,
122+
codegen: "GraphFnCodegen", in_return: bool,
123+
idx: int) -> None:
124+
for j, (idx_j, var) in enumerate(zip(self.obj_ids, self.vars)):
125+
var.make_output(f"{name_in_graph_fn}_{j}", store_pos, codegen,
126+
False, idx_j)
127+
list_str = f"[{','.join(f'{name_in_graph_fn}_{j}' for j in range(len(self.vars)))},]" if len(
128+
self.vars) > 0 else "[]"
129+
codegen.add_import("numpy")
130+
var_str = f"numpy.array({list_str})"
131+
codegen.output(name_in_graph_fn, store_pos, var_str, in_return, idx)
132+
133+
@classmethod
134+
def from_value(cls, value: np.ndarray[Any, Any], need_guard_check: bool,
135+
helper_functions: HelperFunctions,
136+
fx_graph: Optional[FxGraph],
137+
extract_code_at_start: list[StorePos]) -> "NdarrayVar":
138+
return cls(value, need_guard_check, helper_functions, fx_graph,
139+
extract_code_at_start)
140+
141+
def as_fx_node(self) -> NodeArgs:
142+
return self.value
143+
144+
def add_subvars_to_table(self, table: 'ObjectTable') -> None:
145+
for i, (var, idx) in enumerate(zip(self.vars, self.obj_ids)):
146+
old_var = table.get_or_none_by_id(idx)
147+
if old_var is not None:
148+
new_extract: list[StorePos] = [
149+
StoreInIndex(pos, idx, i)
150+
for pos in self.extract_code_at_start
151+
]
152+
old_var.extract_code_at_start.extend(new_extract)
153+
old_var.need_guard_check |= self.need_guard_check
154+
else:
155+
table.add_by_id(var, idx)
156+
var.add_subvars_to_table(table)

frontend/variables/scalar.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,9 @@ def __init__(self, value: ScalarType, value_fix: bool,
2222
need_guard_check: bool, fx_node: Optional[torch.fx.Node],
2323
extract_code_at_start: list[StorePos]) -> None:
2424
super().__init__(need_guard_check, value, extract_code_at_start)
25-
if isinstance(value, bool) and not value_fix:
26-
raise NotImplementedError
25+
# NOTE: should implement bool genererated from tensor
26+
# if isinstance(value, bool) and not value_fix:
27+
# raise NotImplementedError
2728
if not value_fix:
2829
assert fx_node is not None
2930
self.value_fix = value_fix
@@ -119,8 +120,8 @@ def __init__(self, value: np.generic, value_fix: bool,
119120

120121
def make_guard_inner(self, codegen: "GuardFnCodegen",
121122
pos: StorePos) -> None:
122-
codegen.add_check(
123-
f"isinstance({pos}.item(), {type(self.obj).__name__})")
123+
codegen.add_import("numpy")
124+
codegen.add_check(f"isinstance({pos}, numpy.{type(self.obj).__name__})")
124125
if self.value_fix:
125126
item = self.obj.item()
126127
if type(item) == float:

0 commit comments

Comments
 (0)