From 313d4a4fb1a31cdb129dc758e70f7c983babdff8 Mon Sep 17 00:00:00 2001
From: auphelia <jakobapk@web.de>
Date: Wed, 18 Dec 2019 10:17:51 +0000
Subject: [PATCH] [notebook] Added code for section about graph manipulation

---
 notebooks/FINN-HowToWorkWithONNX.ipynb | 197 ++++++++++++++++++++++---
 1 file changed, 179 insertions(+), 18 deletions(-)

diff --git a/notebooks/FINN-HowToWorkWithONNX.ipynb b/notebooks/FINN-HowToWorkWithONNX.ipynb
index 4d866325a..8ce25ba19 100644
--- a/notebooks/FINN-HowToWorkWithONNX.ipynb
+++ b/notebooks/FINN-HowToWorkWithONNX.ipynb
@@ -31,7 +31,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 1,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -56,7 +56,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 2,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -310,16 +310,16 @@
      "output_type": "stream",
      "text": [
       "The output of the ONNX model is: \n",
-      "[[ 8.  3. 10.  8.]\n",
-      " [ 4. 22.  5. 11.]\n",
-      " [ 7. 15.  2.  8.]\n",
-      " [ 2. 12.  8.  3.]]\n",
+      "[[ 7.  7.  4.  3.]\n",
+      " [ 5.  2.  6. 27.]\n",
+      " [18.  9. 18. 14.]\n",
+      " [ 9.  5.  4. 18.]]\n",
       "\n",
       "The output of the reference function is: \n",
-      "[[ 8.  3. 10.  8.]\n",
-      " [ 4. 22.  5. 11.]\n",
-      " [ 7. 15.  2.  8.]\n",
-      " [ 2. 12.  8.  3.]]\n",
+      "[[ 7.  7.  4.  3.]\n",
+      " [ 5.  2.  6. 27.]\n",
+      " [18.  9. 18. 14.]\n",
+      " [ 9.  5.  4. 18.]]\n",
       "\n",
       "The results are the same!\n"
      ]
@@ -374,8 +374,9 @@
     "    node_index = {}\n",
     "    node_ind = 0\n",
     "    for node in model.graph.node:\n",
-    "        node_index[node_ind] = node.name\n",
-    "        node_ind += 1"
+    "        node_index[node.name] = node_ind\n",
+    "        node_ind += 1\n",
+    "    return node_index"
    ]
   },
   {
@@ -470,7 +471,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 25,
+   "execution_count": 17,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -498,7 +499,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 27,
+   "execution_count": 18,
    "metadata": {},
    "outputs": [
     {
@@ -515,7 +516,7 @@
     "for node in add_nodes:\n",
     "    add_pairs = adder_pair(finn_model, node)\n",
     "    if len(add_pairs) != 0:\n",
-    "        substitute_pair = add_pairs\n",
+    "        substitute_pair = add_pairs[0]\n",
     "        print(\"Found following pair that could be replaced by a sum node:\")\n",
     "        for node_pair in add_pairs:\n",
     "            for node in node_pair:\n",
@@ -533,9 +534,18 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 21,
    "metadata": {},
-   "outputs": [],
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "The new node gets the following inputs: \n",
+      "['in1', 'in2', 'in3']\n"
+     ]
+    }
+   ],
    "source": [
     "input_list = []\n",
     "for i in range(len(substitute_pair)):\n",
@@ -547,8 +557,159 @@
     "        for j in range(len(substitute_pair[i].input)):\n",
     "            if substitute_pair[i].input[j] != substitute_pair[i-1].output[0]:\n",
     "                input_list.append(substitute_pair[i].input[j])\n",
-    "        "
+    "print(\"The new node gets the following inputs: \\n{}\".format(input_list))"
    ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 22,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "sum2\n"
+     ]
+    }
+   ],
+   "source": [
+    "print(substitute_pair[1].output[0])"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 23,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "sum_output = substitute_pair[1].output[0]\n",
+    "Sum_node = onnx.helper.make_node(\n",
+    "    'Sum',\n",
+    "    inputs=input_list,\n",
+    "    outputs=[sum_output],\n",
+    "    name=\"Sum\"\n",
+    ")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 24,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "node_ids = get_node_id(finn_model)\n",
+    "node_ind = node_ids[substitute_pair[0].name]\n",
+    "graph.node.insert(node_ind, Sum_node)\n",
+    "\n",
+    "for node in substitute_pair:\n",
+    "    graph.node.remove(node)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 25,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "onnx_model1 = onnx.helper.make_model(graph, producer_name=\"simple-model1\")\n",
+    "onnx.save(onnx_model1, 'simple_model1.onnx')"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 26,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "\n",
+      "Stopping http://0.0.0.0:8081\n",
+      "Serving 'simple_model1.onnx' at http://0.0.0.0:8081\n"
+     ]
+    }
+   ],
+   "source": [
+    "import netron\n",
+    "netron.start('simple_model1.onnx', port=8081, host=\"0.0.0.0\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 27,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/html": [
+       "<iframe src=\"http://0.0.0.0:8081/\" style=\"position: relative; width: 100%;\" height=\"400\"></iframe>\n"
+      ],
+      "text/plain": [
+       "<IPython.core.display.HTML object>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "%%html\n",
+    "<iframe src=\"http://0.0.0.0:8081/\" style=\"position: relative; width: 100%;\" height=\"400\"></iframe>"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 28,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "sess = rt.InferenceSession(onnx_model1.SerializeToString())\n",
+    "output = sess.run(None, input_dict)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 29,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "The output of the manipulated ONNX model is: \n",
+      "[[ 7.  7.  4.  3.]\n",
+      " [ 5.  2.  6. 27.]\n",
+      " [18.  9. 18. 14.]\n",
+      " [ 9.  5.  4. 18.]]\n",
+      "\n",
+      "The output of the reference function is: \n",
+      "[[ 7.  7.  4.  3.]\n",
+      " [ 5.  2.  6. 27.]\n",
+      " [18.  9. 18. 14.]\n",
+      " [ 9.  5.  4. 18.]]\n",
+      "\n",
+      "The results are the same!\n"
+     ]
+    }
+   ],
+   "source": [
+    "print(\"The output of the manipulated ONNX model is: \\n{}\".format(output[0]))\n",
+    "print(\"\\nThe output of the reference function is: \\n{}\".format(ref_output))\n",
+    "\n",
+    "if (output[0] == ref_output).all():\n",
+    "    print(\"\\nThe results are the same!\")\n",
+    "else:\n",
+    "    raise Exception(\"Something went wrong, the output of the model doesn't match the expected output!\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
   }
  ],
  "metadata": {
-- 
GitLab