From ea3522badc31faf32d6b556f91ac10ca2680bb11 Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu <yamanu@xilinx.com> Date: Thu, 28 Oct 2021 22:07:42 +0200 Subject: [PATCH] [InsertDWC] handle StreamingConcat correctly --- .../transformation/fpgadataflow/insert_dwc.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/finn/transformation/fpgadataflow/insert_dwc.py b/src/finn/transformation/fpgadataflow/insert_dwc.py index 58efe65eb..4a0d0a89c 100644 --- a/src/finn/transformation/fpgadataflow/insert_dwc.py +++ b/src/finn/transformation/fpgadataflow/insert_dwc.py @@ -1,4 +1,3 @@ -import warnings from onnx import TensorProto from onnx import helper as oh @@ -48,23 +47,23 @@ class InsertDWC(Transformation): consumers = model.find_consumers(output_name) if consumers is None: continue - if len(consumers) > 1: - warnings.warn( - n.name - + ": HLS node with fan-out higher than 1 cannot be stitched" - ) - + assert len(consumers) == 1, ( + n.name + + ": HLS node with fan-out higher than 1 cannot be stitched" + ) consumer = consumers[0] if _suitable_node(consumer) is True: n0 = getCustomOp(n) n1 = getCustomOp(consumer) n0_out_shape = n0.get_folded_output_shape() - - # If FC and external mem, it could be connected to input 1 + # in some special cases, we need to get folded shapes of + # non-default inputs for the consumer + # - if FC and external mem, it could be connected to input 1 + # - if concat, could be connected to any input if ( consumer.op_type == "StreamingFCLayer_Batch" and n1.get_nodeattr("mem_mode") == "external" - ): + ) or (consumer.op_type == "StreamingConcat"): # get input idx in_idx = None for idx, n_input in enumerate(consumer.input): @@ -73,6 +72,7 @@ class InsertDWC(Transformation): assert in_idx is not None, "Malformed model" n1_in_shape = n1.get_folded_input_shape(in_idx) else: + # use default folded input shape n1_in_shape = n1.get_folded_input_shape() if n0_out_shape[-1] != n1_in_shape[-1]: -- GitLab