From 3a23d26c6f4a0534c09b60bd5e012dd3ec552def Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu <maltanar@gmail.com> Date: Thu, 6 Aug 2020 01:14:03 +0200 Subject: [PATCH] [Driver] make i/o folded shape dependent on platform --- .../fpgadataflow/make_pynq_driver.py | 32 +++++++++++-------- .../transformation/fpgadataflow/templates.py | 2 +- .../test_zynqbuild_end2end_tfc_w1a1.py | 2 +- 3 files changed, 21 insertions(+), 15 deletions(-) diff --git a/src/finn/transformation/fpgadataflow/make_pynq_driver.py b/src/finn/transformation/fpgadataflow/make_pynq_driver.py index 1e45a6572..a7bf9e6e6 100644 --- a/src/finn/transformation/fpgadataflow/make_pynq_driver.py +++ b/src/finn/transformation/fpgadataflow/make_pynq_driver.py @@ -26,10 +26,8 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -import os -import shutil -import warnings +import shutil from finn.custom_op.registry import getCustomOp from finn.transformation import Transformation from finn.util.basic import gen_finn_dt_tensor, get_finn_root, make_build_dir @@ -48,14 +46,11 @@ class MakePYNQDriver(Transformation): value. """ - def __init__(self): + def __init__(self, platform): super().__init__() + self.platform = platform def apply(self, model): - vivado_pynq_proj = model.get_metadata_prop("vivado_pynq_proj") - if vivado_pynq_proj is None or (not os.path.isdir(vivado_pynq_proj)): - warnings.warn("No PYNQ project found, apply MakePYNQProject first.") - # create a temporary folder for the generated driver pynq_driver_dir = make_build_dir(prefix="pynq_driver_") model.set_metadata_prop("pynq_driver_dir", pynq_driver_dir) @@ -68,11 +63,21 @@ class MakePYNQDriver(Transformation): o_tensor_shape_normal = tuple(model.get_tensor_shape(o_tensor_name)) i_tensor_dt = model.get_tensor_datatype(i_tensor_name) o_tensor_dt = model.get_tensor_datatype(o_tensor_name) - # extract HLSCustomOp instances to get folded i/o shapes - first_node = getCustomOp(model.find_consumer(i_tensor_name)) - last_node = getCustomOp(model.find_producer(o_tensor_name)) - i_tensor_shape_folded = tuple(first_node.get_folded_input_shape()) - o_tensor_shape_folded = tuple(last_node.get_folded_output_shape()) + # handle folded i/o shapes due to differences in DMA engines + if self.platform == "zynq": + # extract HLSCustomOp instances to get folded i/o shapes + first_node = getCustomOp(model.find_consumer(i_tensor_name)) + last_node = getCustomOp(model.find_producer(o_tensor_name)) + i_tensor_shape_folded = tuple(first_node.get_folded_input_shape()) + o_tensor_shape_folded = tuple(last_node.get_folded_output_shape()) + else: + i_tensor_shape_folded = list(i_tensor_shape_normal) + i_tensor_shape_folded.insert(-1, 1) + i_tensor_shape_folded = tuple(i_tensor_shape_folded) + o_tensor_shape_folded = list(o_tensor_shape_normal) + o_tensor_shape_folded.insert(-1, 1) + o_tensor_shape_folded = tuple(o_tensor_shape_folded) + # generate dummy folded i/o tensors and their packed versions i_tensor_dummy_folded = gen_finn_dt_tensor(i_tensor_dt, i_tensor_shape_folded) o_tensor_dummy_folded = gen_finn_dt_tensor(o_tensor_dt, o_tensor_shape_folded) @@ -99,6 +104,7 @@ class MakePYNQDriver(Transformation): ret = ret.replace("[1,", "[%s," % batch_var_name) return ret + driver = driver.replace("$PLATFORM$", self.platform) driver = driver.replace("$INPUT_FINN_DATATYPE$", str(i_tensor_dt)) driver = driver.replace("$INPUT_SHAPE_NORMAL$", mss(i_tensor_shape_normal)) driver = driver.replace("$INPUT_SHAPE_FOLDED$", mss(i_tensor_shape_folded)) diff --git a/src/finn/transformation/fpgadataflow/templates.py b/src/finn/transformation/fpgadataflow/templates.py index d22163cbe..036286046 100644 --- a/src/finn/transformation/fpgadataflow/templates.py +++ b/src/finn/transformation/fpgadataflow/templates.py @@ -104,7 +104,7 @@ from finn.core.datatype import DataType from pynq.ps import Clocks class FINNAccelDriver(): - def __init__(self, N, bitfile, platform="zynq"): + def __init__(self, N, bitfile, platform="$PLATFORM$"): \"\"\"Instantiate the FINN accelerator driver. Gets batchsize (N) as integer and path to bitfile as string.\"\"\" self.platform = platform diff --git a/tests/end2end/test_zynqbuild_end2end_tfc_w1a1.py b/tests/end2end/test_zynqbuild_end2end_tfc_w1a1.py index f3250baa5..4a5901d32 100644 --- a/tests/end2end/test_zynqbuild_end2end_tfc_w1a1.py +++ b/tests/end2end/test_zynqbuild_end2end_tfc_w1a1.py @@ -156,7 +156,7 @@ def test_end2end_zynqbuild_tfc_w1a1_make_driver(): model = load_test_checkpoint_or_skip( build_dir + "/end2end_zynqbuild_tfc_w1a1_folded.onnx" ) - model = model.transform(MakePYNQDriver()) + model = model.transform(MakePYNQDriver(platform="zynq-iodma")) model.save(build_dir + "/end2end_zynqbuild_tfc_w1a1_pynq_driver.onnx") -- GitLab