From 5abfe83f75acead8a25da63e4d96023bc8d78710 Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <maltanar@gmail.com>
Date: Tue, 30 Aug 2022 21:11:44 +0200
Subject: [PATCH] [FIFO] characterize Add/DuplicateStreams, optional bypass
 fix, pad

---
 .../fpgadataflow/derive_characteristic.py     | 33 +++++++++++++------
 1 file changed, 23 insertions(+), 10 deletions(-)

diff --git a/src/finn/transformation/fpgadataflow/derive_characteristic.py b/src/finn/transformation/fpgadataflow/derive_characteristic.py
index 651462066..53540e22d 100644
--- a/src/finn/transformation/fpgadataflow/derive_characteristic.py
+++ b/src/finn/transformation/fpgadataflow/derive_characteristic.py
@@ -53,9 +53,10 @@ class DeriveCharacteristic(NodeLocalTransformation):
       NodeLocalTransformation for more details.
     """
 
-    def __init__(self, period, num_workers=None):
+    def __init__(self, period, num_workers=None, manual_bypass=False):
         super().__init__(num_workers=num_workers)
         self.period = period
+        self.manual_bypass = manual_bypass
 
     def applyNodeLocal(self, node):
         op_type = node.op_type
@@ -77,8 +78,6 @@ class DeriveCharacteristic(NodeLocalTransformation):
                     return (node, False)
                 # restricted to single input and output nodes for now
                 multistream_optypes = [
-                    "AddStreams_Batch",
-                    "DuplicateStreams_Batch",
                     "StreamingConcat",
                 ]
                 if node.op_type in multistream_optypes:
@@ -107,9 +106,13 @@ class DeriveCharacteristic(NodeLocalTransformation):
                     },
                     "outputs": {"out": []},
                 }
-
-                txns_in = {"in0": []}
-                txns_out = {"out": []}
+                # override for certain fork/join nodes
+                if node.op_type == "DuplicateStreams_Batch":
+                    del io_dict["outputs"]["out"]
+                    io_dict["outputs"]["out0"] = []
+                    io_dict["outputs"]["out1"] = []
+                elif node.op_type == "AddStreams_Batch":
+                    io_dict["inputs"]["in1"] = [0 for i in range(n_inps)]
 
                 try:
                     # fill out weight stream for decoupled-mode components
@@ -123,10 +126,13 @@ class DeriveCharacteristic(NodeLocalTransformation):
                         io_dict["inputs"]["weights"] = [
                             0 for i in range(num_w_reps * n_weight_inps)
                         ]
-                        txns_in["weights"] = []
                 except AttributeError:
                     pass
 
+                # extra dicts to keep track of cycle-by-cycle transaction behavior
+                txns_in = {key: [] for (key, value) in io_dict["inputs"].items()}
+                txns_out = {key: [] for (key, value) in io_dict["outputs"].items()}
+
                 def monitor_txns(sim_obj):
                     for inp in io_dict["inputs"]:
                         in_ready = _read_signal(sim, inp + sname + "TREADY") == 1
@@ -156,11 +162,15 @@ class DeriveCharacteristic(NodeLocalTransformation):
                 assert total_cycle_count <= self.period
                 # restrict to single input-output stream only for now
                 txns_in = txns_in["in0"]
-                txns_out = txns_out["out"]
+                txns_out = txns_out[
+                    "out" if node.op_type != "DuplicateStreams_Batch" else "out0"
+                ]
                 if len(txns_in) < self.period:
-                    txns_in += [0 for x in range(self.period - len(txns_in))]
+                    pad_in = self.period - len(txns_in)
+                    txns_in += [0 for x in range(pad_in)]
                 if len(txns_out) < self.period:
-                    txns_out += [0 for x in range(self.period - len(txns_out))]
+                    pad_out = self.period - len(txns_out)
+                    txns_out += [0 for x in range(pad_out)]
 
                 def accumulate_char_fxn(chrc):
                     p = len(chrc)
@@ -177,6 +187,7 @@ class DeriveCharacteristic(NodeLocalTransformation):
                 io_characteristic = txns_in + txns_out
                 inst.set_nodeattr("io_characteristic", io_characteristic)
                 inst.set_nodeattr("io_characteristic_period", self.period)
+                inst.set_nodeattr("io_characteristic_pads", [pad_in, pad_out])
             except KeyError:
                 # exception if op_type is not supported
                 raise Exception(
@@ -186,6 +197,8 @@ class DeriveCharacteristic(NodeLocalTransformation):
 
     def apply(self, model: ModelWrapper):
         (model, run_again) = super().apply(model)
+        if not self.manual_bypass:
+            return (model, run_again)
         # apply manual fix for DuplicateStreams and AddStreams for
         # simple residual reconvergent paths with bypass
         addstrm_nodes = model.get_nodes_by_op_type("AddStreams_Batch")
-- 
GitLab