2222from ..target import Target
2323from ..runtime import profiler_vm , profiling , Device , num_threads
2424from ..script import tir as T
25+ from ..ir .instrument import pass_instrument
26+ from ..ir .expr import GlobalVar
2527
2628
2729def _create_args (mod : IRModule , dev : Device , func_name : str = "main" ):
@@ -36,16 +38,6 @@ def _create_args(mod: IRModule, dev: Device, func_name: str = "main"):
3638 return args
3739
3840
39- def _estimated_features (mod : IRModule , params : Dict [str , nd .NDArray ], target : Target ):
40- comp = relay .vm .VMCompiler ()
41- mod , params = comp .optimize (mod , params = params , target = target )
42- return {
43- prim .attrs ["hash" ]: (name , auto_scheduler .feature .named_features_from_primfunc (prim ))
44- for name , prim in mod .functions .items ()
45- if isinstance (prim , tir .PrimFunc )
46- }
47-
48-
4941def _detect_vec_width_registers (
5042 target : Target , vec_width : Optional [int ], num_vector_registers : Optional [int ]
5143):
@@ -226,60 +218,98 @@ def estimate_peak_bandwidth(target: Target, dev: Device, vec_width: Optional[int
226218 return a .numpy ().size * 4 / times .min # 4 bytes per float32
227219
228220
229- def roofline_analysis (
230- mod : IRModule , params : Dict [str , nd .NDArray ], target : Union [str , Target ], dev : Device
221+ @pass_instrument
222+ class SaveLoweredTIR :
223+ """Save TIR functions from right before final lowering. Right now this
224+ means right before tir.MakePackedAPI."""
225+
226+ def __init__ (self ):
227+ self .functions = {}
228+ self .done = False
229+
230+ def run_after_pass (self , mod , info ):
231+ if not self .done :
232+ if info .name == "tir.MakePackedAPI" :
233+ self .done = True
234+ else :
235+ for v , func in mod .functions .items ():
236+ self .functions [v ] = func
237+
238+
239+ def roofline_from_existing (
240+ report : profiling .Report ,
241+ tir_functions : Dict [GlobalVar , tir .PrimFunc ],
242+ target : Target ,
243+ dev : Device ,
231244) -> profiling .Report :
232- """
233- Create a profiling report that contains roofline and other estimated
234- statistics from running a module on the VM.
245+ """Add roofline and other estimated statistics to an existing profiling report.
235246
236- These statistics are calculated by analyzing the lowered TIR of each
237- operator, so they are estimates of the true values. The statistics are:
238- - Bound: Is the operator memory or compute bound. This is computed by
239- assuming that the operator could perfectly cache all loads -- each byte
240- of memory is only loaded once.
241- - Percent of Theoretical Optimal: What percent of theoretical optimal for
242- the bound. i.e. percent of peak memory bandwidth if memory bound,
243- percent of peak FLOP/s if compute bound.
244- - Loaded Bytes: estimation of the number of bytes loaded from main memory.
245- - Estimated Flops: estimated number of floating point operations.
246- - Arithmetic Intensity: ratio of FLOPs per byte of data.
247- - FLOP/s: floating point operations per second.
248- - Bandwidth: Number of bytes loaded per second.
247+ :py:func:`roofline_analysis` should always be used instead of this function
248+ unless you need a custom compilation pipeline.
249249
250- Parameters
251- ----------
252- mod : IRModule
253- Uncompiled input module>
250+ Calculating roofline statistics requires features extracted the TIR
251+ functions in addition to per-operator runtime information (`report`) of the
252+ same TIR features. The features and TIR functions are not included with the
253+ compiled library used to generate the per-operator runtime. It is essential
254+ that the per-operator information comes from the exact same compilation
255+ pipeline as the TIR functions.
254256
255- params : Dict[str, nd.NDArray]
256257
257- target : Union[str, Target]
258- Target to run on.
258+ Example
259+ -------
260+
261+ ..code: : python
262+
263+ import tvm
264+ import tvm.relay
265+
266+ mod, params = tvm.relay.testing.mlp.get_workload()
267+
268+ # it is recommended to use SaveLoweredTIR to get out the tir primfuncs
269+ save_tir = tvm.utils.roofline.SaveLoweredTIR()
270+ with tvm.transform.PassContext(opt_level=3, pass_instrument=[save_tir]):
271+ lib = relay.vm.compile(mod, params=params, target=target)
272+
273+ vmexec = profiler_vm.VirtualMachineProfiler(lib, dev)
274+ report = vmexec.profile(*inputs)
275+
276+ roofline_report = roofline_from_existing(report, save_tir.functions, target, dev)
259277
278+
279+ Parameters
280+ ----------
281+ report : Report
282+ Existing profiling report from :py:method:`VirtualMachineProfiler.profile`.
283+ tir_functions : Dict[GlobalVar, PrimFunc]
284+ TIR primfuncs from the module run to generate `report`. It is nessesary
285+ that these functions come before the `tir.MakePackedAPI` pass and are
286+ compatible with auto_scheduler featurization.
287+ :py:class:`SaveLoweredTIR` is the recommended way to collect these
288+ functions.
289+ target : Target
290+ TVM target that `report` was generated with.
260291 dev : Device
261- Device to run on .
292+ Device that `report` was generated with .
262293
263294 Returns
264295 -------
265-
266- report : profiling.Report
267- Profiling report which includes the estimated statistics.
296+ profiling.Report
297+ New profiling report that includes all information from `report`
298+ along with additional roofline metrics. See
299+ :py:func:`roofline_analysis` for more information on which metrics
300+ are included.
268301 """
269- if isinstance (target , str ):
270- target = Target (target )
271302 peak_bandwidth = estimate_peak_bandwidth (target , dev )
272303 peak_flops = estimate_peak_fma_flops (target , dev )
273304
274305 ridge_point = peak_flops / peak_bandwidth
275306
276- all_features = _estimated_features (mod , params , target )
277-
278- lib = relay .vm .compile (mod , params = params , target = target )
279- vmexec = profiler_vm .VirtualMachineProfiler (lib , dev )
307+ all_features = {
308+ prim .attrs ["hash" ]: (name , auto_scheduler .feature .named_features_from_primfunc (prim ))
309+ for name , prim in tir_functions .items ()
310+ if isinstance (prim , tir .PrimFunc ) and "hash" in prim .attrs .keys ()
311+ }
280312
281- args = _create_args (mod , dev )
282- report = vmexec .profile (* args )
283313 new_calls = []
284314 for call in report .calls :
285315 if "Hash" in call .keys ():
@@ -313,3 +343,71 @@ def roofline_analysis(
313343 else :
314344 new_calls .append (call )
315345 return profiling .Report (new_calls , report .device_metrics )
346+
347+
348+ def roofline_analysis (
349+ mod : IRModule , params : Dict [str , nd .NDArray ], target : Union [str , Target ], dev : Device
350+ ) -> profiling .Report :
351+ """
352+ Create a profiling report that contains roofline and other estimated
353+ statistics from running a module on the VM.
354+
355+ The roofline model measures how close a operator gets to best possible
356+ memory bandwidth or FLOP/s depending on whether it is memory or compute
357+ bound. This computation uses the runtime of the operator along with two
358+ numbers extracted from the TIR code: bytes of memory touched and number of
359+ floating point operations.
360+
361+ These statistics are calculated by analyzing the lowered TIR of each
362+ operator, so they are estimates of the true values. The statistics are:
363+ - Bound: Is the operator memory or compute bound. This is computed by
364+ assuming that the operator could perfectly cache all loads -- each byte
365+ of memory is only loaded once.
366+ - Percent of Theoretical Optimal: What percent of theoretical optimal for
367+ the bound. i.e. percent of peak memory bandwidth if memory bound,
368+ percent of peak FLOP/s if compute bound.
369+ - Loaded Bytes: estimation of the number of bytes loaded from main memory.
370+ - Estimated Flops: estimated number of floating point operations.
371+ - Arithmetic Intensity: ratio of FLOPs per byte of data.
372+ - FLOP/s: floating point operations per second.
373+ - Bandwidth: Number of bytes loaded per second.
374+
375+ Parameters
376+ ----------
377+ mod : IRModule
378+ Uncompiled input module>
379+
380+ params : Dict[str, nd.NDArray]
381+
382+ target : Union[str, Target]
383+ Target to run on.
384+
385+ dev : Device
386+ Device to run on.
387+
388+ Returns
389+ -------
390+
391+ report : profiling.Report
392+ Profiling report which includes the estimated statistics.
393+ """
394+ if isinstance (target , str ):
395+ target = Target (target )
396+
397+ save_tir = SaveLoweredTIR ()
398+ # copy existing context but add our instrument
399+ pass_ctx = transform .PassContext .current ()
400+ with transform .PassContext (
401+ opt_level = pass_ctx .opt_level ,
402+ required_pass = pass_ctx .required_pass ,
403+ disabled_pass = pass_ctx .disabled_pass ,
404+ instruments = list (pass_ctx .instruments ) + [save_tir ],
405+ config = pass_ctx .config ,
406+ ):
407+ lib = relay .vm .compile (mod , params = params , target = target )
408+ vmexec = profiler_vm .VirtualMachineProfiler (lib , dev )
409+
410+ args = _create_args (mod , dev )
411+ report = vmexec .profile (* args )
412+
413+ return roofline_from_existing (report , save_tir .functions , target , dev )
0 commit comments