From a3553b57e208f98a2e32ea7ff5110b0e4d717d81 Mon Sep 17 00:00:00 2001
From: Tobi-Alonso <tobi.alonso@gmail.com>
Date: Tue, 27 Oct 2020 14:02:16 +0000
Subject: [PATCH] [fpgadataflow] Add suport for non linear graphs in InsertDWC
 transf (#227)

---
 .../transformation/fpgadataflow/insert_dwc.py | 94 +++++++++++--------
 1 file changed, 53 insertions(+), 41 deletions(-)

diff --git a/src/finn/transformation/fpgadataflow/insert_dwc.py b/src/finn/transformation/fpgadataflow/insert_dwc.py
index 195a005ff..b4b577605 100644
--- a/src/finn/transformation/fpgadataflow/insert_dwc.py
+++ b/src/finn/transformation/fpgadataflow/insert_dwc.py
@@ -4,6 +4,7 @@ from onnx import helper as oh
 from finn.custom_op.registry import getCustomOp
 from finn.transformation.base import Transformation
 from finn.util.fpgadataflow import is_fpgadataflow_node
+import warnings
 
 
 def _is_dwc_node(node):
@@ -40,48 +41,59 @@ class InsertDWC(Transformation):
         for n in graph.node:
             node_ind += 1
             if _suitable_node(n):
-                n_output = n.output[0]
-                consumer = model.find_consumer(n_output)
-                if _suitable_node(consumer) is True:
-                    n0 = getCustomOp(n)
-                    n1 = getCustomOp(consumer)
-                    n0_out_shape = n0.get_folded_output_shape()
-                    n1_in_shape = n1.get_folded_input_shape()
-                    if n0_out_shape[-1] != n1_in_shape[-1]:
-                        graph_modified = True
-                        # determine dwc inwidth
-                        dwc_in_width = n0.get_outstream_width()
-                        # determine dwc outwidth
-                        dwc_out_width = n1.get_instream_width()
-
-                        # determine shape for dwc
-                        dwc_shape = n0.get_normal_output_shape()
-
-                        # determine dtype for dwc
-                        dtype = n0.get_output_datatype()
-
-                        dwc_output_tensor = oh.make_tensor_value_info(
-                            model.make_new_valueinfo_name(),
-                            TensorProto.FLOAT,
-                            dwc_shape,
+                for n_output in n.output:
+                    consumers = model.find_consumers(n_output)
+                    if consumers is None:
+                        continue
+                    if len(consumers) > 1:
+                        warnings.warn(
+                            n.name
+                            + ": HLS node with fan-out higher than 1 cannot be stitched"
                         )
-                        graph.value_info.append(dwc_output_tensor)
-
-                        dwc_node = oh.make_node(
-                            "StreamingDataWidthConverter_Batch",
-                            [n_output],
-                            [dwc_output_tensor.name],
-                            domain="finn",
-                            backend="fpgadataflow",
-                            shape=dwc_shape,
-                            inWidth=dwc_in_width,
-                            outWidth=dwc_out_width,
-                            dataType=str(dtype.name),
-                        )
-                        # insert dwc
-                        graph.node.insert(node_ind + 1, dwc_node)
 
-                        # set dwc output tensor as new input tensor of second node
-                        consumer.input[0] = dwc_output_tensor.name
+                    consumer = consumers[0]
+                    if _suitable_node(consumer) is True:
+                        n0 = getCustomOp(n)
+                        n1 = getCustomOp(consumer)
+                        n0_out_shape = n0.get_folded_output_shape()
+                        n1_in_shape = n1.get_folded_input_shape()
+                        if n0_out_shape[-1] != n1_in_shape[-1]:
+                            graph_modified = True
+                            # determine dwc inwidth
+                            dwc_in_width = n0.get_outstream_width()
+                            # determine dwc outwidth
+                            dwc_out_width = n1.get_instream_width()
+
+                            # determine shape for dwc
+                            dwc_shape = n0.get_normal_output_shape()
+
+                            # determine dtype for dwc
+                            dtype = n0.get_output_datatype()
+
+                            dwc_output_tensor = oh.make_tensor_value_info(
+                                model.make_new_valueinfo_name(),
+                                TensorProto.FLOAT,
+                                dwc_shape,
+                            )
+                            graph.value_info.append(dwc_output_tensor)
+
+                            dwc_node = oh.make_node(
+                                "StreamingDataWidthConverter_Batch",
+                                [n_output],
+                                [dwc_output_tensor.name],
+                                domain="finn",
+                                backend="fpgadataflow",
+                                shape=dwc_shape,
+                                inWidth=dwc_in_width,
+                                outWidth=dwc_out_width,
+                                dataType=str(dtype.name),
+                            )
+                            # insert dwc
+                            graph.node.insert(node_ind + 1, dwc_node)
+
+                            # set dwc output tensor as new input tensor of second node
+                            for idx, inp in enumerate(consumer.input):
+                                if inp == n_output:
+                                    consumer.input[idx] = dwc_output_tensor.name
 
         return (model, graph_modified)
-- 
GitLab