From 222962a07f8fe3ce6c1f4d5ade85fc7ee5b7a3ee Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <yamanu@xilinx.com>
Date: Wed, 13 Oct 2021 15:46:48 +0200
Subject: [PATCH] [Driver] fix driver DataType access

---
 src/finn/transformation/fpgadataflow/make_pynq_driver.py | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/src/finn/transformation/fpgadataflow/make_pynq_driver.py b/src/finn/transformation/fpgadataflow/make_pynq_driver.py
index bfa2fdbf9..2c3bd7ee5 100644
--- a/src/finn/transformation/fpgadataflow/make_pynq_driver.py
+++ b/src/finn/transformation/fpgadataflow/make_pynq_driver.py
@@ -143,7 +143,7 @@ class MakePYNQDriver(Transformation):
             )
             i_tensor_shape_packed = i_tensor_dummy_packed.shape
             # append all input tensor info to relevant lists
-            idt.append(str(i_tensor_dt))
+            idt.append("DataType['%s']" % i_tensor_dt.name)
             ishape_normal.append(i_tensor_shape_normal)
             ishape_folded.append(i_tensor_shape_folded)
             ishape_packed.append(i_tensor_shape_packed)
@@ -190,7 +190,7 @@ class MakePYNQDriver(Transformation):
             )
             o_tensor_shape_packed = o_tensor_dummy_packed.shape
             # append all output tensor info to relevant lists
-            odt.append(str(o_tensor_dt))
+            odt.append("DataType['%s']" % o_tensor_dt.name)
             oshape_normal.append(o_tensor_shape_normal)
             oshape_folded.append(o_tensor_shape_folded)
             oshape_packed.append(o_tensor_shape_packed)
@@ -240,11 +240,11 @@ class MakePYNQDriver(Transformation):
         driver = template_driver.pynq_driver_template
 
         driver = driver.replace("$PLATFORM$", self.platform)
-        driver = driver.replace("$INPUT_FINN_DATATYPE$", str(idt).replace("'", ""))
+        driver = driver.replace("$INPUT_FINN_DATATYPE$", str(idt).replace('"', ""))
         driver = driver.replace("$INPUT_SHAPE_NORMAL$", str(ishape_normal))
         driver = driver.replace("$INPUT_SHAPE_FOLDED$", str(ishape_folded))
         driver = driver.replace("$INPUT_SHAPE_PACKED$", str(ishape_packed))
-        driver = driver.replace("$OUTPUT_FINN_DATATYPE$", str(odt).replace("'", ""))
+        driver = driver.replace("$OUTPUT_FINN_DATATYPE$", str(odt).replace('"', ""))
         driver = driver.replace("$OUTPUT_SHAPE_NORMAL$", str(oshape_normal))
         driver = driver.replace("$OUTPUT_SHAPE_FOLDED$", str(oshape_folded))
         driver = driver.replace("$OUTPUT_SHAPE_PACKED$", str(oshape_packed))
-- 
GitLab