From 222962a07f8fe3ce6c1f4d5ade85fc7ee5b7a3ee Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu <yamanu@xilinx.com> Date: Wed, 13 Oct 2021 15:46:48 +0200 Subject: [PATCH] [Driver] fix driver DataType access --- src/finn/transformation/fpgadataflow/make_pynq_driver.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/finn/transformation/fpgadataflow/make_pynq_driver.py b/src/finn/transformation/fpgadataflow/make_pynq_driver.py index bfa2fdbf9..2c3bd7ee5 100644 --- a/src/finn/transformation/fpgadataflow/make_pynq_driver.py +++ b/src/finn/transformation/fpgadataflow/make_pynq_driver.py @@ -143,7 +143,7 @@ class MakePYNQDriver(Transformation): ) i_tensor_shape_packed = i_tensor_dummy_packed.shape # append all input tensor info to relevant lists - idt.append(str(i_tensor_dt)) + idt.append("DataType['%s']" % i_tensor_dt.name) ishape_normal.append(i_tensor_shape_normal) ishape_folded.append(i_tensor_shape_folded) ishape_packed.append(i_tensor_shape_packed) @@ -190,7 +190,7 @@ class MakePYNQDriver(Transformation): ) o_tensor_shape_packed = o_tensor_dummy_packed.shape # append all output tensor info to relevant lists - odt.append(str(o_tensor_dt)) + odt.append("DataType['%s']" % o_tensor_dt.name) oshape_normal.append(o_tensor_shape_normal) oshape_folded.append(o_tensor_shape_folded) oshape_packed.append(o_tensor_shape_packed) @@ -240,11 +240,11 @@ class MakePYNQDriver(Transformation): driver = template_driver.pynq_driver_template driver = driver.replace("$PLATFORM$", self.platform) - driver = driver.replace("$INPUT_FINN_DATATYPE$", str(idt).replace("'", "")) + driver = driver.replace("$INPUT_FINN_DATATYPE$", str(idt).replace('"', "")) driver = driver.replace("$INPUT_SHAPE_NORMAL$", str(ishape_normal)) driver = driver.replace("$INPUT_SHAPE_FOLDED$", str(ishape_folded)) driver = driver.replace("$INPUT_SHAPE_PACKED$", str(ishape_packed)) - driver = driver.replace("$OUTPUT_FINN_DATATYPE$", str(odt).replace("'", "")) + driver = driver.replace("$OUTPUT_FINN_DATATYPE$", str(odt).replace('"', "")) driver = driver.replace("$OUTPUT_SHAPE_NORMAL$", str(oshape_normal)) driver = driver.replace("$OUTPUT_SHAPE_FOLDED$", str(oshape_folded)) driver = driver.replace("$OUTPUT_SHAPE_PACKED$", str(oshape_packed)) -- GitLab