diff --git a/notebooks/FINN-HowToWorkWithONNX.ipynb b/notebooks/FINN-HowToWorkWithONNX.ipynb index c9ca6a94b0c6645cebe6417d9e3d982449a45ef9..44a030c11c88f4084b7da2b765a6563efcb8dc8d 100644 --- a/notebooks/FINN-HowToWorkWithONNX.ipynb +++ b/notebooks/FINN-HowToWorkWithONNX.ipynb @@ -14,7 +14,8 @@ "metadata": {}, "source": [ "## Outline\n", - "* #### How to create a simple model" + "* #### How to create a simple model\n", + "* #### How to manipulate an ONNX model" ] }, { @@ -54,7 +55,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 81, "metadata": {}, "outputs": [], "source": [ @@ -66,7 +67,7 @@ "\n", "Add3_node = onnx.helper.make_node(\n", " 'Add',\n", - " inputs=['sum2', 'abs1'],\n", + " inputs=['abs1', 'abs1'],\n", " outputs=['sum3'],\n", ")\n", "\n", @@ -92,7 +93,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 82, "metadata": {}, "outputs": [], "source": [ @@ -113,7 +114,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 83, "metadata": {}, "outputs": [], "source": [ @@ -148,7 +149,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 84, "metadata": {}, "outputs": [], "source": [ @@ -165,13 +166,15 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 85, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ + "\n", + "Stopping http://0.0.0.0:8081\n", "Serving 'simple_model.onnx' at http://0.0.0.0:8081\n" ] } @@ -183,7 +186,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 86, "metadata": { "scrolled": true }, @@ -217,7 +220,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 92, "metadata": {}, "outputs": [], "source": [ @@ -227,7 +230,7 @@ " sum1 = np.add(in1, in2)\n", " sum2 = np.add(sum1, in3)\n", " abs1 = np.absolute(sum2)\n", - " sum3 = np.add(abs1, sum2)\n", + " sum3 = np.add(abs1, abs1)\n", " return np.round(sum3)" ] }, @@ -240,7 +243,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 93, "metadata": {}, "outputs": [], "source": [ @@ -258,7 +261,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 94, "metadata": {}, "outputs": [], "source": [ @@ -277,7 +280,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 95, "metadata": {}, "outputs": [], "source": [ @@ -296,20 +299,226 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 96, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The output of the ONNX model is: \n", + "[[24. 2. 6. 2.]\n", + " [ 5. 1. 9. 5.]\n", + " [ 6. 1. 8. 8.]\n", + " [14. 4. 0. 1.]]\n", + "\n", + "The output of the reference function is: \n", + "[[24. 2. 6. 2.]\n", + " [ 5. 1. 9. 5.]\n", + " [ 6. 1. 8. 8.]\n", + " [14. 4. 0. 1.]]\n", + "\n", + "The results are the same!\n" + ] + } + ], "source": [ "ref_output= expected_output(in1_values, in2_values, in3_values)\n", - "assert output[0].all() == ref_output.all()" + "print(\"The output of the ONNX model is: \\n{}\".format(output[0]))\n", + "print(\"\\nThe output of the reference function is: \\n{}\".format(ref_output))\n", + "\n", + "if (output[0] == ref_output).all():\n", + " print(\"\\nThe results are the same!\")\n", + "else:\n", + " raise Exception(\"Something went wrong, the output of the model doesn't match the expected output!\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "<font size=\"3\">Now that we have verified that the model works as we expected it to, we can continue working with the graph. </font>" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### How to manipulate an ONNX model\n", + "\n", + "<font size=\"3\">In the model there are two successive adder nodes. An adder node in ONNX can only add two inputs, but there is also the [**sum**](https://github.com/onnx/onnx/blob/master/docs/Operators.md#Sum) node, which can process more than one input. So it would be a reasonable change of the graph to combine the two successive adder nodes to one sum node. </font>" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "<font size=\"3\">In the following we assume that we do not know the appearance of the model, so we first try to identify whether there are two consecutive adders in the graph and then convert them into a sum node. \n", + "\n", + "Here we make use of FINN. FINN provides a thin wrapper around the model which provides several additional helper functions to manipulate the graph. The code can be found [here](https://github.com/Xilinx/finn/blob/dev/src/finn/core/modelwrapper.py) and you can find a more detailed description in the notebook [FINN-ModelWrapper](FINN-ModelWrapper.ipynb).</font>" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 104, "metadata": {}, "outputs": [], - "source": [] + "source": [ + "from finn.core.modelwrapper import ModelWrapper\n", + "finn_model = ModelWrapper(onnx_model)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "<font size=\"3\">Now we design a function that searches for adder nodes in the graph and returns the found nodes. </font>" + ] + }, + { + "cell_type": "code", + "execution_count": 105, + "metadata": {}, + "outputs": [], + "source": [ + "def identify_adder_nodes(model):\n", + " add_nodes = []\n", + " for node in model.graph.node:\n", + " if node.op_type == \"Add\":\n", + " add_nodes.append(node)\n", + " return add_nodes\n", + " " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "<font size=\"3\">The function iterates over all nodes of the model and if the operation type is \"Add\" the node will be stored in `add_nodes`. At the end `add_nodes` is returned.\n", + "\n", + "If we apply this to our model, three nodes should be returned.</font> " + ] + }, + { + "cell_type": "code", + "execution_count": 106, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Found adder node: \n", + "input: \"in1\"\n", + "input: \"in2\"\n", + "output: \"sum1\"\n", + "op_type: \"Add\"\n", + "\n", + "Found adder node: \n", + "input: \"sum1\"\n", + "input: \"in3\"\n", + "output: \"sum2\"\n", + "op_type: \"Add\"\n", + "\n", + "Found adder node: \n", + "input: \"abs1\"\n", + "input: \"abs1\"\n", + "output: \"sum3\"\n", + "op_type: \"Add\"\n", + "\n" + ] + } + ], + "source": [ + "add_nodes = identify_adder_nodes(finn_model)\n", + "for node in add_nodes:\n", + " print(\"Found adder node: \\n{}\".format(node))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "<font size=\"3\">Among other helper functions, ModelWrapper offers two functions to determine the preceding and succeeding node of a node. However, these functions are not getting a node as input, but can determine the consumer or producer of a tensor. We write two functions that uses these helper functions to determine the previous and the next node of a node.</font>" + ] + }, + { + "cell_type": "code", + "execution_count": 107, + "metadata": {}, + "outputs": [], + "source": [ + "def find_predecessor(model, node):\n", + " predecessors = []\n", + " for i in range(len(node.input)):\n", + " producer = model.find_producer(node.input[i])\n", + " predecessors.append(producer)\n", + " return predecessors\n", + " \n", + "\n", + "def find_successor(model, node):\n", + " successors = []\n", + " for i in range(len(node.output)):\n", + " consumer = model.find_consumer(node.output[i])\n", + " successors.append(consumer)\n", + " return successors" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "<font size=\"3\">The first function uses `find_producer` from ModelWrapper to create a list of the producers of the inputs of the given node. So the returned list is indirectly filled with the predecessors of the node. The second function works in a similar way, `find_consumer` from ModelWrapper is used to find the consumers of the output tensors of the node and so a list with the successors can be created. \n", + "\n", + "So now we can find out which adder node has an adder node as successor</font>" + ] + }, + { + "cell_type": "code", + "execution_count": 108, + "metadata": {}, + "outputs": [], + "source": [ + "def adder_pair(model, node):\n", + " successor_list = find_successor(model, node)\n", + " node_pair = []\n", + " for successor in successor_list:\n", + " if successor.op_type == \"Add\":\n", + " node_pair.append(node)\n", + " node_pair.append(successor)\n", + " return node_pair\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": 113, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[input: \"in1\"\n", + "input: \"in2\"\n", + "output: \"sum1\"\n", + "op_type: \"Add\"\n", + ", input: \"sum1\"\n", + "input: \"in3\"\n", + "output: \"sum2\"\n", + "op_type: \"Add\"\n", + "]\n" + ] + } + ], + "source": [ + "i = 0\n", + "for node in add_nodes:\n", + " add_pair = adder_pair(finn_model, node)\n", + " if len(add_pair) != 0:\n", + " print(add_pair)\n", + " i+=1" + ] }, { "cell_type": "code",