From 341a8a3ef25383396edba6d28aaa7ee3e0c51761 Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <yamanu@amd.com>
Date: Thu, 21 Jul 2022 20:11:30 +0200
Subject: [PATCH] [ToHLS] draft a InferStreamingEltwiseAbsDiff conversion

---
 .../fpgadataflow/convert_to_hls_layers.py     | 92 +++++++++++++++++++
 1 file changed, 92 insertions(+)

diff --git a/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py b/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py
index 429bc34ff..e8f6372ab 100644
--- a/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py
+++ b/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py
@@ -1671,3 +1671,95 @@ class InferConcatLayer(Transformation):
             model = model.transform(InferShapes())
             model = model.transform(InferDataTypes())
         return (model, graph_modified)
+
+
+class InferStreamingEltwiseAbsDiff(Transformation):
+    """Convert eltwise Sub -> Abs to StreamingEltwise layer
+    with AbsDiffEltwise op."""
+
+    def apply(self, model):
+        graph = model.graph
+        node_ind = 0
+        graph_modified = False
+        for node in graph.node:
+            node_ind += 1
+            if node.op_type == "Sub":
+                in0 = node.input[0]
+                in1 = node.input[1]
+                result = node.output[0]
+                in0_shape = model.get_tensor_shape(in0)
+                in1_shape = model.get_tensor_shape(in1)
+
+                # skip if different shapes on inputs
+                if in0_shape != in1_shape:
+                    continue
+
+                idt0 = model.get_tensor_datatype(in0)
+                idt1 = model.get_tensor_datatype(in1)
+
+                # skip conversion for layers with float input
+                if not (idt0.is_integer() and idt1.is_integer()):
+                    continue
+
+                # look for a downstream Abs node
+                res_consumer = model.find_consumer(result)
+                if res_consumer is None:
+                    continue
+                if res_consumer.op_type != "Abs":
+                    continue
+
+                result = res_consumer.output[0]
+
+                # check layout and convert if necessary
+                in0_layout = model.get_tensor_layout(in0)
+                in1_layout = model.get_tensor_layout(in1)
+                result_layout = model.get_tensor_layout(result)
+
+                if in0_layout == DataLayout.NCHW:
+                    in0 = nchw_to_nhwc(in0, model, node_ind)
+                    node_ind += 1
+                    in0_shape = model.get_tensor_shape(in0)
+
+                if in1_layout == DataLayout.NCHW:
+                    in1 = nchw_to_nhwc(in1, model, node_ind)
+                    node_ind += 1
+                    in1_shape = model.get_tensor_shape(in1)
+
+                # keep track of where we need to insert the HLS Op
+                # it has to be ahead of the output transform
+                insert_point = node_ind
+
+                if result_layout == DataLayout.NCHW:
+                    result = nchw_to_nhwc(result, model, node_ind, reverse=True)
+                    node_ind += 1
+
+                # now safe to assume num_channels is size of last dimension
+                num_channels = int(in0_shape[-1])
+                # create node with no parallelization first
+                pe = 1
+
+                # create and insert new Eltwise node
+                new_node = helper.make_node(
+                    "StreamingEltwise",
+                    [in0, in1],
+                    [result],
+                    domain="finn.custom_op.fpgadataflow",
+                    backend="fpgadataflow",
+                    NumChannels=num_channels,
+                    PE=pe,
+                    inputDataType0=idt0.name,
+                    inputDataType1=idt1.name,
+                    eltwiseOp="AbsDiff",
+                    numInputVectors=in0_shape[:-1],
+                    name="StreamingEltwise_" + node.name,
+                )
+                graph.node.insert(insert_point, new_node)
+                # remove old nodes
+                graph.node.remove(node)
+                graph.node.remove(res_consumer)
+                graph_modified = True
+
+        # if graph_modified:
+        # model = model.transform(InferShapes())
+        # model = model.transform(InferDataTypes())
+        return (model, graph_modified)
-- 
GitLab