From 1f2488aecbf9f7ae0169943327639879fc2987b5 Mon Sep 17 00:00:00 2001
From: Hendrik Borras <hendrikborras@web.de>
Date: Wed, 23 Jun 2021 12:23:45 +0100
Subject: [PATCH] Added SIMD requirement to StreamingFCLayer_Batch

---
 .../fpgadataflow/streamingfclayer_batch.py    | 41 +++++++++++++------
 1 file changed, 29 insertions(+), 12 deletions(-)

diff --git a/src/finn/custom_op/fpgadataflow/streamingfclayer_batch.py b/src/finn/custom_op/fpgadataflow/streamingfclayer_batch.py
index 48b40f105..1e4dc0b18 100644
--- a/src/finn/custom_op/fpgadataflow/streamingfclayer_batch.py
+++ b/src/finn/custom_op/fpgadataflow/streamingfclayer_batch.py
@@ -238,9 +238,10 @@ class StreamingFCLayer_Batch(HLSCustomOp):
         mem_width = Q * W * P
         mmode = self.get_nodeattr("mem_mode")
         mstyle = self.get_nodeattr("ram_style")
-        if (mmode == "decoupled" and mstyle != "ultra") or (
-            mmode == "const" and self.calc_wmem() <= 128) or (
-            mmode == "external"
+        if (
+            (mmode == "decoupled" and mstyle != "ultra")
+            or (mmode == "const" and self.calc_wmem() <= 128)
+            or (mmode == "external")
         ):
             return 0
         width_multiplier = math.ceil(mem_width / 72)
@@ -266,9 +267,10 @@ class StreamingFCLayer_Batch(HLSCustomOp):
         mem_width = Q * W * P
         mmode = self.get_nodeattr("mem_mode")
         mstyle = self.get_nodeattr("ram_style")
-        if (mmode == "decoupled" and mstyle in ["distributed", "ultra"]) or (
-            mmode == "const" and self.calc_wmem() <= 128) or (
-            mmode == "external"
+        if (
+            (mmode == "decoupled" and mstyle in ["distributed", "ultra"])
+            or (mmode == "const" and self.calc_wmem() <= 128)
+            or (mmode == "external")
         ):
             return 0
         # assuming SDP mode RAMB18s (see UG573 Table 1-10)
@@ -604,9 +606,11 @@ class StreamingFCLayer_Batch(HLSCustomOp):
                     tdt = DataType.get_smallest_possible(0 - tdt_max)
             else:
                 tdt = DataType.get_smallest_possible(tdt_max)
-            assert np.vectorize(tdt.allowed)(threshold_tensor).all(), (
-                "Thresholds in %s can't be expressed with type %s"
-                % (self.onnx_node.name, str(tdt))
+            assert np.vectorize(tdt.allowed)(
+                threshold_tensor
+            ).all(), "Thresholds in %s can't be expressed with type %s" % (
+                self.onnx_node.name,
+                str(tdt),
             )
             self.set_nodeattr("accDataType", tdt.name)
         else:
@@ -843,9 +847,11 @@ class StreamingFCLayer_Batch(HLSCustomOp):
                 # get computed threshold datatype from attribute
                 tdt = DataType[self.get_nodeattr("accDataType")]
 
-                assert np.vectorize(tdt.allowed)(threshold_tensor).all(), (
-                    "Thresholds in %s can't be expressed with type %s"
-                    % (self.onnx_node.name, str(tdt))
+                assert np.vectorize(tdt.allowed)(
+                    threshold_tensor
+                ).all(), "Thresholds in %s can't be expressed with type %s" % (
+                    self.onnx_node.name,
+                    str(tdt),
                 )
                 thresholds_hls_code = numpy_to_hls_code(
                     threshold_tensor, tdt, "thresholds", False, True
@@ -1005,6 +1011,17 @@ class StreamingFCLayer_Batch(HLSCustomOp):
             self.code_gen_dict["$GLOBALS$"] += ['#include "thresh.h"']
 
     def defines(self, var):
+        # Only ipgen mode: Make sure that SIMD parameter satisfies minimum requirements.
+        if var == "ipgen":
+            SIMD = self.get_nodeattr("SIMD")
+            MW = self.get_nodeattr("MW")
+            condition = SIMD > (MW / 1024)
+            msg = (
+                f"HLS synthesis of StreamingFCLayer_Batch requires: "
+                f"SIMD > MW / 1024. This is not fulfilled with: SIMD={SIMD} "
+                f"and MW={MW} for node: {self.onnx_node.name}."
+            )
+            assert condition, msg
         mem_mode = self.get_nodeattr("mem_mode")
         numInputVectors = list(self.get_nodeattr("numInputVectors"))
         numReps = np.prod(numInputVectors)
-- 
GitLab