From 8e5b8bac243f2734d716c09fb24c3853d2e75150 Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <maltanar@gmail.com>
Date: Thu, 5 Dec 2019 23:03:16 +0000
Subject: [PATCH] [Transform] fix InferBinaryStreamingFCLayer shapes and attrs

---
 .../fpgadataflow/convert_to_hls_layers.py     | 54 ++++++++++---------
 .../test_convert_to_hls_layers.py             | 18 +++----
 2 files changed, 38 insertions(+), 34 deletions(-)

diff --git a/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py b/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py
index 60c3fa4d7..f75cf2e91 100644
--- a/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py
+++ b/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py
@@ -1,5 +1,6 @@
-import onnx.helper as oh
+from onnx import helper
 
+from finn.core.datatype import DataType
 from finn.transformation import Transformation
 
 
@@ -35,40 +36,43 @@ class InferBinaryStreamingFCLayer(Transformation):
                         # create node with no parallelization first
                         pe = 1
                         simd = 1
-                        wmem = int(mw * mh)
-                        # extract threshold shape
-                        tmem = int(mh / pe)
-                        n_thres = T.shape[1]
+                        assert mh % pe == 0
+                        assert mw % simd == 0
+                        wmem = mw * mh // (pe * simd)
+                        assert mw * mh == wmem * pe * simd
+                        nf = mh // pe
+                        tmem = nf
                         assert T.shape[0] == 1 or T.shape[0] == mh
-                        assert n_thres == 1
-                        # W is expected to be (PE, WMEM, SIMD)
-                        # transpose first to meet finn-hlslib assumptions
-                        W_new = W.transpose().reshape(pe, wmem, simd)
-                        model.set_initializer(mm_weight, W_new)
-                        # T is expected to be (NF, PE, n_thres)
-                        # TODO need to double-check the threshold shape here
-                        T_new = T.reshape(pe, tmem, n_thres)
-                        model.set_initializer(mt_thres, T_new)
-                        # reshape input and output tensors to expected shape
-                        # input is expected to be (1, mw/simd, simd)
-                        # output is expected to be (1, mh/pe, pe)
-                        in_shape = [1, int(mw / simd), simd]
-                        out_shape = [1, int(mh / pe), pe]
+                        idt = DataType.BINARY
+                        wdt = DataType.BINARY
+                        odt = model.get_tensor_datatype(mt_output)
+                        if odt.bitwidth() == 1:
+                            # covers both bipolar and binary
+                            actval = 0
+                        else:
+                            actval = odt.min()
+                        in_shape = [1, mw]
+                        out_shape = [1, mh]
                         model.set_tensor_shape(mm_input, in_shape)
                         model.set_tensor_shape(mt_output, out_shape)
                         # create and insert new StreamingFCLayer node
-                        new_node = oh.make_node(
+                        new_node = helper.make_node(
                             "StreamingFCLayer_Batch",
                             [mm_input, mm_weight, mt_thres],
                             [mt_output],
                             domain="finn",
                             backend="fpgadataflow",
-                            MH=mh,
-                            MW=mw,
-                            PE=1,
-                            SIMD=1,
-                            resDataType="Recast<XnorMul>",
                             resType="ap_resource_lut()",
+                            MW=mw,
+                            MH=mh,
+                            SIMD=simd,
+                            PE=pe,
+                            WMEM=wmem,
+                            TMEM=tmem,
+                            inputDataType=idt.name,
+                            weightDataType=wdt.name,
+                            outputDataType=odt.name,
+                            ActVal=actval,
                         )
                         graph.node.insert(node_ind, new_node)
                         # remove old nodes
diff --git a/tests/fpgadataflow/test_convert_to_hls_layers.py b/tests/fpgadataflow/test_convert_to_hls_layers.py
index 49eccd4fe..27a0dfe00 100644
--- a/tests/fpgadataflow/test_convert_to_hls_layers.py
+++ b/tests/fpgadataflow/test_convert_to_hls_layers.py
@@ -37,17 +37,17 @@ def test_convert_to_hls_layers_lfc_w1a1():
     model = model.transform(to_hls.InferBinaryStreamingFCLayer())
     fc0 = model.graph.node[2]
     assert fc0.op_type == "StreamingFCLayer_Batch"
-    assert model.get_tensor_shape(fc0.input[0]) == [1, 784, 1]
-    assert model.get_tensor_shape(fc0.input[1]) == [1, 784 * 1024, 1]
-    assert model.get_tensor_shape(fc0.input[2]) == [1, 1024, 1]
+    assert model.get_tensor_shape(fc0.input[0]) == [1, 784]
+    assert model.get_tensor_shape(fc0.input[1]) == [784, 1024]
+    assert model.get_tensor_shape(fc0.input[2]) == [1024, 1]
     fc1 = model.graph.node[3]
     assert fc1.op_type == "StreamingFCLayer_Batch"
-    assert model.get_tensor_shape(fc1.input[0]) == [1, 1024, 1]
-    assert model.get_tensor_shape(fc1.input[1]) == [1, 1024 * 1024, 1]
-    assert model.get_tensor_shape(fc1.input[2]) == [1, 1024, 1]
+    assert model.get_tensor_shape(fc1.input[0]) == [1, 1024]
+    assert model.get_tensor_shape(fc1.input[1]) == [1024, 1024]
+    assert model.get_tensor_shape(fc1.input[2]) == [1024, 1]
     fc2 = model.graph.node[4]
     assert fc2.op_type == "StreamingFCLayer_Batch"
-    assert model.get_tensor_shape(fc2.input[0]) == [1, 1024, 1]
-    assert model.get_tensor_shape(fc2.input[1]) == [1, 1024 * 1024, 1]
-    assert model.get_tensor_shape(fc2.input[2]) == [1, 1024, 1]
+    assert model.get_tensor_shape(fc2.input[0]) == [1, 1024]
+    assert model.get_tensor_shape(fc2.input[1]) == [1024, 1024]
+    assert model.get_tensor_shape(fc2.input[2]) == [1024, 1]
     os.remove(export_onnx_path)
-- 
GitLab