From ea3522badc31faf32d6b556f91ac10ca2680bb11 Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <yamanu@xilinx.com>
Date: Thu, 28 Oct 2021 22:07:42 +0200
Subject: [PATCH] [InsertDWC] handle StreamingConcat correctly

---
 .../transformation/fpgadataflow/insert_dwc.py | 20 +++++++++----------
 1 file changed, 10 insertions(+), 10 deletions(-)

diff --git a/src/finn/transformation/fpgadataflow/insert_dwc.py b/src/finn/transformation/fpgadataflow/insert_dwc.py
index 58efe65eb..4a0d0a89c 100644
--- a/src/finn/transformation/fpgadataflow/insert_dwc.py
+++ b/src/finn/transformation/fpgadataflow/insert_dwc.py
@@ -1,4 +1,3 @@
-import warnings
 from onnx import TensorProto
 from onnx import helper as oh
 
@@ -48,23 +47,23 @@ class InsertDWC(Transformation):
                     consumers = model.find_consumers(output_name)
                     if consumers is None:
                         continue
-                    if len(consumers) > 1:
-                        warnings.warn(
-                            n.name
-                            + ": HLS node with fan-out higher than 1 cannot be stitched"
-                        )
-
+                    assert len(consumers) == 1, (
+                        n.name
+                        + ": HLS node with fan-out higher than 1 cannot be stitched"
+                    )
                     consumer = consumers[0]
                     if _suitable_node(consumer) is True:
                         n0 = getCustomOp(n)
                         n1 = getCustomOp(consumer)
                         n0_out_shape = n0.get_folded_output_shape()
-
-                        # If FC and external mem, it could be connected to input 1
+                        # in some special cases, we need to get folded shapes of
+                        # non-default inputs for the consumer
+                        # - if FC and external mem, it could be connected to input 1
+                        # - if concat, could be connected to any input
                         if (
                             consumer.op_type == "StreamingFCLayer_Batch"
                             and n1.get_nodeattr("mem_mode") == "external"
-                        ):
+                        ) or (consumer.op_type == "StreamingConcat"):
                             # get input idx
                             in_idx = None
                             for idx, n_input in enumerate(consumer.input):
@@ -73,6 +72,7 @@ class InsertDWC(Transformation):
                             assert in_idx is not None, "Malformed model"
                             n1_in_shape = n1.get_folded_input_shape(in_idx)
                         else:
+                            # use default folded input shape
                             n1_in_shape = n1.get_folded_input_shape()
 
                         if n0_out_shape[-1] != n1_in_shape[-1]:
-- 
GitLab