From 5ce16895771b383e269b74c73b070424b31370fd Mon Sep 17 00:00:00 2001
From: auphelia <jakobapk@web.de>
Date: Wed, 30 Oct 2019 14:59:44 +0000
Subject: [PATCH] [Integration] Interface of execute_custom_node.py to
 onnx_exec.py finished and first draft of unit test added

---
 src/finn/core/execute_custom_node.py |  42 +++++++-
 tests/test_custom_onnx_exec.py       | 147 +++++++++++++++++++++++++++
 2 files changed, 185 insertions(+), 4 deletions(-)
 create mode 100644 tests/test_custom_onnx_exec.py

diff --git a/src/finn/core/execute_custom_node.py b/src/finn/core/execute_custom_node.py
index 064ed5f22..678794a53 100644
--- a/src/finn/core/execute_custom_node.py
+++ b/src/finn/core/execute_custom_node.py
@@ -1,8 +1,42 @@
 #import onnx.helper as helper
 
-#import finn.core.MultiThreshold
+import finn.core.MultiThreshold as multiThresh
 
-def execute_custom_node(node, context, graph)
+def execute_custom_node(node, context, graph) :
     """Call custom implementation to execute a single custom node. Input/output provided via context."""
-    node_inputs = list(filter(lambda x: x.name in node.input, graph.input))
-    print(node_inputs)
+    
+    if node.op_type == 'MultiThreshold' :
+        node_inputs = list(filter(lambda x: x.name in node.input, graph.input))
+        
+        # extract shape size of input tensors to determine which is input and which thresholds
+        shape_dict = {}
+        for inputs in node_inputs :
+            shape_dict[inputs.name]=0
+            for dim_value in inputs.type.tensor_type.shape.dim :
+                shape_dict[inputs.name] += 1
+        
+        # store input values in right tensors according to the shape size
+        for inputs in node_inputs :
+            if shape_dict[inputs.name] == 4 :
+                v = context[inputs.name]
+            else :
+                thresholds = context[inputs.name]
+        
+        output_list = multiThresh.execute(v, thresholds) 
+        
+        for output_ind in node.output:
+            print(output_ind)
+            #outp = node.output[output_ind]
+            #if output_list[output_ind].shape != context[outp].shape:
+            #    raise Exception(
+            #        "Output shapes disagree after node execution: found %s vs expected %s"
+            #        % (str(output_list[output_ind].shape.shape), str(context[outp].shape))
+            #    )
+            #context[outp] = output_list[output_ind]
+
+
+    else :
+        raise Exception(
+                "This custom node is currently not supported."
+        )
+
diff --git a/tests/test_custom_onnx_exec.py b/tests/test_custom_onnx_exec.py
new file mode 100644
index 000000000..828e10d1d
--- /dev/null
+++ b/tests/test_custom_onnx_exec.py
@@ -0,0 +1,147 @@
+import numpy as np
+
+import onnx
+from onnx import helper
+from onnx import AttributeProto, TensorProto, GraphProto
+
+import finn.core.execute_custom_node as ex_cu_node
+
+
+def test_execute_custom_node() :
+    inputs = np.ndarray(
+        shape=(6, 3, 2, 2),
+        buffer=np.array(
+            [
+                4.8,
+                3.2,
+                1.2,
+                4.9,
+                7.8,
+                2.4,
+                3.1,
+                4.7,
+                6.2,
+                5.1,
+                4.9,
+                2.2,
+                6.2,
+                0.0,
+                0.8,
+                4.7,
+                0.2,
+                5.6,
+                8.9,
+                9.2,
+                9.1,
+                4.0,
+                3.3,
+                4.9,
+                2.3,
+                1.7,
+                1.3,
+                2.2,
+                4.6,
+                3.4,
+                3.7,
+                9.8,
+                4.7,
+                4.9,
+                2.8,
+                2.7,
+                8.3,
+                6.7,
+                4.2,
+                7.1,
+                2.8,
+                3.1,
+                0.8,
+                0.6,
+                4.4,
+                2.7,
+                6.3,
+                6.1,
+                1.4,
+                5.3,
+                2.3,
+                1.9,
+                4.7,
+                8.1,
+                9.3,
+                3.7,
+                2.7,
+                5.1,
+                4.2,
+                1.8,
+                4.1,
+                7.3,
+                7.1,
+                0.4,
+                0.2,
+                1.3,
+                4.3,
+                8.9,
+                1.4,
+                1.6,
+                8.3,
+                9.4,
+            ]
+        ),
+    )
+
+    threshold_values = np.ndarray(
+        shape=(3, 7),
+        buffer=np.array(
+            [
+                0.8,
+                1.4,
+                1.7,
+                3.5,
+                5.2,
+                6.8,
+                8.2,
+                0.2,
+                2.2,
+                3.5,
+                4.5,
+                6.6,
+                8.6,
+                9.2,
+                1.3,
+                4.1,
+                4.5,
+                6.5,
+                7.8,
+                8.1,
+                8.9,
+            ]
+        ),
+    )
+
+    v = helper.make_tensor_value_info('v', TensorProto.FLOAT, [6, 3, 2, 2])
+    thresholds = helper.make_tensor_value_info('thresholds', TensorProto.FLOAT, [3, 7])
+    out = helper.make_tensor_value_info('out', TensorProto.FLOAT, [6, 3, 2, 2])
+
+    node_def = helper.make_node(
+            'MultiThreshold',
+            ['v', 'thresholds'],
+            ['out'],
+            domain='finn'
+            )
+
+
+    graph_def = helper.make_graph(
+        [node_def],
+        "test_model",
+        [v, thresholds],
+        [out]
+        )
+    
+    model = helper.make_model(graph_def, producer_name='onnx-example')
+
+    execution_context = {}
+    execution_context['v'] = inputs
+    execution_context['thresholds'] = threshold_values
+
+    print(ex_cu_node.execute_custom_node(node_def, execution_context, graph_def))
+
+
-- 
GitLab