1818from typing import Dict , Union , Optional
1919import numpy as np
2020
21- from .. import auto_scheduler , relay , tir , nd , IRModule , build , topi , transform
21+ from .. import auto_scheduler , relay , tir , nd , IRModule , build , topi , transform , get_global_func
2222from ..target import Target
2323from ..runtime import profiler_vm , profiling , Device , num_threads
2424from ..script import tir as T
2525from ..ir .instrument import pass_instrument
2626from ..ir .expr import GlobalVar
27+ from ..rpc .base import RPC_SESS_MASK
28+ from ..rpc .client import RPCSession
29+ from ..contrib import utils
2730
2831
29- def _create_args (mod : IRModule , dev : Device , func_name : str = "main" ):
32+ def _create_args (mod : IRModule , dev : Device , func_name : str = "main" , remote = None ):
33+ if dev .device_type >= RPC_SESS_MASK :
34+ random_fill = remote .get_function ("tvm.contrib.random.random_fill" )
35+ else :
36+ random_fill = get_global_func ("tvm.contrib.random.random_fill" )
37+ assert random_fill , "Please make sure USE_RANDOM is ON in config.cmake"
3038 args = []
3139 for arg in mod [func_name ].params :
32- args .append (
33- nd .array (
34- np .zeros ([x .value for x in arg .type_annotation .shape ], arg .type_annotation .dtype ),
35- device = dev ,
36- )
40+ ary = nd .empty (
41+ [x .value for x in arg .type_annotation .shape ],
42+ arg .type_annotation .dtype ,
43+ device = dev ,
3744 )
45+ random_fill (ary )
46+ args .append (ary )
3847 return args
3948
4049
@@ -103,6 +112,7 @@ def estimate_peak_fma_flops(
103112 dev : Device ,
104113 vec_width : Optional [int ] = None ,
105114 num_vector_registers : Optional [int ] = None ,
115+ remote : Optional [RPCSession ] = None ,
106116) -> float :
107117 """
108118 Estimate the maximum number of FLOP/s this target/device combo is capable
@@ -123,6 +133,9 @@ def estimate_peak_fma_flops(
123133 num_vector_registers : Optional[int]
124134 Number of vector registers on the underlying hardware. Will try to
125135 infer if no value is provided.
136+ remote : Optional[RPCSession]
137+ Remote session used to upload artifacts for runtime evaluation. Must be
138+ the same session used to create `dev`.
126139
127140 Returns
128141 -------
@@ -146,7 +159,23 @@ def estimate_peak_fma_flops(
146159 )
147160 with transform .PassContext (opt_level = 3 ):
148161 f = build (specialized , target = target )
149- a = nd .array (np .ones ((nthreads , num_vector_registers , vec_width ), dtype = "float32" ), device = dev )
162+
163+ # upload to remote if running over rpc
164+ if dev .device_type >= RPC_SESS_MASK :
165+ if remote is None :
166+ raise RuntimeError ("A RPCSession must be provided when using a remote device." )
167+ temp = utils .tempdir ()
168+ path = temp .relpath ("peak_fma_flops.tar" )
169+ f .export_library (path )
170+ remote .upload (path )
171+ f = remote .load_module ("peak_fma_flops.tar" )
172+ random_fill = remote .get_function ("tvm.contrib.random.random_fill" )
173+ else :
174+ random_fill = get_global_func ("tvm.contrib.random.random_fill" )
175+ assert random_fill , "Please make sure USE_RANDOM is ON in config.cmake"
176+
177+ a = nd .empty ((nthreads , num_vector_registers , vec_width ), dtype = "float32" , device = dev )
178+ random_fill (a )
150179 times = f .time_evaluator (f .entry_name , dev , repeat = 100 , number = 1 )(a )
151180 flops = 2 * vec_width * num_vector_registers * nthreads * iters # fma is two flops
152181 flop_s = flops / times .min
@@ -171,7 +200,12 @@ def peak_bandwidth_tir(a: T.handle, b: T.handle, threads: T.int32, vec_width: T.
171200 B [i , l , j ] += A [i , k , l , j ]
172201
173202
174- def estimate_peak_bandwidth (target : Target , dev : Device , vec_width : Optional [int ] = None ) -> float :
203+ def estimate_peak_bandwidth (
204+ target : Target ,
205+ dev : Device ,
206+ vec_width : Optional [int ] = None ,
207+ remote : Optional [RPCSession ] = None ,
208+ ) -> float :
175209 """Estimate peak memory bandwidth of a target/device combo.
176210
177211 Peak bandwidth is estimated by running a small experiment on the underlying
@@ -187,6 +221,9 @@ def estimate_peak_bandwidth(target: Target, dev: Device, vec_width: Optional[int
187221 Device to measure peak bandwidth on.
188222 vec_width : Optional[int]
189223 Vector unit width, determined from target if not supplied.
224+ remote : Optional[RPCSession]
225+ Remote session used to upload artifacts for runtime evaluation. Must be
226+ the same session used to create `dev`.
190227
191228 Returns
192229 -------
@@ -207,13 +244,30 @@ def estimate_peak_bandwidth(target: Target, dev: Device, vec_width: Optional[int
207244 )
208245 with transform .PassContext (opt_level = 3 ):
209246 f = build (specialized , target = target )
247+
248+ # upload to remote if running over rpc
249+ if dev .device_type >= RPC_SESS_MASK :
250+ if remote is None :
251+ raise RuntimeError ("A RPCSession must be provided when using a remote device." )
252+ temp = utils .tempdir ()
253+ path = temp .relpath ("peak_bandwidth.tar" )
254+ f .export_library (path )
255+ remote .upload (path )
256+ f = remote .load_module ("peak_bandwidth.tar" )
257+ random_fill = remote .get_function ("tvm.contrib.random.random_fill" )
258+ else :
259+ random_fill = get_global_func ("tvm.contrib.random.random_fill" )
260+ assert random_fill , "Please make sure USE_RANDOM is ON in config.cmake"
261+
210262 threads = num_threads ()
211263 # Data size needs to be larger than last level of cache. We don't have a
212264 # way of getting cache sizes, so this number should give us a large enough
213265 # size.
214266 size = 10 ** 8 // (4 * threads * vec_width )
215- a = nd .array (np .ones ((threads , size , 4 , vec_width ), dtype = "float32" ), device = dev )
216- b = nd .array (np .ones ((threads , vec_width , 4 ), dtype = "float32" ), device = dev )
267+ a = nd .empty ((threads , size , 4 , vec_width ), dtype = "float32" , device = dev )
268+ random_fill (a )
269+ b = nd .empty ((threads , vec_width , 4 ), dtype = "float32" , device = dev )
270+ random_fill (b )
217271 times = f .time_evaluator (f .entry_name , dev , repeat = 10 , number = 1 )(a , b , threads )
218272 return a .numpy ().size * 4 / times .min # 4 bytes per float32
219273
@@ -241,6 +295,7 @@ def roofline_from_existing(
241295 tir_functions : Dict [GlobalVar , tir .PrimFunc ],
242296 target : Target ,
243297 dev : Device ,
298+ remote : Optional [RPCSession ] = None ,
244299) -> profiling .Report :
245300 """Add roofline and other estimated statistics to an existing profiling report.
246301
@@ -290,6 +345,9 @@ def roofline_from_existing(
290345 TVM target that `report` was generated with.
291346 dev : Device
292347 Device that `report` was generated with.
348+ remote : Optional[RPCSession]
349+ Remote session used to upload artifacts for runtime evaluation. Must be
350+ the same session used to create `dev`.
293351
294352 Returns
295353 -------
@@ -299,8 +357,8 @@ def roofline_from_existing(
299357 :py:func:`roofline_analysis` for more information on which metrics
300358 are included.
301359 """
302- peak_bandwidth = estimate_peak_bandwidth (target , dev )
303- peak_flops = estimate_peak_fma_flops (target , dev )
360+ peak_bandwidth = estimate_peak_bandwidth (target , dev , remote = remote )
361+ peak_flops = estimate_peak_fma_flops (target , dev , remote = remote )
304362
305363 ridge_point = peak_flops / peak_bandwidth
306364
@@ -346,7 +404,11 @@ def roofline_from_existing(
346404
347405
348406def roofline_analysis (
349- mod : IRModule , params : Dict [str , nd .NDArray ], target : Union [str , Target ], dev : Device
407+ mod : IRModule ,
408+ params : Dict [str , nd .NDArray ],
409+ target : Union [str , Target ],
410+ dev : Device ,
411+ remote : Optional [RPCSession ] = None ,
350412) -> profiling .Report :
351413 """
352414 Create a profiling report that contains roofline and other estimated
@@ -385,6 +447,10 @@ def roofline_analysis(
385447 dev : Device
386448 Device to run on.
387449
450+ remote : Optional[RPCSession]
451+ Remote session used to upload artifacts for runtime evaluation. Must be
452+ the same session used to create `dev`.
453+
388454 Returns
389455 -------
390456
@@ -405,9 +471,18 @@ def roofline_analysis(
405471 config = pass_ctx .config ,
406472 ):
407473 lib = relay .vm .compile (mod , params = params , target = target )
474+ # upload to remote if running over rpc
475+ if dev .device_type >= RPC_SESS_MASK :
476+ if remote is None :
477+ raise RuntimeError ("A RPCSession must be provided when using a remote device." )
478+ temp = utils .tempdir ()
479+ path = temp .relpath ("roofline_lib.tar" )
480+ lib .mod .export_library (path )
481+ remote .upload (path )
482+ lib = remote .load_module ("roofline_lib.tar" )
408483 vmexec = profiler_vm .VirtualMachineProfiler (lib , dev )
409484
410- args = _create_args (mod , dev )
485+ args = _create_args (mod , dev , remote = remote )
411486 report = vmexec .profile (* args )
412487
413- return roofline_from_existing (report , save_tir .functions , target , dev )
488+ return roofline_from_existing (report , save_tir .functions , target , dev , remote = remote )
0 commit comments