From 4e3773699d6c3881774dfaf1c6d4bb1a3af691c9 Mon Sep 17 00:00:00 2001
From: Tobi-Alonso <tobi.alonso@gmail.com>
Date: Fri, 5 Mar 2021 16:47:49 +0000
Subject: [PATCH] [fpgadataflow] Add ext weight count check

---
 src/finn/qnn-data/templates/driver/driver_base.py        | 6 ++++++
 src/finn/transformation/fpgadataflow/make_pynq_driver.py | 5 ++++-
 src/finn/transformation/fpgadataflow/template_driver.py  | 3 ++-
 3 files changed, 12 insertions(+), 2 deletions(-)

diff --git a/src/finn/qnn-data/templates/driver/driver_base.py b/src/finn/qnn-data/templates/driver/driver_base.py
index 439a1304a..01f332b3b 100644
--- a/src/finn/qnn-data/templates/driver/driver_base.py
+++ b/src/finn/qnn-data/templates/driver/driver_base.py
@@ -156,6 +156,12 @@ class FINNExampleOverlay(Overlay):
 
                 self.external_weights +=[(iwdma,weight_buf)]
 
+        if "number_of_external_weights" in self._io_shape_dict:
+            hw_ext_weights = self._io_shape_dict["number_of_external_weights"]
+            assert len(self.external_weights) == hw_ext_weights, (
+                "Number of hardware external weights and number of external " +
+                "weight tensors available do not match. \n"+
+                "Is runtime_weight_dir pointing to the correct folder?")
 
 
     def load_runtime_weights(self, flush_accel=True, verify=True):
diff --git a/src/finn/transformation/fpgadataflow/make_pynq_driver.py b/src/finn/transformation/fpgadataflow/make_pynq_driver.py
index 092c43c87..de5e180f5 100644
--- a/src/finn/transformation/fpgadataflow/make_pynq_driver.py
+++ b/src/finn/transformation/fpgadataflow/make_pynq_driver.py
@@ -123,6 +123,7 @@ class MakePYNQDriver(Transformation):
         
         os.makedirs(weights_dir)
         idma_idx = 0
+        ext_weight_dma_cnt = 0
             
         for node in model.graph.node:
             assert node.op_type == "StreamingDataflowPartition", (
@@ -134,6 +135,7 @@ class MakePYNQDriver(Transformation):
             if producer is None : # input dma?
                 idma_name = "idma" + str(idma_idx)
                 if init_tensor is not None: # input weights dma?
+                    ext_weight_dma_cnt += 1
                     w_dtype = model.get_tensor_datatype(node.input[0])
                     init_external_tensor = to_external_tensor(init_tensor,w_dtype)
                     np.save(weights_dir+"/"+ idma_name+".npy",init_external_tensor)
@@ -169,7 +171,8 @@ class MakePYNQDriver(Transformation):
         driver = driver.replace("$OUTPUT_SHAPE_NORMAL$", mss(o_tensor_shape_normal))
         driver = driver.replace("$OUTPUT_SHAPE_FOLDED$", mss(o_tensor_shape_folded))
         driver = driver.replace("$OUTPUT_SHAPE_PACKED$", mss(o_tensor_shape_packed))
-        driver = driver.replace("$INPUT_DMA_NAME$", "'%s'" %net_input_name)
+        driver = driver.replace("$INPUT_DMA_NAME$", "'%s'" % net_input_name)
+        driver = driver.replace("$EXT_WEIGHT_NUM$", str(ext_weight_dma_cnt) )
 
         with open(driver_py, "w") as f:
             f.write(driver)
diff --git a/src/finn/transformation/fpgadataflow/template_driver.py b/src/finn/transformation/fpgadataflow/template_driver.py
index 6be318247..5265835dd 100644
--- a/src/finn/transformation/fpgadataflow/template_driver.py
+++ b/src/finn/transformation/fpgadataflow/template_driver.py
@@ -78,7 +78,8 @@ io_shape_dict = {
     "oshape_folded" : $OUTPUT_SHAPE_FOLDED$,
     "ishape_packed" : $INPUT_SHAPE_PACKED$,
     "oshape_packed" : $OUTPUT_SHAPE_PACKED$,
-    "input_dma_name" : $INPUT_DMA_NAME$
+    "input_dma_name" : $INPUT_DMA_NAME$,
+    "number_of_external_weights": $EXT_WEIGHT_NUM$
 }
 
 if __name__ == "__main__":
-- 
GitLab