88from ..pycode_writer import get_float_string
99from ..fx_graph import NodeArgs , FxGraph
1010from ..utils import NullObject , null_object
11- from ..store_pos import StorePos
11+ from ..store_pos import StorePos , StoreInFreeVar , StoreInAttr
1212if TYPE_CHECKING :
1313 from ..pycode_generator import GraphFnCodegen , GuardFnCodegen
14+ from ..object_table import ObjectTable
1415
1516
1617class NoneVar (Variable ):
@@ -132,6 +133,33 @@ def as_fx_node(self) -> NodeArgs:
132133 return slice (self .start , self .stop , self .step )
133134
134135
136+ class EllipsisVar (Variable ):
137+
138+ def __init__ (self , need_guard_check : bool , obj : Any ,
139+ extract_code_at_start : list [StorePos ]) -> None :
140+ super ().__init__ (need_guard_check , obj , extract_code_at_start )
141+
142+ def make_guard_inner (self , codegen : "GuardFnCodegen" ,
143+ pos : StorePos ) -> None :
144+ codegen .add_id_check (f"id({ pos } ) == { id (self .obj )} " , self .obj )
145+
146+ def make_output_inner (self , name_in_graph_fn : str , store_pos : StorePos ,
147+ codegen : "GraphFnCodegen" , in_return : bool ,
148+ idx : int ) -> None :
149+ name = codegen .add_obj (self .obj , "Ellipsis_VAR" )
150+ codegen .output (name_in_graph_fn , store_pos , name , in_return , idx )
151+
152+ @classmethod
153+ def from_value (cls , value : Any , need_guard_check : bool ,
154+ _helper_functions : HelperFunctions ,
155+ _fx_graph : Optional [FxGraph ],
156+ extract_code_at_start : list [StorePos ]) -> "EllipsisVar" :
157+ return cls (need_guard_check , value , extract_code_at_start )
158+
159+ def as_fx_node (self ) -> NodeArgs :
160+ return Ellipsis
161+
162+
135163torch_modules = set ([torch ])
136164
137165
@@ -162,10 +190,24 @@ def from_value(cls, value: ModuleType, need_guard_check: bool,
162190
163191
164192class FunctionVar (Variable ):
193+ closure_vars : list [Variable ]
194+ obj_ids : list [int ]
165195
166196 def __init__ (self , func : Callable [..., Any ], need_guard_check : bool ,
197+ helper_functions : HelperFunctions ,
167198 extract_code_at_start : list [StorePos ]) -> None :
168199 super ().__init__ (need_guard_check , func , extract_code_at_start )
200+ self .closure_vars = []
201+ self .obj_ids = []
202+ if hasattr (func , "__code__" ) and hasattr (func , "__closure__" ):
203+ if func .__closure__ is not None :
204+ assert len (func .__code__ .co_freevars ) == len (func .__closure__ )
205+ for i , x in enumerate (func .__closure__ ):
206+ if x .cell_contents != func :
207+ cell_var = helper_functions .get_or_make_var (
208+ x , need_guard_check , None , [StoreInFreeVar (i )])
209+ self .closure_vars .append (cell_var )
210+ self .obj_ids .append (id (x ))
169211
170212 def make_guard_inner (self , codegen : "GuardFnCodegen" ,
171213 pos : StorePos ) -> None :
@@ -187,17 +229,31 @@ def from_value(cls, value: Callable[..., Any], need_guard_check: bool,
187229 _helper_functions : HelperFunctions ,
188230 _fx_graph : Optional [FxGraph ],
189231 extract_code_at_start : list [StorePos ]) -> "FunctionVar" :
190- return cls (value , need_guard_check , extract_code_at_start )
232+ return cls (value , need_guard_check , _helper_functions ,
233+ extract_code_at_start )
234+
235+ def add_subvars_to_table (self , table : 'ObjectTable' ) -> None :
236+ for i , (var , idx ) in enumerate (zip (self .closure_vars , self .obj_ids )):
237+ old_var = table .get_or_none_by_id (idx )
238+ if old_var is not None :
239+ new_extract : list [StorePos ] = [StoreInFreeVar (i )]
240+ old_var .extract_code_at_start .extend (new_extract )
241+ old_var .need_guard_check |= self .need_guard_check
242+ else :
243+ table .add_by_id (var , idx )
244+ var .add_subvars_to_table (table )
245+
246+ # def as_fx_node(self) -> NodeArgs:
247+ # return self.obj
191248
192249
193250class RangeVar (Variable ):
194- start : Optional [ int ]
195- stop : Optional [ int ]
196- step : Optional [ int ]
251+ start : int
252+ stop : int
253+ step : int
197254
198- def __init__ (self , start : Optional [int ], stop : Optional [int ],
199- step : Optional [int ], need_guard_check : bool , obj : range ,
200- extract_code_at_start : list [StorePos ]) -> None :
255+ def __init__ (self , start : int , stop : int , step : int , need_guard_check : bool ,
256+ obj : range , extract_code_at_start : list [StorePos ]) -> None :
201257 super ().__init__ (need_guard_check , obj , extract_code_at_start )
202258 self .start = start
203259 self .stop = stop
@@ -222,3 +278,6 @@ def from_value(cls, value: range, need_guard_check: bool,
222278 extract_code_at_start : list [StorePos ]) -> "RangeVar" :
223279 return cls (value .start , value .stop , value .step , need_guard_check , value ,
224280 extract_code_at_start )
281+
282+ def as_fx_node (self ) -> NodeArgs :
283+ return range (self .start , self .stop , self .step )
0 commit comments