diff --git a/src/finn/transformation/fpgadataflow/insert_fifo.py b/src/finn/transformation/fpgadataflow/insert_fifo.py index e77774df729a6e7725e7a4589d3453d8725105a1..b9222cf3eef1a8566e698b4f58a85d100f824243 100644 --- a/src/finn/transformation/fpgadataflow/insert_fifo.py +++ b/src/finn/transformation/fpgadataflow/insert_fifo.py @@ -101,8 +101,8 @@ class InsertFIFO(Transformation): for first_node in graph.node: node_ind += 1 if _suitable_node(first_node): - for n_output in first_node.output: - consumers = model.find_consumers(n_output) + for idx_out, output_name in enumerate(first_node.output): + consumers = model.find_consumers(output_name) if consumers == []: continue if len(consumers) > 1: @@ -120,12 +120,15 @@ class InsertFIFO(Transformation): # check if folded_shape of output of first node and # input of the second node is equal n1 = getCustomOp(consumer) + idx_inp = 0 for idx, inp in enumerate(consumer.input): - if inp == n_output: + if inp == output_name: if idx == 0: fld_shape_2 = n1.get_folded_input_shape() + idx_inp = 0 else: fld_shape_2 = n1.get_folded_input_shape(ind=idx) + idx_inp = idx assert _suitable_folded_shapes( fld_shape, fld_shape_2 ), """The @@ -135,8 +138,15 @@ class InsertFIFO(Transformation): # check if outFIFOdepth attribute of first node # and inFIFOdepth attribute of consumer node is equal - n0_depth = n0.get_nodeattr("outFIFODepth") - n1_depth = n1.get_nodeattr("inFIFODepth") + if idx_out == 0: + n0_depth = n0.get_nodeattr("outFIFODepth") + else: + n0_depth = n0.get_nodeattr("outFIFODepths")[idx_out] + if idx_inp == 0: + n1_depth = n1.get_nodeattr("inFIFODepth") + else: + n1_depth = n1.get_nodeattr("inFIFODepths")[idx_inp] + if n0_depth == n1_depth: fifo_depth = n0_depth elif n0_depth != n1_depth: @@ -160,7 +170,7 @@ class InsertFIFO(Transformation): ) fifo_node = oh.make_node( "StreamingFIFO", - [n_output], + [output_name], [fifo_output_tensor.name], domain="finn.custom_op.fpgadataflow", backend="fpgadataflow", @@ -174,11 +184,22 @@ class InsertFIFO(Transformation): graph.node.insert(node_ind + 1, fifo_node) # set fifo output tensor as new input tensor of second node for idx, inp in enumerate(consumer.input): - if inp == n_output: + if inp == output_name: consumer.input[idx] = fifo_output_tensor.name # ensure created FIFO depth is reflected on both sides - n0.set_nodeattr("outFIFODepth", fifo_depth) - n1.set_nodeattr("inFIFODepth", fifo_depth) + if idx_out == 0: + n0.set_nodeattr("outFIFODepth", fifo_depth) + else: + odepths = n0.get_nodeattr("outFIFODepths") + odepths[idx_out] = fifo_depth + n0.set_nodeattr("outFIFODepths", odepths) + if idx_inp == 0: + n1.set_nodeattr("inFIFODepth", fifo_depth) + else: + idepths = n1.get_nodeattr("inFIFODepths") + idepths[idx_inp] = fifo_depth + n1.set_nodeattr("inFIFODepths", idepths) + graph_modified = True if graph_modified is False: