Skip to content
Snippets Groups Projects
Commit a2f2d752 authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[PYNQ] more flexible driver that can do N samples instead of 1

parent 49b8f79f
No related branches found
No related tags found
No related merge requests found
......@@ -87,14 +87,25 @@ class MakePYNQDriver(Transformation):
# fill in the driver template
driver_py = pynq_driver_dir + "/driver.py"
driver = templates.pynq_driver_template
def mss(x, batch_var_name="N"):
# "make shape string"
# for a shape like (1, ...) emit a string (N, ...)
# where N is the default value for batch_var_name
# this lets the driver work with a batch of samples at once
ret = str(x)
ret = ret.replace("(1,", "(%s," % batch_var_name)
ret = ret.replace("[1,", "[%s," % batch_var_name)
return ret
driver = driver.replace("$INPUT_FINN_DATATYPE$", str(i_tensor_dt))
driver = driver.replace("$INPUT_SHAPE_NORMAL$", str(i_tensor_shape_normal))
driver = driver.replace("$INPUT_SHAPE_FOLDED$", str(i_tensor_shape_folded))
driver = driver.replace("$INPUT_SHAPE_PACKED$", str(i_tensor_shape_packed))
driver = driver.replace("$INPUT_SHAPE_NORMAL$", mss(i_tensor_shape_normal))
driver = driver.replace("$INPUT_SHAPE_FOLDED$", mss(i_tensor_shape_folded))
driver = driver.replace("$INPUT_SHAPE_PACKED$", mss(i_tensor_shape_packed))
driver = driver.replace("$OUTPUT_FINN_DATATYPE$", str(o_tensor_dt))
driver = driver.replace("$OUTPUT_SHAPE_NORMAL$", str(o_tensor_shape_normal))
driver = driver.replace("$OUTPUT_SHAPE_FOLDED$", str(o_tensor_shape_folded))
driver = driver.replace("$OUTPUT_SHAPE_PACKED$", str(o_tensor_shape_packed))
driver = driver.replace("$OUTPUT_SHAPE_NORMAL$", mss(o_tensor_shape_normal))
driver = driver.replace("$OUTPUT_SHAPE_FOLDED$", mss(o_tensor_shape_folded))
driver = driver.replace("$OUTPUT_SHAPE_PACKED$", mss(o_tensor_shape_packed))
with open(driver_py, "w") as f:
f.write(driver)
......
......@@ -98,6 +98,13 @@ from finn.core.datatype import DataType
bitfile_path = "resizer.bit"
ol = Overlay(bitfile_path)
dma=ol.axi_dma_0
ctrl_regs=ol.resize_accel_0
# AXI lite register offset for number of iterations
# used by TLastMarker to signal end of transmission for AXI CDMA
REG_OFFSET_NUM_ITERS = 0x10
# number of samples for inference
N = 1
# declare input/output types and shapes for the accelerator
# input FINN DataType
......@@ -113,6 +120,9 @@ oshape_normal = $OUTPUT_SHAPE_NORMAL$
oshape_folded = $OUTPUT_SHAPE_FOLDED$
oshape_packed = $OUTPUT_SHAPE_PACKED$
# set up TLastMarker with correct num. samples
ctrl_regs.write(REG_OFFSET_NUM_ITERS, N)
# load desired input .npy file
ibuf_normal = np.load("input.npy")
# ensure that shape is as expected
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment