diff --git a/notebooks/FINN-HowToWorkWithONNX.ipynb b/notebooks/FINN-HowToWorkWithONNX.ipynb index 4d866325a3e57bb881b0bf7dedf337f9653a6bc5..8ce25ba19fc83dc7cc5382e6502ba73fef851490 100644 --- a/notebooks/FINN-HowToWorkWithONNX.ipynb +++ b/notebooks/FINN-HowToWorkWithONNX.ipynb @@ -31,7 +31,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -56,7 +56,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -310,16 +310,16 @@ "output_type": "stream", "text": [ "The output of the ONNX model is: \n", - "[[ 8. 3. 10. 8.]\n", - " [ 4. 22. 5. 11.]\n", - " [ 7. 15. 2. 8.]\n", - " [ 2. 12. 8. 3.]]\n", + "[[ 7. 7. 4. 3.]\n", + " [ 5. 2. 6. 27.]\n", + " [18. 9. 18. 14.]\n", + " [ 9. 5. 4. 18.]]\n", "\n", "The output of the reference function is: \n", - "[[ 8. 3. 10. 8.]\n", - " [ 4. 22. 5. 11.]\n", - " [ 7. 15. 2. 8.]\n", - " [ 2. 12. 8. 3.]]\n", + "[[ 7. 7. 4. 3.]\n", + " [ 5. 2. 6. 27.]\n", + " [18. 9. 18. 14.]\n", + " [ 9. 5. 4. 18.]]\n", "\n", "The results are the same!\n" ] @@ -374,8 +374,9 @@ " node_index = {}\n", " node_ind = 0\n", " for node in model.graph.node:\n", - " node_index[node_ind] = node.name\n", - " node_ind += 1" + " node_index[node.name] = node_ind\n", + " node_ind += 1\n", + " return node_index" ] }, { @@ -470,7 +471,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 17, "metadata": {}, "outputs": [], "source": [ @@ -498,7 +499,7 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": 18, "metadata": {}, "outputs": [ { @@ -515,7 +516,7 @@ "for node in add_nodes:\n", " add_pairs = adder_pair(finn_model, node)\n", " if len(add_pairs) != 0:\n", - " substitute_pair = add_pairs\n", + " substitute_pair = add_pairs[0]\n", " print(\"Found following pair that could be replaced by a sum node:\")\n", " for node_pair in add_pairs:\n", " for node in node_pair:\n", @@ -533,9 +534,18 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 21, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The new node gets the following inputs: \n", + "['in1', 'in2', 'in3']\n" + ] + } + ], "source": [ "input_list = []\n", "for i in range(len(substitute_pair)):\n", @@ -547,8 +557,159 @@ " for j in range(len(substitute_pair[i].input)):\n", " if substitute_pair[i].input[j] != substitute_pair[i-1].output[0]:\n", " input_list.append(substitute_pair[i].input[j])\n", - " " + "print(\"The new node gets the following inputs: \\n{}\".format(input_list))" ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "sum2\n" + ] + } + ], + "source": [ + "print(substitute_pair[1].output[0])" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [], + "source": [ + "sum_output = substitute_pair[1].output[0]\n", + "Sum_node = onnx.helper.make_node(\n", + " 'Sum',\n", + " inputs=input_list,\n", + " outputs=[sum_output],\n", + " name=\"Sum\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [], + "source": [ + "node_ids = get_node_id(finn_model)\n", + "node_ind = node_ids[substitute_pair[0].name]\n", + "graph.node.insert(node_ind, Sum_node)\n", + "\n", + "for node in substitute_pair:\n", + " graph.node.remove(node)" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [], + "source": [ + "onnx_model1 = onnx.helper.make_model(graph, producer_name=\"simple-model1\")\n", + "onnx.save(onnx_model1, 'simple_model1.onnx')" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Stopping http://0.0.0.0:8081\n", + "Serving 'simple_model1.onnx' at http://0.0.0.0:8081\n" + ] + } + ], + "source": [ + "import netron\n", + "netron.start('simple_model1.onnx', port=8081, host=\"0.0.0.0\")" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "<iframe src=\"http://0.0.0.0:8081/\" style=\"position: relative; width: 100%;\" height=\"400\"></iframe>\n" + ], + "text/plain": [ + "<IPython.core.display.HTML object>" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "%%html\n", + "<iframe src=\"http://0.0.0.0:8081/\" style=\"position: relative; width: 100%;\" height=\"400\"></iframe>" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [], + "source": [ + "sess = rt.InferenceSession(onnx_model1.SerializeToString())\n", + "output = sess.run(None, input_dict)" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The output of the manipulated ONNX model is: \n", + "[[ 7. 7. 4. 3.]\n", + " [ 5. 2. 6. 27.]\n", + " [18. 9. 18. 14.]\n", + " [ 9. 5. 4. 18.]]\n", + "\n", + "The output of the reference function is: \n", + "[[ 7. 7. 4. 3.]\n", + " [ 5. 2. 6. 27.]\n", + " [18. 9. 18. 14.]\n", + " [ 9. 5. 4. 18.]]\n", + "\n", + "The results are the same!\n" + ] + } + ], + "source": [ + "print(\"The output of the manipulated ONNX model is: \\n{}\".format(output[0]))\n", + "print(\"\\nThe output of the reference function is: \\n{}\".format(ref_output))\n", + "\n", + "if (output[0] == ref_output).all():\n", + " print(\"\\nThe results are the same!\")\n", + "else:\n", + " raise Exception(\"Something went wrong, the output of the model doesn't match the expected output!\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": {