diff --git a/src/finn/transformation/fpgadataflow/create_stitched_ip.py b/src/finn/transformation/fpgadataflow/create_stitched_ip.py index 327c7867fe30485f6df51d5e98dcbbaceea04cd8..04bf054ea3be3d0dc4bef0e0445f9ee99095848d 100644 --- a/src/finn/transformation/fpgadataflow/create_stitched_ip.py +++ b/src/finn/transformation/fpgadataflow/create_stitched_ip.py @@ -223,8 +223,8 @@ class CreateStitchedIP(Transformation): behavior. It is strongly recommended to insert FIFOs prior to calling CreateStitchedIP.""" ) - # ensure that all nodes are fpgadataflow, and that IPs are generated for node in model.graph.node: + # ensure that all nodes are fpgadataflow, and that IPs are generated assert is_fpgadataflow_node( node ), "All nodes must be FINN fpgadataflow nodes." @@ -236,9 +236,7 @@ class CreateStitchedIP(Transformation): self.connect_clk_rst(node) self.connect_axi(node) for i in range(len(node.input)): - if is_external_input(model, node, i): - self.connect_s_axis_external(node, idx=i) - else: + if not is_external_input(model, node, i): producer = model.find_producer(node.input[i]) if producer is None: continue @@ -254,8 +252,25 @@ class CreateStitchedIP(Transformation): "[get_bd_intf_pins %s/%s]" % (producer.name, src_intf_name, node.name, dst_intf_name) ) + + # process external inputs and outputs in top-level graph input order + for input in model.graph.input: + inp_name = input.name + inp_cons = model.find_consumers(inp_name) + assert inp_cons is not None, "No consumer for input " + inp_name + assert len(inp_cons) == 1, "Multiple consumers for input " + inp_name + node = inp_cons[0] + node_inst = getCustomOp(node) + for i in range(len(node.input)): + if node.input[i] == inp_name: + self.connect_s_axis_external(node, idx=i) + for output in model.graph.output: + out_name = output.name + node = model.find_producer(out_name) + assert node is not None, "No producer for output " + out_name + node_inst = getCustomOp(node) for i in range(len(node.output)): - if is_external_output(model, node, i): + if node.output[i] == out_name: self.connect_m_axis_external(node, idx=i) # create a temporary folder for the project