|
6 | 6 | import logging |
7 | 7 | import tvm |
8 | 8 | import topi |
| 9 | +import re |
9 | 10 |
|
10 | 11 | from nnvm.top import registry as reg, OpPattern |
11 | 12 | from nnvm.top import nn as _nn |
@@ -340,7 +341,7 @@ def _get_workload(data, pad_data, kernel, output): |
340 | 341 | w_str = (i_w + w_pad*2 - k_w) // (o_w - 1) |
341 | 342 | return Workload(i_b, i_h, i_w, i_c, o_c, k_h, k_w, h_pad, w_pad, h_str, w_str) |
342 | 343 |
|
343 | | -def schedule_packed_conv2d(outs, plan=None, skip_load_inp=False, skip_load_wgt=False, |
| 344 | +def schedule_packed_conv2d(outs, planStr=None, skip_load_inp=False, skip_load_wgt=False, |
344 | 345 | skip_load_acc=False, skip_store_out=False, skip_alu=False, |
345 | 346 | skip_gemm=False): |
346 | 347 | """ Schedule the packed conv2d. |
@@ -377,7 +378,23 @@ def _traverse(op): |
377 | 378 | else: |
378 | 379 | pad_data = None |
379 | 380 | wrkld = _get_workload(data, pad_data, kernel, output) |
380 | | - if plan is None: |
| 381 | + if planStr: |
| 382 | + matchObj = re.match( r'b(\d+)oc(\d+)ic(\d+)h(\d+)w(\d+)oc_t(\d+)h_t(\d+)', sched_str) |
| 383 | + b_factor = int(matchObj.group(1)) |
| 384 | + oc_factor = int(matchObj.group(2)) |
| 385 | + ic_factor = int(matchObj.group(3)) |
| 386 | + h_factor = int(matchObj.group(4)) |
| 387 | + w_factor = int(matchObj.group(5)) |
| 388 | + oc_nthread = int(matchObj.group(6)) |
| 389 | + h_nthread = int(matchObj.group(7)) |
| 390 | + plan = Schedule(b_factor=b_factor, |
| 391 | + oc_factor=oc_factor, |
| 392 | + ic_factor=ic_factor, |
| 393 | + h_factor=h_factor, |
| 394 | + w_factor=w_factor, |
| 395 | + oc_nthread=oc_nthread, |
| 396 | + h_nthread=h_nthread) |
| 397 | + else: |
381 | 398 | plan = find_schedules(wrkld, vt_only=True, best_only=True)[0] |
382 | 399 | logging.info("Trying to find plan for %s", wrkld) |
383 | 400 | env = get_env() |
|
0 commit comments