From 5ce16895771b383e269b74c73b070424b31370fd Mon Sep 17 00:00:00 2001 From: auphelia <jakobapk@web.de> Date: Wed, 30 Oct 2019 14:59:44 +0000 Subject: [PATCH] [Integration] Interface of execute_custom_node.py to onnx_exec.py finished and first draft of unit test added --- src/finn/core/execute_custom_node.py | 42 +++++++- tests/test_custom_onnx_exec.py | 147 +++++++++++++++++++++++++++ 2 files changed, 185 insertions(+), 4 deletions(-) create mode 100644 tests/test_custom_onnx_exec.py diff --git a/src/finn/core/execute_custom_node.py b/src/finn/core/execute_custom_node.py index 064ed5f22..678794a53 100644 --- a/src/finn/core/execute_custom_node.py +++ b/src/finn/core/execute_custom_node.py @@ -1,8 +1,42 @@ #import onnx.helper as helper -#import finn.core.MultiThreshold +import finn.core.MultiThreshold as multiThresh -def execute_custom_node(node, context, graph) +def execute_custom_node(node, context, graph) : """Call custom implementation to execute a single custom node. Input/output provided via context.""" - node_inputs = list(filter(lambda x: x.name in node.input, graph.input)) - print(node_inputs) + + if node.op_type == 'MultiThreshold' : + node_inputs = list(filter(lambda x: x.name in node.input, graph.input)) + + # extract shape size of input tensors to determine which is input and which thresholds + shape_dict = {} + for inputs in node_inputs : + shape_dict[inputs.name]=0 + for dim_value in inputs.type.tensor_type.shape.dim : + shape_dict[inputs.name] += 1 + + # store input values in right tensors according to the shape size + for inputs in node_inputs : + if shape_dict[inputs.name] == 4 : + v = context[inputs.name] + else : + thresholds = context[inputs.name] + + output_list = multiThresh.execute(v, thresholds) + + for output_ind in node.output: + print(output_ind) + #outp = node.output[output_ind] + #if output_list[output_ind].shape != context[outp].shape: + # raise Exception( + # "Output shapes disagree after node execution: found %s vs expected %s" + # % (str(output_list[output_ind].shape.shape), str(context[outp].shape)) + # ) + #context[outp] = output_list[output_ind] + + + else : + raise Exception( + "This custom node is currently not supported." + ) + diff --git a/tests/test_custom_onnx_exec.py b/tests/test_custom_onnx_exec.py new file mode 100644 index 000000000..828e10d1d --- /dev/null +++ b/tests/test_custom_onnx_exec.py @@ -0,0 +1,147 @@ +import numpy as np + +import onnx +from onnx import helper +from onnx import AttributeProto, TensorProto, GraphProto + +import finn.core.execute_custom_node as ex_cu_node + + +def test_execute_custom_node() : + inputs = np.ndarray( + shape=(6, 3, 2, 2), + buffer=np.array( + [ + 4.8, + 3.2, + 1.2, + 4.9, + 7.8, + 2.4, + 3.1, + 4.7, + 6.2, + 5.1, + 4.9, + 2.2, + 6.2, + 0.0, + 0.8, + 4.7, + 0.2, + 5.6, + 8.9, + 9.2, + 9.1, + 4.0, + 3.3, + 4.9, + 2.3, + 1.7, + 1.3, + 2.2, + 4.6, + 3.4, + 3.7, + 9.8, + 4.7, + 4.9, + 2.8, + 2.7, + 8.3, + 6.7, + 4.2, + 7.1, + 2.8, + 3.1, + 0.8, + 0.6, + 4.4, + 2.7, + 6.3, + 6.1, + 1.4, + 5.3, + 2.3, + 1.9, + 4.7, + 8.1, + 9.3, + 3.7, + 2.7, + 5.1, + 4.2, + 1.8, + 4.1, + 7.3, + 7.1, + 0.4, + 0.2, + 1.3, + 4.3, + 8.9, + 1.4, + 1.6, + 8.3, + 9.4, + ] + ), + ) + + threshold_values = np.ndarray( + shape=(3, 7), + buffer=np.array( + [ + 0.8, + 1.4, + 1.7, + 3.5, + 5.2, + 6.8, + 8.2, + 0.2, + 2.2, + 3.5, + 4.5, + 6.6, + 8.6, + 9.2, + 1.3, + 4.1, + 4.5, + 6.5, + 7.8, + 8.1, + 8.9, + ] + ), + ) + + v = helper.make_tensor_value_info('v', TensorProto.FLOAT, [6, 3, 2, 2]) + thresholds = helper.make_tensor_value_info('thresholds', TensorProto.FLOAT, [3, 7]) + out = helper.make_tensor_value_info('out', TensorProto.FLOAT, [6, 3, 2, 2]) + + node_def = helper.make_node( + 'MultiThreshold', + ['v', 'thresholds'], + ['out'], + domain='finn' + ) + + + graph_def = helper.make_graph( + [node_def], + "test_model", + [v, thresholds], + [out] + ) + + model = helper.make_model(graph_def, producer_name='onnx-example') + + execution_context = {} + execution_context['v'] = inputs + execution_context['thresholds'] = threshold_values + + print(ex_cu_node.execute_custom_node(node_def, execution_context, graph_def)) + + -- GitLab