From 83862bff89c19a207cfb9767c7d0d9a69ae674b1 Mon Sep 17 00:00:00 2001
From: Felix Jentzsch <fepaje@mail.upb.de>
Date: Tue, 11 May 2021 09:28:03 +0200
Subject: [PATCH] Support implicit flatten via reshape

---
 .../fpgadataflow/insert_fifo.py               |   6 -
 src/finn/transformation/move_reshape.py       |  73 ++++---
 src/finn/transformation/streamline/absorb.py  |  60 +++---
 .../test_convert_to_hls_conv_fc_transition.py | 204 +++++++-----------
 4 files changed, 143 insertions(+), 200 deletions(-)

diff --git a/src/finn/transformation/fpgadataflow/insert_fifo.py b/src/finn/transformation/fpgadataflow/insert_fifo.py
index f51e6212b..1ce936cd7 100644
--- a/src/finn/transformation/fpgadataflow/insert_fifo.py
+++ b/src/finn/transformation/fpgadataflow/insert_fifo.py
@@ -33,12 +33,6 @@ def _suitable_folded_shapes(ishape, oshape):
     matching_size = np.prod(ishape) == np.prod(oshape)
     return matching_stream_width and matching_size
 
-    # i_dummy = np.random.rand(*ishape)
-    # o_dummy = np.random.rand(*oshape)
-    # ishape_canonical = np.squeeze(i_dummy).shape
-    # oshape_canonical = np.squeeze(o_dummy).shape
-    # return ishape_canonical == oshape_canonical
-
 
 class InsertFIFO(Transformation):
     """Inserting FIFOs in the beginning and end of the graph as well as
diff --git a/src/finn/transformation/move_reshape.py b/src/finn/transformation/move_reshape.py
index 87405590e..d723a530b 100644
--- a/src/finn/transformation/move_reshape.py
+++ b/src/finn/transformation/move_reshape.py
@@ -26,28 +26,28 @@ class RemoveCNVtoFCFlatten(Transformation):
         graph = model.graph
         graph_modified = False
         for n in graph.node:
-            if n.op_type == "Flatten":  # re-add reshape
-                # shape = model.get_initializer(n.input[1])
-                # if (shape == [1, -1]).all():
-                producer = model.find_producer(n.input[0])
-                if _is_fpgadataflow_node(producer) is True:
-                    consumer = model.find_consumer(n.output[0])
-                    if _is_fpgadataflow_node(consumer) is True:
-                        graph_modified = True
-                        consumer.input[0] = n.input[0]
-                        graph.node.remove(n)
-                elif producer.op_type == "Transpose":
-                    transp_node = producer
-
-                    # check if transpose converts NHWC to NCHW
-                    perms = list(get_by_name(transp_node.attribute, "perm").ints)
-                    if perms == [0, 3, 1, 2]:
-
-                        producer = model.find_producer(transp_node.input[0])
-
-                        if _is_fpgadataflow_node(producer) is True:
-                            consumer = model.find_consumer(n.output[0])
-                            if _is_fpgadataflow_node(consumer) is True:
+            # also support implicit flatten via reshape, e.g. reshape(1,-1)
+            if n.op_type == "Flatten" or n.op_type == "Reshape":
+                ishape = model.get_tensor_shape(n.input[0])
+                oshape = model.get_tensor_shape(n.output[0])
+                if len(oshape) == 2 and ishape[0] == oshape[0]:
+                    producer = model.find_producer(n.input[0])
+                    if _is_fpgadataflow_node(producer) is True:
+                        # standalone flatten, remove
+                        consumer = model.find_consumer(n.output[0])
+                        if _is_fpgadataflow_node(consumer) is True:
+                            graph_modified = True
+                            consumer.input[0] = n.input[0]
+                            graph.node.remove(n)
+                    elif producer.op_type == "Transpose":
+                        # transpose + flatten, absorb into following node
+                        transp_node = producer
+                        # check if transpose converts NHWC to NCHW
+                        perms = list(get_by_name(transp_node.attribute, "perm").ints)
+                        if perms == [0, 3, 1, 2]:
+                            producer = model.find_producer(transp_node.input[0])
+                            if _is_fpgadataflow_node(producer) is True:
+                                consumer = model.find_consumer(n.output[0])
                                 if consumer.op_type == "StreamingFCLayer_Batch":
                                     fc_inst = getCustomOp(consumer)
                                     mw = fc_inst.get_nodeattr("MW")
@@ -55,35 +55,34 @@ class RemoveCNVtoFCFlatten(Transformation):
                                     (b, h, w, c) = model.get_tensor_shape(
                                         transp_node.input[0]
                                     )
-                                    # absorb transpose into weight matrix, allowing FC layer to operate on the NHWC input
+                                    # absorb transpose into weight matrix,
+                                    # allowing FC layer to operate on the NHWC input
                                     W = model.get_initializer(consumer.input[1])
                                     assert (
                                         W is not None
                                     ), "Initializer for matmul weights is not set."
-                                    print("fc weights before")
-                                    print(W.shape)
-                                    print(W)
-
+                                    # print("fc weights before")
+                                    # print(W.shape)
+                                    # print(W)
                                     W_new = W.reshape(c, h, w, mh)
                                     W_new = W_new.transpose((1, 2, 0, 3))
                                     W_new = W_new.reshape(mw, mh)
-
-                                    print("fc weights after")
-                                    print(W_new.shape)
-                                    print(W_new)
-
                                     model.set_initializer(consumer.input[1], W_new)
-
+                                    # print("fc weights after")
+                                    # print(W_new.shape)
+                                    # print(W_new)
                                     # remove transpose & flatten nodes
-                                    graph_modified = True
                                     consumer.input[0] = transp_node.input[0]
                                     graph.node.remove(n)
                                     graph.node.remove(transp_node)
+                                    graph_modified = True
                                 else:
                                     warnings.warn(
-                                        "Could not absorb transpose into node behind flatten layer"
+                                        "Could not absorb transpose->flatten into subsequent node"
                                     )
-                    else:
-                        warnings.warn("Unsupported transpose node before flatten layer")
+                        else:
+                            warnings.warn(
+                                "Unsupported transpose node before flatten layer"
+                            )
 
         return (model, graph_modified)
diff --git a/src/finn/transformation/streamline/absorb.py b/src/finn/transformation/streamline/absorb.py
index 979a69163..e8082eb82 100644
--- a/src/finn/transformation/streamline/absorb.py
+++ b/src/finn/transformation/streamline/absorb.py
@@ -309,7 +309,8 @@ class Absorb1BitMulIntoConv(Transformation):
 
 class AbsorbTransposeIntoMultiThreshold(Transformation):
     """Change (NHWCTranpose -> MultiThreshold -> NCHWTranspose) to (MultiThreshold)
-    with NHWC mode."""
+    with NHWC mode. For (NHWCTranpose -> MultiThreshold -> Flatten), move Transpose
+    past MultiThreshold to prepare for the RemoveCNVtoFCFlatten() transformation."""
 
     def apply(self, model):
         graph = model.graph
@@ -338,35 +339,36 @@ class AbsorbTransposeIntoMultiThreshold(Transformation):
                                 graph.node.remove(n)
                                 graph.node.remove(final_t_cand)
                                 graph_modified = True
-                        elif final_t_cand.op_type == "Flatten":  # TODO: re-add reshape
-                            # oshape = model.get_tensor_shape(final_t_cand.output[0])
-                            # if len(oshape) == 2:
-                            # transition to FC part, can still use NHWC
-                            mt = getCustomOp(mt_cand)
-                            mt.set_nodeattr("data_layout", "NHWC")
-                            # get rid of first tranpose node
-                            mt_cand.input[0] = n.input[0]
-                            # fix output shape for MultiThreshold
-                            mt_ishape = model.get_tensor_shape(mt_cand.input[0])
-                            (b, h, w, c) = mt_ishape
-                            # assert (
-                            #    h == 1 and w == 1
-                            # ), """Untested spatial dim
-                            # in conv->fc transition, proceed with caution!"""
-                            model.set_tensor_shape(mt_cand.output[0], mt_ishape)
-
-                            graph.node.remove(n)
-                            transpose_output = model.make_new_valueinfo_name()
-                            new_transpose = oh.make_node(
-                                "Transpose",
-                                [mt_cand.output[0]],
-                                [transpose_output],
-                                perm=[0, 3, 1, 2],
-                            )
-                            graph.node.insert(node_ind + 1, new_transpose)
-                            final_t_cand.input[0] = transpose_output
+                        # also support implicit flatten via reshape, e.g. reshape(1,-1)
+                        elif (
+                            final_t_cand.op_type == "Flatten"
+                            or final_t_cand.op_type == "Reshape"
+                        ):
+                            ishape = model.get_tensor_shape(final_t_cand.input[0])
+                            oshape = model.get_tensor_shape(final_t_cand.output[0])
+                            if len(oshape) == 2 and ishape[0] == oshape[0]:
+                                # transition to FC part, can still use NHWC
+                                mt = getCustomOp(mt_cand)
+                                mt.set_nodeattr("data_layout", "NHWC")
+                                # get rid of first tranpose node
+                                mt_cand.input[0] = n.input[0]
+                                graph.node.remove(n)
+                                # fix output shape for MultiThreshold
+                                mt_ishape = model.get_tensor_shape(mt_cand.input[0])
+                                (b, h, w, c) = mt_ishape
+                                model.set_tensor_shape(mt_cand.output[0], mt_ishape)
+                                # re-insert Transpose behind MultiThreshold
+                                transpose_output = model.make_new_valueinfo_name()
+                                new_transpose = oh.make_node(
+                                    "Transpose",
+                                    [mt_cand.output[0]],
+                                    [transpose_output],
+                                    perm=[0, 3, 1, 2],
+                                )
+                                graph.node.insert(node_ind + 1, new_transpose)
+                                final_t_cand.input[0] = transpose_output
 
-                            graph_modified = True
+                                graph_modified = True
         if graph_modified:
             model = model.transform(InferDataTypes())
         return (model, graph_modified)
diff --git a/tests/fpgadataflow/test_convert_to_hls_conv_fc_transition.py b/tests/fpgadataflow/test_convert_to_hls_conv_fc_transition.py
index f0b9758d9..5d4d5a2a3 100755
--- a/tests/fpgadataflow/test_convert_to_hls_conv_fc_transition.py
+++ b/tests/fpgadataflow/test_convert_to_hls_conv_fc_transition.py
@@ -81,47 +81,56 @@ def get_multithreshold_rand_params(channels, num_of_thres, seed=None):
     return thres
 
 
-# conv_config  kernel_size,stride, pad
-
-
-# @pytest.mark.parametrize(
-#    "conv_config", [(1, 2, 0), (1, 3, 0), (3, 2, 1), (3, 1, 0), (3, 1, 1), (5, 2, 1)]
-# )
-# @pytest.mark.parametrize("depthwise", [False, True])
-# @pytest.mark.parametrize("exec_mode", ["cppsim", "rtlsim"])
-@pytest.mark.parametrize("conv_config", [(3, 1, 1)])
-@pytest.mark.parametrize("depthwise", [False])
-@pytest.mark.parametrize("exec_mode", ["cppsim"])
+# conv_config: input_shape, kernel_shape, stride, pad
+@pytest.mark.parametrize(
+    "conv_config",
+    [
+        ((6, 6), (3, 3), (1, 1), (1, 1)),
+        ((12, 1), (3, 1), (1, 1), (1, 0)),
+        ((1, 15), (1, 5), (1, 1), (0, 2)),
+    ],
+)
+@pytest.mark.parametrize("depthwise", [False, True])
+@pytest.mark.parametrize("use_reshape", [False, True])
+@pytest.mark.parametrize("exec_mode", ["cppsim", "rtlsim"])
 @pytest.mark.slow
 @pytest.mark.vivado
-def test_convert_to_hls_conv_fc_transition(conv_config, depthwise, exec_mode):
-    kernel_size, stride, pad = conv_config
+def test_convert_to_hls_conv_fc_transition(
+    conv_config, depthwise, use_reshape, exec_mode
+):
     np.random.seed(0)
     idt = DataType.UINT4
     odt = DataType.UINT4
 
-    in_feature_dim = 2
+    input_shape, kernel_shape, stride, pad = conv_config
+    kernel_size_h, kernel_size_w = kernel_shape
+    input_size_h, input_size_w = input_shape
+    stride_h, stride_w = stride
+    pad_h, pad_w = pad
+
     in_chn = 4
-    fc_filters = 8
+    fc_filters = 16
 
     if depthwise is True:
         group = out_chn = in_chn
-        conv_param_shape = [out_chn, 1, kernel_size, kernel_size]
+        conv_param_shape = [out_chn, 1, kernel_size_h, kernel_size_w]
     else:
         group = 1
-        out_chn = 4
-        conv_param_shape = [out_chn, in_chn, kernel_size, kernel_size]
+        out_chn = 8
+        conv_param_shape = [out_chn, in_chn, kernel_size_h, kernel_size_w]
 
-    total_pad = 2 * pad
-    out_feature_dim = compute_conv_output_dim(
-        in_feature_dim, kernel_size, stride, total_pad
+    output_size_h = compute_conv_output_dim(
+        input_size_h, kernel_size_h, stride_h, 2 * pad_h
+    )
+    output_size_w = compute_conv_output_dim(
+        input_size_w, kernel_size_w, stride_w, 2 * pad_w
     )
 
-    input_shape = [1, in_chn, in_feature_dim, in_feature_dim]
-    conv_output_shape = [1, out_chn, out_feature_dim, out_feature_dim]
+    input_shape = [1, in_chn, input_size_h, input_size_w]
+    conv_output_shape = [1, out_chn, output_size_h, output_size_w]
     output_shape = [1, fc_filters]
 
-    fc_param_shape = [out_chn * out_feature_dim * out_feature_dim, fc_filters]
+    fc_param_shape = [out_chn * output_size_h * output_size_w, fc_filters]
 
     conv_weight_dt = DataType.INT4
     fc_weight_dt = DataType.INT4
@@ -129,9 +138,9 @@ def test_convert_to_hls_conv_fc_transition(conv_config, depthwise, exec_mode):
     conv_config = {}
     conv_config["dilations"] = [1, 1]
     conv_config["group"] = group
-    conv_config["kernel_shape"] = [kernel_size, kernel_size]
-    conv_config["pads"] = [pad, pad, pad, pad]
-    conv_config["strides"] = [stride, stride]
+    conv_config["kernel_shape"] = [kernel_size_h, kernel_size_w]
+    conv_config["pads"] = [pad_h, pad_w, pad_h, pad_w]
+    conv_config["strides"] = [stride_h, stride_w]
 
     global_in = helper.make_tensor_value_info(
         "global_in", TensorProto.FLOAT, input_shape
@@ -150,8 +159,18 @@ def test_convert_to_hls_conv_fc_transition(conv_config, depthwise, exec_mode):
         helper.make_tensor_value_info(
             "thres2_param", TensorProto.FLOAT, (fc_filters, 15)
         ),
+        helper.make_tensor_value_info("reshape_shape", TensorProto.INT64, []),
     ]
 
+    if use_reshape:
+        flatten_node = helper.make_node(
+            "Reshape", ["thres1_out", "reshape_shape"], ["flatten_out"]
+        )
+    else:
+        flatten_node = helper.make_node(
+            "Flatten", ["thres1_out"], ["flatten_out"], axis=1
+        )
+
     modelproto = helper.make_model(
         helper.make_graph(
             name="test",
@@ -168,10 +187,8 @@ def test_convert_to_hls_conv_fc_transition(conv_config, depthwise, exec_mode):
                     ["thres1_out"],
                     domain="finn.custom_op.general",
                     out_dtype="UINT4",
-                    # out_bias=-7,
-                    # out_scale=1.0
                 ),
-                helper.make_node("Flatten", ["thres1_out"], ["flatten_out"], axis=1),
+                flatten_node,
                 helper.make_node(
                     "MatMul", ["flatten_out", "matmul_param"], ["matmul_out"]
                 ),
@@ -181,8 +198,6 @@ def test_convert_to_hls_conv_fc_transition(conv_config, depthwise, exec_mode):
                     ["global_out"],
                     domain="finn.custom_op.general",
                     out_dtype="UINT4",
-                    # out_bias=-7,
-                    # out_scale=1.0
                 ),
             ],
         )
@@ -196,9 +211,7 @@ def test_convert_to_hls_conv_fc_transition(conv_config, depthwise, exec_mode):
     model.set_tensor_datatype("matmul_param", fc_weight_dt)
     model.set_tensor_datatype("thres1_param", DataType.INT32)
     model.set_tensor_datatype("thres2_param", DataType.INT32)
-    model.set_tensor_datatype(
-        "flatten_out", DataType.UINT4
-    )  # TODO: not inferred automatically (FLOAT32)
+
     model.set_initializer(
         "conv_param", gen_finn_dt_tensor(conv_weight_dt, conv_param_shape)
     )
@@ -211,6 +224,7 @@ def test_convert_to_hls_conv_fc_transition(conv_config, depthwise, exec_mode):
     model.set_initializer(
         "matmul_param", gen_finn_dt_tensor(fc_weight_dt, fc_param_shape)
     )
+    model.set_initializer("reshape_shape", np.array([1, -1]))
 
     model = model.transform(InferShapes())
     model = model.transform(InferDataTypes())
@@ -218,81 +232,30 @@ def test_convert_to_hls_conv_fc_transition(conv_config, depthwise, exec_mode):
 
     model.save("testmodel_in.onnx")
 
-    x = gen_finn_dt_tensor(idt, input_shape)
-    inp_dict = {model.graph.input[0].name: x}
-    output = oxe.execute_onnx(model, inp_dict)
-    print(output)
-
-    # streamlining step
-    model = model.transform(MoveScalarLinearPastInvariants())
-    model = model.transform(Streamline())
-    model = model.transform(LowerConvsToMatMul())
-    model = model.transform(MakeMaxPoolNHWC())
-    model = model.transform(absorb.AbsorbTransposeIntoMultiThreshold())
-    model = model.transform(Streamline())
-
-    model = model.transform(InferDataLayouts())
-    model = model.transform(RemoveUnusedTensors())
-
-    model.save("testmodel_streamlined.onnx")
-
-    output = oxe.execute_onnx(model, inp_dict)
-    print(output)
+    # streamlining
+    new_model = model.transform(MoveScalarLinearPastInvariants())
+    new_model = new_model.transform(Streamline())
+    new_model = new_model.transform(LowerConvsToMatMul())
+    new_model = new_model.transform(absorb.AbsorbTransposeIntoMultiThreshold())
+    new_model = new_model.transform(Streamline())
 
-    # convert_to_hls step
-    model = model.transform(to_hls.InferQuantizedStreamingFCLayer())
-    model = model.transform(to_hls.InferThresholdingLayer())
-    model = model.transform(to_hls.InferConvInpGen())
-    model = model.transform(to_hls.InferStreamingMaxPool())
-    model = model.transform(RemoveCNVtoFCFlatten())
-    model = model.transform(absorb.AbsorbConsecutiveTransposes())
+    new_model = new_model.transform(InferDataLayouts())
+    new_model = new_model.transform(RemoveUnusedTensors())
 
-    model = model.transform(GiveUniqueNodeNames())
-    model = model.transform(InferDataLayouts())
-
-    if exec_mode == "cppsim":
-        model = model.transform(PrepareCppSim())
-        model = model.transform(CompileCppSim())
-        model = model.transform(SetExecMode("cppsim"))
-    elif exec_mode == "rtlsim":
-        model = model.transform(SetExecMode("rtlsim"))
-        model = model.transform(GiveUniqueNodeNames())
-        model = model.transform(PrepareIP("xc7z020clg400-1", 5))
-        model = model.transform(HLSSynthIP())
-        model = model.transform(PrepareRTLSim())
-    else:
-        raise Exception("Unknown exec_mode")
-
-    model.save("testmodel_hls.onnx")
-
-    output = oxe.execute_onnx(model, inp_dict)
-    print(output)
-
-    model_orig = ModelWrapper("testmodel_in.onnx")
-    model_hls = ModelWrapper("testmodel_hls.onnx")
+    new_model.save("testmodel_streamlined.onnx")
 
-    assert oxe.compare_execution(model_orig, model_hls, inp_dict)
-
-
-"""
-    new_model = model.transform(LowerConvsToMatMul())
-    new_model = new_model.transform(to_hls.InferConvInpGen())
+    # convert_to_hls
     if depthwise is True:
         new_model = new_model.transform(to_hls.InferVVAU())
-    else:
-        new_model = new_model.transform(to_hls.InferQuantizedStreamingFCLayer())
-        fc_node = new_model.get_nodes_by_op_type("StreamingFCLayer_Batch")[0]
-        fc_inst = getCustomOp(fc_node)
-        mw = fc_inst.get_nodeattr("MW")
-        mh = fc_inst.get_nodeattr("MH")
-        pe_cands = list(filter(lambda x: mh % x == 0, range(2, mh + 1)))
-        simd_cands = list(filter(lambda x: mw % x == 0, range(2, mw + 1)))
-        fc_inst.set_nodeattr("PE", pe_cands[0])
-        fc_inst.set_nodeattr("SIMD", simd_cands[0])
+    new_model = new_model.transform(to_hls.InferQuantizedStreamingFCLayer())
+    new_model = new_model.transform(to_hls.InferThresholdingLayer())
+    new_model = new_model.transform(to_hls.InferConvInpGen())
+    new_model = new_model.transform(to_hls.InferStreamingMaxPool())
+    new_model = new_model.transform(RemoveCNVtoFCFlatten())
+    new_model = new_model.transform(absorb.AbsorbConsecutiveTransposes())
 
     new_model = new_model.transform(GiveUniqueNodeNames())
-    new_model = new_model.transform(InferShapes())
-    new_model = new_model.transform(InferDataTypes())
+    new_model = new_model.transform(InferDataLayouts())
 
     if exec_mode == "cppsim":
         new_model = new_model.transform(PrepareCppSim())
@@ -307,31 +270,16 @@ def test_convert_to_hls_conv_fc_transition(conv_config, depthwise, exec_mode):
     else:
         raise Exception("Unknown exec_mode")
 
+    new_model.save("testmodel_hls.onnx")
+
+    # check for correct execution
     x = gen_finn_dt_tensor(idt, input_shape)
     inp_dict = {model.graph.input[0].name: x}
     assert oxe.compare_execution(model, new_model, inp_dict)
-    if kernel_size == 1 and stride > 1 and pad == 0:
-        assert new_model.graph.node[1].op_type == "DownSampler"
-        if exec_mode == "rtlsim":
-            node = new_model.get_nodes_by_op_type("DownSampler")[0]
-            inst = getCustomOp(node)
-            cycles_rtlsim = inst.get_nodeattr("cycles_rtlsim")
-            exp_cycles_dict = new_model.analysis(exp_cycles_per_layer)
-            exp_cycles = exp_cycles_dict[node.name]
-            assert np.isclose(exp_cycles, cycles_rtlsim, atol=11)
-            assert exp_cycles != 0
-
-    if pad == 1:
-        padding_node = new_model.get_nodes_by_op_type("FMPadding_Batch")[0]
-        padding_inst = getCustomOp(padding_node)
-        assert padding_inst.get_nodeattr("SIMD") == in_chn
-
-    if depthwise is True and exec_mode == "rtlsim":
-        node = new_model.get_nodes_by_op_type("Vector_Vector_Activate_Batch")[0]
-        inst = getCustomOp(node)
-        cycles_rtlsim = inst.get_nodeattr("cycles_rtlsim")
-        exp_cycles_dict = new_model.analysis(exp_cycles_per_layer)
-        exp_cycles = exp_cycles_dict[node.name]
-        assert np.isclose(exp_cycles, cycles_rtlsim, atol=11)
-        assert exp_cycles != 0
- """
+
+    num_transpose = len(new_model.get_nodes_by_op_type("Transpose"))
+    num_flatten = len(new_model.get_nodes_by_op_type("Flatten"))
+    num_reshape = len(new_model.get_nodes_by_op_type("Reshape"))
+
+    # check if transpose->flatten was removed
+    assert num_transpose == 1 and num_flatten == 0 and num_reshape == 0
-- 
GitLab