From 55a8828ff89ff167dd9397c8e02aa7a63efac49c Mon Sep 17 00:00:00 2001
From: Tobi-Alonso <tobi.alonso@gmail.com>
Date: Tue, 7 Jul 2020 22:38:59 +0100
Subject: [PATCH] [FPGADataflow] Add InferDuplicateStreamsLayer

---
 .../fpgadataflow/convert_to_hls_layers.py     | 74 +++++++++++++++++++
 1 file changed, 74 insertions(+)

diff --git a/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py b/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py
index 844c21726..6fe6e97df 100644
--- a/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py
+++ b/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py
@@ -35,6 +35,7 @@ from finn.transformation import Transformation
 from finn.custom_op.registry import getCustomOp
 from finn.transformation.infer_shapes import InferShapes
 from finn.transformation.infer_datatypes import InferDataTypes
+from finn.transformation.general import SortGraph
 import finn.core.data_layout as DataLayout
 from finn.util.onnx import nchw_to_nhwc
 import warnings
@@ -780,6 +781,79 @@ class InferAddStreamsLayer(Transformation):
         return (model, graph_modified)
 
 
+class InferDuplicateStreamsLayer(Transformation):
+    """Insert a DuplicateStreams HLS layer for any tensor with fanout == 2 """
+
+    def apply(self, model):
+        graph = model.graph
+        node_ind = 0
+        graph_modified = False
+        for node in graph.node:
+            node_ind += 1
+            successors = model.find_consumers(node.output[0])
+            if successors is not None and len(successors) == 2:
+                output_tensor = node.output[0]
+
+                dt = model.get_tensor_datatype(output_tensor)
+
+                # skip conversion for layers with float input
+                if not dt.is_integer():
+                    continue
+
+                # create clone tensors
+                out_shape = model.get_tensor_shape(output_tensor)
+                out_tensor_clones = []
+                for i in range(2):
+                    clone = helper.make_tensor_value_info(
+                        model.make_new_valueinfo_name(), TensorProto.FLOAT, out_shape
+                    )
+                    model.graph.value_info.append(clone)
+                    out_tensor_clones += [clone.name]
+
+                num_ch = int(out_shape[-1])
+                vecs = out_shape[:-1]
+
+                # create node with no parallelization first
+                pe = 1
+                assert (
+                    num_ch % pe == 0
+                ), "Requirement channels divisable by PE is violated."
+
+                dup_node = helper.make_node(
+                    "DuplicateStreams_Batch",
+                    [output_tensor],
+                    out_tensor_clones,
+                    domain="finn",
+                    backend="fpgadataflow",
+                    NumChannels=num_ch,
+                    PE=pe,
+                    inputDataType=dt.name,
+                    numInputVectors=vecs,
+                )
+
+                graph.node.insert(node_ind, dup_node)
+
+                # connect successors to out tensor clone
+                clone_idx = 0
+                for successor in successors:
+                    for i, succ_input in enumerate(successor.input):
+                        if succ_input == output_tensor:
+                            successor.input[i] = out_tensor_clones[clone_idx]
+                            clone_idx += 1
+                            # if one node has multiple connections to the same output
+                            # find_direct_successors will return one node per input
+                            # so break the inner loop will result in correct behaviour
+                            break
+
+                graph_modified = True
+
+        if graph_modified:
+            model = model.transform(SortGraph())
+            model = model.transform(InferShapes())
+            model = model.transform(InferDataTypes())
+        return (model, graph_modified)
+
+
 class InferChannelwiseLinearLayer(Transformation):
     """Convert any channel-wise Add/Mul into a HLS layer."""
 
-- 
GitLab