From ef902db989605db79d1d7ee4a1408a0754e581db Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <yamanu@xilinx.com>
Date: Thu, 21 Apr 2022 09:43:47 +0100
Subject: [PATCH] [Refactor] account for changes in find_consumers behavior

---
 src/finn/transformation/fpgadataflow/create_stitched_ip.py | 6 ++++--
 src/finn/transformation/fpgadataflow/insert_dwc.py         | 2 +-
 src/finn/transformation/fpgadataflow/insert_fifo.py        | 2 +-
 src/finn/transformation/fpgadataflow/insert_tlastmarker.py | 2 +-
 src/finn/transformation/fpgadataflow/make_zynq_proj.py     | 6 ++++--
 src/finn/transformation/fpgadataflow/set_fifo_depths.py    | 2 +-
 src/finn/transformation/fpgadataflow/vitis_build.py        | 4 +++-
 src/finn/transformation/qonnx/fold_quant_weights.py        | 5 ++++-
 src/finn/transformation/streamline/absorb.py               | 7 +++----
 src/finn/transformation/streamline/reorder.py              | 1 +
 10 files changed, 23 insertions(+), 14 deletions(-)

diff --git a/src/finn/transformation/fpgadataflow/create_stitched_ip.py b/src/finn/transformation/fpgadataflow/create_stitched_ip.py
index c88c65fc8..933cb2757 100644
--- a/src/finn/transformation/fpgadataflow/create_stitched_ip.py
+++ b/src/finn/transformation/fpgadataflow/create_stitched_ip.py
@@ -64,7 +64,9 @@ def is_external_output(model, node, i):
     # indicate whether output i of node should be made external
     # True only if output is unconnected
     consumers = model.find_consumers(node.output[i])
-    if consumers is None:
+    if consumers == []:
+        # TODO should ideally check if tensor is in top-level
+        # outputs
         return True
     return False
 
@@ -270,7 +272,7 @@ class CreateStitchedIP(Transformation):
         for input in model.graph.input:
             inp_name = input.name
             inp_cons = model.find_consumers(inp_name)
-            assert inp_cons is not None, "No consumer for input " + inp_name
+            assert inp_cons != [], "No consumer for input " + inp_name
             assert len(inp_cons) == 1, "Multiple consumers for input " + inp_name
             node = inp_cons[0]
             node_inst = getCustomOp(node)
diff --git a/src/finn/transformation/fpgadataflow/insert_dwc.py b/src/finn/transformation/fpgadataflow/insert_dwc.py
index 4a0d0a89c..afc889f5b 100644
--- a/src/finn/transformation/fpgadataflow/insert_dwc.py
+++ b/src/finn/transformation/fpgadataflow/insert_dwc.py
@@ -45,7 +45,7 @@ class InsertDWC(Transformation):
             if _suitable_node(n):
                 for output_name in n.output:
                     consumers = model.find_consumers(output_name)
-                    if consumers is None:
+                    if consumers == []:
                         continue
                     assert len(consumers) == 1, (
                         n.name
diff --git a/src/finn/transformation/fpgadataflow/insert_fifo.py b/src/finn/transformation/fpgadataflow/insert_fifo.py
index b5ae2da47..266138490 100644
--- a/src/finn/transformation/fpgadataflow/insert_fifo.py
+++ b/src/finn/transformation/fpgadataflow/insert_fifo.py
@@ -62,7 +62,7 @@ class InsertFIFO(Transformation):
             if _suitable_node(first_node):
                 for n_output in first_node.output:
                     consumers = model.find_consumers(n_output)
-                    if consumers is None:
+                    if consumers == []:
                         continue
                     if len(consumers) > 1:
                         warnings.warn(
diff --git a/src/finn/transformation/fpgadataflow/insert_tlastmarker.py b/src/finn/transformation/fpgadataflow/insert_tlastmarker.py
index 34cb61346..a33cee464 100644
--- a/src/finn/transformation/fpgadataflow/insert_tlastmarker.py
+++ b/src/finn/transformation/fpgadataflow/insert_tlastmarker.py
@@ -97,7 +97,7 @@ class InsertTLastMarker(Transformation):
                 first_node = model.find_consumers(graph_in_name)
                 # skip if no consumers (this may be the case for unused initializers)
                 # TODO: fix this with a cleanup transform
-                if first_node is None:
+                if first_node == []:
                     continue
                 assert len(first_node) == 1, "Input fans out to multiple nodes"
                 first_node = first_node[0]
diff --git a/src/finn/transformation/fpgadataflow/make_zynq_proj.py b/src/finn/transformation/fpgadataflow/make_zynq_proj.py
index 80ce8f016..da620766a 100644
--- a/src/finn/transformation/fpgadataflow/make_zynq_proj.py
+++ b/src/finn/transformation/fpgadataflow/make_zynq_proj.py
@@ -152,11 +152,13 @@ class MakeZYNQProject(Transformation):
             # define kernel instances
             # name kernels connected to graph inputs as idmaxx
             # name kernels connected to graph outputs as odmaxx
-            if producer is None or consumer is None:
+            if (producer is None) or (consumer == []):
+                # TODO not a good way of checking for external inp&out
+                # should look at the list of top-level in/out instead
                 if producer is None:
                     instance_names[node.name] = "idma" + str(idma_idx)
                     idma_idx += 1
-                elif consumer is None:
+                elif consumer == []:
                     instance_names[node.name] = "odma" + str(odma_idx)
                     odma_idx += 1
                 config.append(
diff --git a/src/finn/transformation/fpgadataflow/set_fifo_depths.py b/src/finn/transformation/fpgadataflow/set_fifo_depths.py
index ce7cf7bc5..28f74b529 100644
--- a/src/finn/transformation/fpgadataflow/set_fifo_depths.py
+++ b/src/finn/transformation/fpgadataflow/set_fifo_depths.py
@@ -99,7 +99,7 @@ class RemoveShallowFIFOs(Transformation):
                 # bypass shallow fifos
                 shallow_fifos.append(node)
                 consumers = model.find_consumers(node.output[0])
-                if consumers is None:
+                if consumers == []:
                     producer = model.find_producer(node.input[0])
                     for idx, inp in enumerate(producer.output):
                         if inp == node.input[0]:
diff --git a/src/finn/transformation/fpgadataflow/vitis_build.py b/src/finn/transformation/fpgadataflow/vitis_build.py
index 365632cd5..4dce3ab16 100644
--- a/src/finn/transformation/fpgadataflow/vitis_build.py
+++ b/src/finn/transformation/fpgadataflow/vitis_build.py
@@ -213,11 +213,13 @@ class VitisLink(Transformation):
             # define kernel instances
             # name kernels connected to graph inputs as idmaxx
             # name kernels connected to graph inputs as odmaxx
+            # TODO not a good way of checking for external in/out
+            # check top-level in/out list instead
             if producer is None:
                 instance_names[node.name] = "idma" + str(idma_idx)
                 config.append("nk=%s:1:%s" % (node.name, instance_names[node.name]))
                 idma_idx += 1
-            elif consumer is None:
+            elif consumer == []:
                 instance_names[node.name] = "odma" + str(odma_idx)
                 config.append("nk=%s:1:%s" % (node.name, instance_names[node.name]))
                 odma_idx += 1
diff --git a/src/finn/transformation/qonnx/fold_quant_weights.py b/src/finn/transformation/qonnx/fold_quant_weights.py
index c81085e7a..e8a0f418a 100644
--- a/src/finn/transformation/qonnx/fold_quant_weights.py
+++ b/src/finn/transformation/qonnx/fold_quant_weights.py
@@ -146,11 +146,14 @@ class FoldQuantWeights(Transformation):
                         model.set_initializer(mul_tensor.name, scale)
 
                         successor = model.find_consumers(node_out)
-                        if successor is None:
+                        if successor == []:
                             raise RuntimeError(
                                 "Can only constant fold scaled Quant weights "
                                 "if a successor exists."
                             )
+                        assert (
+                            len(successor) == 1
+                        ), "Only implemented for a single consumer"
                         successor = successor[0]
                         succ_output_name = successor.output[0]
 
diff --git a/src/finn/transformation/streamline/absorb.py b/src/finn/transformation/streamline/absorb.py
index 97ae3b51a..32e539d87 100644
--- a/src/finn/transformation/streamline/absorb.py
+++ b/src/finn/transformation/streamline/absorb.py
@@ -627,10 +627,9 @@ class AbsorbTransposeIntoResize(Transformation):
                         graph.node.insert(node_ind + 1, new_transpose)
                         # rewire nodes
                         final_t_cands = model.find_consumers(mt_cand.output[0])
-                        if final_t_cands is not None:
-                            # rewire next nodes' inputs
-                            for final_t_cand in final_t_cands:
-                                final_t_cand.input[0] = trans_output
+                        # rewire next nodes' inputs
+                        for final_t_cand in final_t_cands:
+                            final_t_cand.input[0] = trans_output
                         mt_cand.output[0] = trans_input
                         graph_modified = True
         if graph_modified:
diff --git a/src/finn/transformation/streamline/reorder.py b/src/finn/transformation/streamline/reorder.py
index 9751f41d7..e922dffe3 100644
--- a/src/finn/transformation/streamline/reorder.py
+++ b/src/finn/transformation/streamline/reorder.py
@@ -755,6 +755,7 @@ class MoveOpPastFork(Transformation):
                 # Check case when branches are empty and go
                 # to the same node
                 consumers = model.find_consumers(n.output[0])
+                assert len(consumers) > 1, "Must have >1 consumer"
                 unique_consumer = True
                 for consum_node in consumers[1:]:
                     if consumers[0] != consum_node:
-- 
GitLab