From faaf3cef499aab3ca6316e4c4633cecb26b4dc5d Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <maltanar@gmail.com>
Date: Tue, 24 Mar 2020 23:16:38 +0000
Subject: [PATCH] [StreamingFC] more i/o shape fixes to support mat-mat

---
 .../fpgadataflow/streamingfclayer_batch.py    | 21 +++++++++++++++----
 1 file changed, 17 insertions(+), 4 deletions(-)

diff --git a/src/finn/custom_op/fpgadataflow/streamingfclayer_batch.py b/src/finn/custom_op/fpgadataflow/streamingfclayer_batch.py
index bf1f2fd7f..39386b00b 100644
--- a/src/finn/custom_op/fpgadataflow/streamingfclayer_batch.py
+++ b/src/finn/custom_op/fpgadataflow/streamingfclayer_batch.py
@@ -251,8 +251,20 @@ class StreamingFCLayer_Batch(HLSCustomOp):
         folded_output_shape = tuple(vecs + [nf, pe])
         return folded_output_shape
 
+    def get_normal_input_shape(self):
+        mw = self.get_nodeattr("MW")
+        vecs = list(self.get_nodeattr("numInputVectors"))
+        normal_input_shape = tuple(vecs + [mw])
+        return normal_input_shape
+
+    def get_normal_output_shape(self):
+        mh = self.get_nodeattr("MH")
+        vecs = list(self.get_nodeattr("numInputVectors"))
+        normal_output_shape = tuple(vecs + [mh])
+        return normal_output_shape
+
     def get_number_output_values(self):
-        nf = self.get_folded_output_shape()[1]
+        nf = np.prod(self.get_folded_output_shape()[:-1])
         return nf
 
     def get_template_param_values(self):
@@ -466,7 +478,6 @@ class StreamingFCLayer_Batch(HLSCustomOp):
     def execute_node(self, context, graph):
         mode = self.get_nodeattr("exec_mode")
         node = self.onnx_node
-        mh = self.get_nodeattr("MH")
 
         # TODO ensure codegen dir exists
         if mode == "npysim":
@@ -522,7 +533,8 @@ class StreamingFCLayer_Batch(HLSCustomOp):
                 context[node.output[0]].shape == self.get_folded_output_shape()
             ), """Output shape is not as expected"""
             # reshape output to have expected shape
-            context[node.output[0]] = context[node.output[0]].reshape(1, mh)
+            oshape = self.get_normal_output_shape()
+            context[node.output[0]] = context[node.output[0]].reshape(*oshape)
         elif mode == "rtlsim":
             prefixed_top_name = "%s_%s" % (node.name, node.name)
             # check if needed file exists
@@ -556,7 +568,8 @@ class StreamingFCLayer_Batch(HLSCustomOp):
 
                 # load and reshape output
                 output = np.load(out_npy_path)
-                output = np.asarray([output], dtype=np.float32).reshape(1, mh)
+                oshape = self.get_normal_output_shape()
+                output = np.asarray([output], dtype=np.float32).reshape(*oshape)
                 context[node.output[0]] = output
 
             else:
-- 
GitLab