From 9b776821e3a5281b7dec05857b4ac19037677d04 Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <maltanar@gmail.com>
Date: Wed, 31 Aug 2022 14:43:57 +0200
Subject: [PATCH] [FIFO] adapt FIFO sizing to new attributes

---
 .../fpgadataflow/derive_characteristic.py     | 84 +++++++++++--------
 1 file changed, 51 insertions(+), 33 deletions(-)

diff --git a/src/finn/transformation/fpgadataflow/derive_characteristic.py b/src/finn/transformation/fpgadataflow/derive_characteristic.py
index b821059bc..660476070 100644
--- a/src/finn/transformation/fpgadataflow/derive_characteristic.py
+++ b/src/finn/transformation/fpgadataflow/derive_characteristic.py
@@ -71,7 +71,7 @@ class DeriveCharacteristic(NodeLocalTransformation):
                 assert inst.get_nodeattr("rtlsim_so") != "", (
                     "rtlsim not ready for " + node.name
                 )
-                if inst.get_nodeattr("io_characteristic_period") > 0:
+                if inst.get_nodeattr("io_chrc_period") > 0:
                     warnings.warn(
                         "Skipping node %s: already has FIFO characteristic" % node.name
                     )
@@ -133,18 +133,25 @@ class DeriveCharacteristic(NodeLocalTransformation):
                     pass
 
                 # extra dicts to keep track of cycle-by-cycle transaction behavior
-                txns_in = {key: [] for (key, value) in io_dict["inputs"].items()}
-                txns_out = {key: [] for (key, value) in io_dict["outputs"].items()}
+                # note that we restrict key names to filter out weight streams etc
+                txns_in = {
+                    key: [] for (key, value) in io_dict["inputs"].items() if "in" in key
+                }
+                txns_out = {
+                    key: []
+                    for (key, value) in io_dict["outputs"].items()
+                    if "out" in key
+                }
 
                 def monitor_txns(sim_obj):
-                    for inp in io_dict["inputs"]:
+                    for inp in txns_in:
                         in_ready = _read_signal(sim, inp + sname + "TREADY") == 1
                         in_valid = _read_signal(sim, inp + sname + "TVALID") == 1
                         if in_ready and in_valid:
                             txns_in[inp].append(1)
                         else:
                             txns_in[inp].append(0)
-                    for outp in io_dict["outputs"]:
+                    for outp in txns_out:
                         if (
                             _read_signal(sim, outp + sname + "TREADY") == 1
                             and _read_signal(sim, outp + sname + "TVALID") == 1
@@ -163,17 +170,7 @@ class DeriveCharacteristic(NodeLocalTransformation):
                     hook_preclk=monitor_txns,
                 )
                 assert total_cycle_count <= self.period
-                # restrict to single input-output stream only for now
-                txns_in = txns_in["in0"]
-                txns_out = txns_out[
-                    "out" if node.op_type != "DuplicateStreams_Batch" else "out0"
-                ]
-                if len(txns_in) < self.period:
-                    pad_in = self.period - len(txns_in)
-                    txns_in += [0 for x in range(pad_in)]
-                if len(txns_out) < self.period:
-                    pad_out = self.period - len(txns_out)
-                    txns_out += [0 for x in range(pad_out)]
+                inst.set_nodeattr("io_chrc_period", self.period)
 
                 def accumulate_char_fxn(chrc):
                     p = len(chrc)
@@ -183,14 +180,36 @@ class DeriveCharacteristic(NodeLocalTransformation):
                             ret.append(chrc[0])
                         else:
                             ret.append(ret[-1] + chrc[t % p])
-                    return ret
+                    return np.asarray(ret, dtype=np.int32)
+
+                all_txns_in = np.empty((len(txns_in.keys()), 2 * self.period))
+                all_txns_out = np.empty((len(txns_out.keys()), 2 * self.period))
+                all_pad_in = []
+                all_pad_out = []
+                for in_idx, in_strm_nm in enumerate(txns_in.keys()):
+                    txn_in = txns_in[in_strm_nm]
+                    if len(txn_in) < self.period:
+                        pad_in = self.period - len(txn_in)
+                        txn_in += [0 for x in range(pad_in)]
+                    txn_in = accumulate_char_fxn(txn_in)
+                    all_txns_in[in_idx, :] = txn_in
+                    all_pad_in.append(pad_in)
+
+                for out_idx, out_strm_nm in enumerate(txns_out.keys()):
+                    txn_out = txns_out[out_strm_nm]
+                    if len(txn_out) < self.period:
+                        pad_out = self.period - len(txn_out)
+                        txn_out += [0 for x in range(pad_out)]
+                    txn_out = accumulate_char_fxn(txn_out)
+                    all_txns_out[out_idx, :] = txn_out
+                    all_pad_out.append(pad_out)
+
+                # TODO specialize here for DuplicateStreams and AddStreams
+                inst.set_nodeattr("io_chrc_in", all_txns_in)
+                inst.set_nodeattr("io_chrc_out", all_txns_out)
+                inst.set_nodeattr("io_chrc_pads_in", all_pad_in)
+                inst.set_nodeattr("io_chrc_pads_out", all_pad_out)
 
-                txns_in = accumulate_char_fxn(txns_in)
-                txns_out = accumulate_char_fxn(txns_out)
-                io_characteristic = txns_in + txns_out
-                inst.set_nodeattr("io_characteristic", io_characteristic)
-                inst.set_nodeattr("io_characteristic_period", self.period)
-                inst.set_nodeattr("io_characteristic_pads", [pad_in, pad_out])
             except KeyError:
                 # exception if op_type is not supported
                 raise Exception(
@@ -230,7 +249,7 @@ class DeriveCharacteristic(NodeLocalTransformation):
             comp_branch_first = registry.getCustomOp(comp_branch_first)
             # for DuplicateStreams, use comp_branch_first's input characterization
             # for AddStreams, use comp_branch_last's output characterization
-            period = comp_branch_first.get_nodeattr("io_characteristic_period")
+            period = comp_branch_first.get_nodeattr("io_chrc_period")
             comp_branch_first_f = comp_branch_first.get_nodeattr("io_characteristic")[
                 : 2 * period
             ]
@@ -239,9 +258,9 @@ class DeriveCharacteristic(NodeLocalTransformation):
             ]
             ds_node_inst = registry.getCustomOp(ds_node)
             addstrm_node_inst = registry.getCustomOp(addstrm_node)
-            ds_node_inst.set_nodeattr("io_characteristic_period", period)
+            ds_node_inst.set_nodeattr("io_chrc_period", period)
             ds_node_inst.set_nodeattr("io_characteristic", comp_branch_first_f * 2)
-            addstrm_node_inst.set_nodeattr("io_characteristic_period", period)
+            addstrm_node_inst.set_nodeattr("io_chrc_period", period)
             addstrm_node_inst.set_nodeattr("io_characteristic", comp_branch_last_f * 2)
             warnings.warn(
                 f"Set {ds_node.name} chrc. from {comp_branch_first.onnx_node.name}"
@@ -272,15 +291,15 @@ class DeriveFIFOSizes(NodeLocalTransformation):
                 # lookup op_type in registry of CustomOps
                 prod = registry.getCustomOp(node)
                 assert op_type != "StreamingFIFO", "Found existing FIFOs"
-                period = prod.get_nodeattr("io_characteristic_period")
-                prod_chrc = prod.get_nodeattr("io_characteristic")
+                period = prod.get_nodeattr("io_chrc_period")
+                prod_chrc = prod.get_nodeattr("io_chrc_out")[0]
                 assert (
-                    len(prod_chrc) == 4 * period
+                    len(prod_chrc) == 2 * period
                 ), "Found unexpected characterization attribute"
                 if prod.get_nodeattr("outFIFODepth") > 2:
                     # FIFO depth already set, can skip this node
                     return (node, False)
-                prod_chrc = np.asarray(prod_chrc).reshape(2, -1)[1]
+
                 # find consumers
                 model = self.ref_input_model
                 out_fifo_depths = []
@@ -292,8 +311,7 @@ class DeriveFIFOSizes(NodeLocalTransformation):
                         out_fifo_depths.append(2)
                         continue
                     cons = registry.getCustomOp(cons_node)
-                    cons_chrc = cons.get_nodeattr("io_characteristic")
-                    cons_chrc = np.asarray(cons_chrc).reshape(2, -1)[0]
+                    cons_chrc = cons.get_nodeattr("io_chrc_in")[0]
                     # find minimum phase shift satisfying the constraint
                     pshift_min = period - 1
                     for pshift_cand in range(period):
@@ -304,7 +322,7 @@ class DeriveFIFOSizes(NodeLocalTransformation):
                             break
                     prod_chrc_part = prod_chrc[pshift_min : (pshift_min + period)]
                     cons_chrc_part = cons_chrc[:period]
-                    fifo_depth = (prod_chrc_part - cons_chrc_part).max()
+                    fifo_depth = int((prod_chrc_part - cons_chrc_part).max())
                     out_fifo_depths.append(fifo_depth)
                 # set output FIFO depth for this (producing) node
                 # InsertFIFO looks at the max of (outFIFODepth, inFIFODepth)
-- 
GitLab