From ef902db989605db79d1d7ee4a1408a0754e581db Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu <yamanu@xilinx.com> Date: Thu, 21 Apr 2022 09:43:47 +0100 Subject: [PATCH] [Refactor] account for changes in find_consumers behavior --- src/finn/transformation/fpgadataflow/create_stitched_ip.py | 6 ++++-- src/finn/transformation/fpgadataflow/insert_dwc.py | 2 +- src/finn/transformation/fpgadataflow/insert_fifo.py | 2 +- src/finn/transformation/fpgadataflow/insert_tlastmarker.py | 2 +- src/finn/transformation/fpgadataflow/make_zynq_proj.py | 6 ++++-- src/finn/transformation/fpgadataflow/set_fifo_depths.py | 2 +- src/finn/transformation/fpgadataflow/vitis_build.py | 4 +++- src/finn/transformation/qonnx/fold_quant_weights.py | 5 ++++- src/finn/transformation/streamline/absorb.py | 7 +++---- src/finn/transformation/streamline/reorder.py | 1 + 10 files changed, 23 insertions(+), 14 deletions(-) diff --git a/src/finn/transformation/fpgadataflow/create_stitched_ip.py b/src/finn/transformation/fpgadataflow/create_stitched_ip.py index c88c65fc8..933cb2757 100644 --- a/src/finn/transformation/fpgadataflow/create_stitched_ip.py +++ b/src/finn/transformation/fpgadataflow/create_stitched_ip.py @@ -64,7 +64,9 @@ def is_external_output(model, node, i): # indicate whether output i of node should be made external # True only if output is unconnected consumers = model.find_consumers(node.output[i]) - if consumers is None: + if consumers == []: + # TODO should ideally check if tensor is in top-level + # outputs return True return False @@ -270,7 +272,7 @@ class CreateStitchedIP(Transformation): for input in model.graph.input: inp_name = input.name inp_cons = model.find_consumers(inp_name) - assert inp_cons is not None, "No consumer for input " + inp_name + assert inp_cons != [], "No consumer for input " + inp_name assert len(inp_cons) == 1, "Multiple consumers for input " + inp_name node = inp_cons[0] node_inst = getCustomOp(node) diff --git a/src/finn/transformation/fpgadataflow/insert_dwc.py b/src/finn/transformation/fpgadataflow/insert_dwc.py index 4a0d0a89c..afc889f5b 100644 --- a/src/finn/transformation/fpgadataflow/insert_dwc.py +++ b/src/finn/transformation/fpgadataflow/insert_dwc.py @@ -45,7 +45,7 @@ class InsertDWC(Transformation): if _suitable_node(n): for output_name in n.output: consumers = model.find_consumers(output_name) - if consumers is None: + if consumers == []: continue assert len(consumers) == 1, ( n.name diff --git a/src/finn/transformation/fpgadataflow/insert_fifo.py b/src/finn/transformation/fpgadataflow/insert_fifo.py index b5ae2da47..266138490 100644 --- a/src/finn/transformation/fpgadataflow/insert_fifo.py +++ b/src/finn/transformation/fpgadataflow/insert_fifo.py @@ -62,7 +62,7 @@ class InsertFIFO(Transformation): if _suitable_node(first_node): for n_output in first_node.output: consumers = model.find_consumers(n_output) - if consumers is None: + if consumers == []: continue if len(consumers) > 1: warnings.warn( diff --git a/src/finn/transformation/fpgadataflow/insert_tlastmarker.py b/src/finn/transformation/fpgadataflow/insert_tlastmarker.py index 34cb61346..a33cee464 100644 --- a/src/finn/transformation/fpgadataflow/insert_tlastmarker.py +++ b/src/finn/transformation/fpgadataflow/insert_tlastmarker.py @@ -97,7 +97,7 @@ class InsertTLastMarker(Transformation): first_node = model.find_consumers(graph_in_name) # skip if no consumers (this may be the case for unused initializers) # TODO: fix this with a cleanup transform - if first_node is None: + if first_node == []: continue assert len(first_node) == 1, "Input fans out to multiple nodes" first_node = first_node[0] diff --git a/src/finn/transformation/fpgadataflow/make_zynq_proj.py b/src/finn/transformation/fpgadataflow/make_zynq_proj.py index 80ce8f016..da620766a 100644 --- a/src/finn/transformation/fpgadataflow/make_zynq_proj.py +++ b/src/finn/transformation/fpgadataflow/make_zynq_proj.py @@ -152,11 +152,13 @@ class MakeZYNQProject(Transformation): # define kernel instances # name kernels connected to graph inputs as idmaxx # name kernels connected to graph outputs as odmaxx - if producer is None or consumer is None: + if (producer is None) or (consumer == []): + # TODO not a good way of checking for external inp&out + # should look at the list of top-level in/out instead if producer is None: instance_names[node.name] = "idma" + str(idma_idx) idma_idx += 1 - elif consumer is None: + elif consumer == []: instance_names[node.name] = "odma" + str(odma_idx) odma_idx += 1 config.append( diff --git a/src/finn/transformation/fpgadataflow/set_fifo_depths.py b/src/finn/transformation/fpgadataflow/set_fifo_depths.py index ce7cf7bc5..28f74b529 100644 --- a/src/finn/transformation/fpgadataflow/set_fifo_depths.py +++ b/src/finn/transformation/fpgadataflow/set_fifo_depths.py @@ -99,7 +99,7 @@ class RemoveShallowFIFOs(Transformation): # bypass shallow fifos shallow_fifos.append(node) consumers = model.find_consumers(node.output[0]) - if consumers is None: + if consumers == []: producer = model.find_producer(node.input[0]) for idx, inp in enumerate(producer.output): if inp == node.input[0]: diff --git a/src/finn/transformation/fpgadataflow/vitis_build.py b/src/finn/transformation/fpgadataflow/vitis_build.py index 365632cd5..4dce3ab16 100644 --- a/src/finn/transformation/fpgadataflow/vitis_build.py +++ b/src/finn/transformation/fpgadataflow/vitis_build.py @@ -213,11 +213,13 @@ class VitisLink(Transformation): # define kernel instances # name kernels connected to graph inputs as idmaxx # name kernels connected to graph inputs as odmaxx + # TODO not a good way of checking for external in/out + # check top-level in/out list instead if producer is None: instance_names[node.name] = "idma" + str(idma_idx) config.append("nk=%s:1:%s" % (node.name, instance_names[node.name])) idma_idx += 1 - elif consumer is None: + elif consumer == []: instance_names[node.name] = "odma" + str(odma_idx) config.append("nk=%s:1:%s" % (node.name, instance_names[node.name])) odma_idx += 1 diff --git a/src/finn/transformation/qonnx/fold_quant_weights.py b/src/finn/transformation/qonnx/fold_quant_weights.py index c81085e7a..e8a0f418a 100644 --- a/src/finn/transformation/qonnx/fold_quant_weights.py +++ b/src/finn/transformation/qonnx/fold_quant_weights.py @@ -146,11 +146,14 @@ class FoldQuantWeights(Transformation): model.set_initializer(mul_tensor.name, scale) successor = model.find_consumers(node_out) - if successor is None: + if successor == []: raise RuntimeError( "Can only constant fold scaled Quant weights " "if a successor exists." ) + assert ( + len(successor) == 1 + ), "Only implemented for a single consumer" successor = successor[0] succ_output_name = successor.output[0] diff --git a/src/finn/transformation/streamline/absorb.py b/src/finn/transformation/streamline/absorb.py index 97ae3b51a..32e539d87 100644 --- a/src/finn/transformation/streamline/absorb.py +++ b/src/finn/transformation/streamline/absorb.py @@ -627,10 +627,9 @@ class AbsorbTransposeIntoResize(Transformation): graph.node.insert(node_ind + 1, new_transpose) # rewire nodes final_t_cands = model.find_consumers(mt_cand.output[0]) - if final_t_cands is not None: - # rewire next nodes' inputs - for final_t_cand in final_t_cands: - final_t_cand.input[0] = trans_output + # rewire next nodes' inputs + for final_t_cand in final_t_cands: + final_t_cand.input[0] = trans_output mt_cand.output[0] = trans_input graph_modified = True if graph_modified: diff --git a/src/finn/transformation/streamline/reorder.py b/src/finn/transformation/streamline/reorder.py index 9751f41d7..e922dffe3 100644 --- a/src/finn/transformation/streamline/reorder.py +++ b/src/finn/transformation/streamline/reorder.py @@ -755,6 +755,7 @@ class MoveOpPastFork(Transformation): # Check case when branches are empty and go # to the same node consumers = model.find_consumers(n.output[0]) + assert len(consumers) > 1, "Must have >1 consumer" unique_consumer = True for consum_node in consumers[1:]: if consumers[0] != consum_node: -- GitLab