From 915ec0d4a5991cdf08f354474fae482721dcae54 Mon Sep 17 00:00:00 2001
From: auphelia <jakobapk@web.de>
Date: Thu, 7 Nov 2019 11:48:32 -0500
Subject: [PATCH] [Code generation] Added some of the missing layer parameters
 (WMEM, TMEM and API) to get_layer_parameters()

---
 src/finn/backend/fpgadataflow/code_gen.py | 27 ++++++++++++++++++++---
 1 file changed, 24 insertions(+), 3 deletions(-)

diff --git a/src/finn/backend/fpgadataflow/code_gen.py b/src/finn/backend/fpgadataflow/code_gen.py
index 45c1dcabc..2c742075b 100644
--- a/src/finn/backend/fpgadataflow/code_gen.py
+++ b/src/finn/backend/fpgadataflow/code_gen.py
@@ -1,4 +1,4 @@
-def get_layer_attributes(node):
+def get_layer_parameters(model, node):
     # Layer attributes
     num_attr = len(node.attribute)
     for k in range(num_attr):
@@ -14,6 +14,14 @@ def get_layer_attributes(node):
             L_resDataType = node.attribute[k].s
         if node.attribute[k].name == "resType":
             L_resType = node.attribute[k].s
+
+    # get other parameters
+    weights_shape = model.get_tensor_shape(node.input[1])
+    thresholds_shape = model.get_tensor_shape(node.input[2])
+    L_WMEM = weights_shape[2]
+    L_TMEM = thresholds_shape[0]
+    L_API = thresholds_shape[2]
+
     return [
         L_PE,
         L_SIMD,
@@ -21,6 +29,9 @@ def get_layer_attributes(node):
         L_MW,
         L_resDataType.decode("utf-8"),
         L_resType.decode("utf-8"),
+        L_WMEM,
+        L_TMEM,
+        L_API,
     ]
 
 
@@ -87,8 +98,18 @@ def computation_cmds(model, code_gen_dict):
             weights = node.input[1]
             thresholds = node.input[2]
             outp = node.output[0]
-            # get layer attributes
-            [PE, SIMD, MH, MW, resDataType, resType] = get_layer_attributes(node)
+            # get layer parameters
+            [
+                PE,
+                SIMD,
+                MH,
+                MW,
+                resDataType,
+                resType,
+                WMEM,
+                TMEM,
+                API,
+            ] = get_layer_parameters(model, node)
 
             code_gen_dict["compute"].append(
                 "{}<L{}_MW, L{}_MH, L{}_SIMD, L{}_PE, {}> "
-- 
GitLab