From efb7057dff172ef9f7692a4e7428fc3a7f66ae64 Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <maltanar@gmail.com>
Date: Thu, 8 Oct 2020 16:07:00 +0200
Subject: [PATCH] [Refactor] move RemoveShallowFIFOs to own transformation

---
 .../fpgadataflow/set_fifo_depths.py           | 69 +++++++++++--------
 1 file changed, 42 insertions(+), 27 deletions(-)

diff --git a/src/finn/transformation/fpgadataflow/set_fifo_depths.py b/src/finn/transformation/fpgadataflow/set_fifo_depths.py
index bc795aa92..713148d7f 100644
--- a/src/finn/transformation/fpgadataflow/set_fifo_depths.py
+++ b/src/finn/transformation/fpgadataflow/set_fifo_depths.py
@@ -77,6 +77,44 @@ def optimize_depth(depth):
     return int(2 ** math.ceil(math.log2(depth)))
 
 
+class RemoveShallowFIFOs(Transformation):
+    """Remove small FIFOs as the streaming components have depth-2 FIFOs on the
+    input/outputs by default."""
+
+    # TODO add unit test
+
+    def __init__(self, shallow_threshold=2):
+        self.shallow_threshold = shallow_threshold
+
+    def apply(self, model):
+        shallow_fifos = []
+        for node in model.graph.node:
+            if (
+                node.op_type == "StreamingFIFO"
+                and getCustomOp(node).get_nodeattr("depth") <= self.shallow_threshold
+            ):
+                # bypass shallow fifos
+                shallow_fifos.append(node)
+                consumers = model.find_consumers(node.output[0])
+                if consumers is None:
+                    producer = model.find_producer(node.input[0])
+                    for idx, inp in enumerate(producer.output):
+                        if inp == node.input[0]:
+                            producer.output[idx] = node.output[0]
+                else:
+                    assert len(consumers) == 1, "Fanout detected from FIFO output"
+                    consumer = consumers[0]
+                    # set fifo input tensor as new input tensor of second node
+                    for idx, inp in enumerate(consumer.input):
+                        if inp == node.output[0]:
+                            consumer.input[idx] = node.input[0]
+        # now filter out
+        for node_to_remove in shallow_fifos:
+            model.graph.node.remove(node_to_remove)
+
+        return (model, False)
+
+
 class CapConvolutionFIFODepths(Transformation):
     """Make the size of FIFOs for convolution layers smaller where possible.
     Will be automatically called from InsertAndSetFIFODepths if the appropriate
@@ -101,6 +139,8 @@ class CapConvolutionFIFODepths(Transformation):
     than 1 row.
     """
 
+    # TODO add unit test
+
     def __init__(self, max_qsrl_depth=256):
         super().__init__()
         self.max_qsrl_depth = max_qsrl_depth
@@ -346,32 +386,7 @@ class InsertAndSetFIFODepths(Transformation):
             model = model.transform(
                 CapConvolutionFIFODepths(max_qsrl_depth=self.max_qsrl_depth)
             )
-
-        # Remove FIFOs which have depth <= 2
-        # TODO move this to own transformation
-        shallow_fifos = []
-        # First, bypass them
-        for node in model.graph.node:
-            if (
-                node.op_type == "StreamingFIFO"
-                and getCustomOp(node).get_nodeattr("depth") <= 2
-            ):
-                shallow_fifos.append(node)
-                consumers = model.find_consumers(node.output[0])
-                if consumers is None:
-                    producer = model.find_producer(node.input[0])
-                    for idx, inp in enumerate(producer.output):
-                        if inp == node.input[0]:
-                            producer.output[idx] = node.output[0]
-                else:
-                    assert len(consumers) == 1, "Fanout detected from FIFO output"
-                    consumer = consumers[0]
-                    # set fifo input tensor as new input tensor of second node
-                    for idx, inp in enumerate(consumer.input):
-                        if inp == node.output[0]:
-                            consumer.input[idx] = node.input[0]
-        # now filter out
-        for node_to_remove in shallow_fifos:
-            model.graph.node.remove(node_to_remove)
+        # remove shallow FIFOs
+        model = model.transform(RemoveShallowFIFOs())
 
         return (model, False)
-- 
GitLab