Skip to content

Commit 56c57e5

Browse files
committed
[AUTOMATION] Infrastructure to use hardware and schedules from vta-experiments (apache#7)
1 parent cd21608 commit 56c57e5

6 files changed

Lines changed: 35 additions & 18 deletions

File tree

vta/config/pynq_sample.json

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,21 @@
11
{
22
"TARGET" : "pynq",
3-
"HW_VER" : "0.0.1",
3+
"HW_VER" : "0.0.2",
44
"HW_FREQ" : 100,
5-
"HW_CLK_TARGET" : 8,
6-
"ALU" : true,
7-
"GEMM_II" : 2,
8-
"TALU_II" : 4,
5+
"HW_CLK_TARGET" : 7,
6+
"ALU_EN" : true,
7+
"MUL_EN" : true,
8+
"GEMM_II" : 1,
9+
"TALU_II" : 2,
910
"LOG_INP_WIDTH" : 3,
10-
"LOG_WGT_WIDTH" : 1,
11+
"LOG_WGT_WIDTH" : 3,
1112
"LOG_ACC_WIDTH" : 5,
1213
"LOG_OUT_WIDTH" : 3,
1314
"LOG_BATCH" : 0,
14-
"LOG_BLOCK_IN" : 5,
15-
"LOG_BLOCK_OUT" : 5,
15+
"LOG_BLOCK_IN" : 4,
16+
"LOG_BLOCK_OUT" : 4,
1617
"LOG_UOP_BUFF_SIZE" : 15,
17-
"LOG_INP_BUFF_SIZE" : 17,
18-
"LOG_WGT_BUFF_SIZE" : 17,
18+
"LOG_INP_BUFF_SIZE" : 15,
19+
"LOG_WGT_BUFF_SIZE" : 18,
1920
"LOG_ACC_BUFF_SIZE" : 17
2021
}

vta/hardware/xilinx/sim/vta_test.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,13 +60,13 @@ int main(void) {
6060
#endif // ALU_EN
6161

6262
// Run blocked GEMM test
63-
status |= blocked_gemm_test(256, 256, VTA_BLOCK_OUT*4, true, 2);
63+
// status |= blocked_gemm_test(256, 256, VTA_BLOCK_OUT*4, true, 2);
6464
status |= blocked_gemm_test(256, 256, VTA_BLOCK_OUT*4, false, 2);
65-
status |= blocked_gemm_test(256, 256, VTA_BLOCK_OUT*4, true, 1);
65+
// status |= blocked_gemm_test(256, 256, VTA_BLOCK_OUT*4, true, 1);
6666
status |= blocked_gemm_test(256, 256, VTA_BLOCK_OUT*4, false, 1);
6767

6868
// Simple GEMM unit test
69-
status |= gemm_test(4 * VTA_BATCH, 4 * VTA_BLOCK_OUT, 4 * VTA_BLOCK_IN, true);
69+
status |= gemm_test(4 * VTA_BATCH, 4 * VTA_BLOCK_OUT, 4 * VTA_BLOCK_IN, false);
7070

7171
return status;
7272
}

vta/python/vta/bitstream.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ def get_bitstream_path():
2828

2929
# Derive destination path
3030
cache_dir = os.getenv("VTA_CACHE_PATH", os.path.join(os.getenv("HOME"), ".vta_cache/"))
31-
cache_dir = os.path.join(cache_dir, env.TARGET)
3231
# Create the directory if it didn't exist
3332
if not os.path.exists(cache_dir):
3433
os.makedirs(cache_dir)

vta/python/vta/environment.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def __init__(self, cfg):
150150
self._dev_ctx = None
151151
self._last_env = None
152152
# derive bitstream name
153-
self.BITSTREAM = "{}_{}x{}x{}_a{}w{}o{}_{}_{}_{}_{}_{}MHz_{}ns_gii{}".format(
153+
self.BITSTREAM = "{}/bitstreams_valid/{}x{}x{}_a{}w{}o{}_{}_{}_{}_{}_{}MHz_{}ns_gii{}".format(
154154
self.HW_VER.replace('.', '_'),
155155
self.BATCH,
156156
self.BLOCK_IN,

vta/python/vta/top/vta_conv2d.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import logging
77
import tvm
88
import topi
9+
import re
910

1011
from nnvm.top import registry as reg, OpPattern
1112
from nnvm.top import nn as _nn
@@ -340,7 +341,7 @@ def _get_workload(data, pad_data, kernel, output):
340341
w_str = (i_w + w_pad*2 - k_w) // (o_w - 1)
341342
return Workload(i_b, i_h, i_w, i_c, o_c, k_h, k_w, h_pad, w_pad, h_str, w_str)
342343

343-
def schedule_packed_conv2d(outs, plan=None, skip_load_inp=False, skip_load_wgt=False,
344+
def schedule_packed_conv2d(outs, planStr=None, skip_load_inp=False, skip_load_wgt=False,
344345
skip_load_acc=False, skip_store_out=False, skip_alu=False,
345346
skip_gemm=False):
346347
""" Schedule the packed conv2d.
@@ -377,7 +378,23 @@ def _traverse(op):
377378
else:
378379
pad_data = None
379380
wrkld = _get_workload(data, pad_data, kernel, output)
380-
if plan is None:
381+
if planStr:
382+
matchObj = re.match( r'b(\d+)oc(\d+)ic(\d+)h(\d+)w(\d+)oc_t(\d+)h_t(\d+)', sched_str)
383+
b_factor = int(matchObj.group(1))
384+
oc_factor = int(matchObj.group(2))
385+
ic_factor = int(matchObj.group(3))
386+
h_factor = int(matchObj.group(4))
387+
w_factor = int(matchObj.group(5))
388+
oc_nthread = int(matchObj.group(6))
389+
h_nthread = int(matchObj.group(7))
390+
plan = Schedule(b_factor=b_factor,
391+
oc_factor=oc_factor,
392+
ic_factor=ic_factor,
393+
h_factor=h_factor,
394+
w_factor=w_factor,
395+
oc_nthread=oc_nthread,
396+
h_nthread=h_nthread)
397+
else:
381398
plan = find_schedules(wrkld, vt_only=True, best_only=True)[0]
382399
logging.info("Trying to find plan for %s", wrkld)
383400
env = get_env()

vta/src/sim/sim_driver.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -526,7 +526,7 @@ class Device {
526526

527527
template<bool use_imm, typename F>
528528
void RunALULoop(const VTAAluInsn* op, F func) {
529-
prof_->alu_counter += op->iter_out * op->iter_in * op->uop_end - op->uop_bgn;
529+
prof_->alu_counter += op->iter_out * op->iter_in * (op->uop_end - op->uop_bgn);
530530
if (prof_->SkipExec()) return;
531531
for (int y = 0; y < op->iter_out; ++y) {
532532
for (int x = 0; x < op->iter_in; ++x) {

0 commit comments

Comments
 (0)