diff --git a/src/finn/transformation/fpgadataflow/insert_dwc.py b/src/finn/transformation/fpgadataflow/insert_dwc.py index e26e92391edd8ac420e89c72fb34c5554c601967..aef793477168028d398c697f958c4cc729ba4ec0 100644 --- a/src/finn/transformation/fpgadataflow/insert_dwc.py +++ b/src/finn/transformation/fpgadataflow/insert_dwc.py @@ -56,7 +56,20 @@ class InsertDWC(Transformation): n0 = getCustomOp(n) n1 = getCustomOp(consumer) n0_out_shape = n0.get_folded_output_shape() - n1_in_shape = n1.get_folded_input_shape() + + # If FC and external mem, it could be connected to input 1 + if (consumer.op_type == "StreamingFCLayer_Batch" and + n1.get_nodeattr("mem_mode") == "external"): + # get input idx + in_idx = None + for idx, n_input in enumerate(consumer.input): + if n_output == n_input: + in_idx = idx + assert in_idx is not None,"Malformed model" + n1_in_shape = n1.get_folded_input_shape(in_idx) + else: + n1_in_shape = n1.get_folded_input_shape() + if n0_out_shape[-1] != n1_in_shape[-1]: graph_modified = True # determine dwc inwidth