diff --git a/notebooks/FINN-HowToWorkWithONNX.ipynb b/notebooks/FINN-HowToWorkWithONNX.ipynb index 44a030c11c88f4084b7da2b765a6563efcb8dc8d..fa4e381a64ae00be03e476695a3aab0457dbec3f 100644 --- a/notebooks/FINN-HowToWorkWithONNX.ipynb +++ b/notebooks/FINN-HowToWorkWithONNX.ipynb @@ -31,7 +31,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 39, "metadata": {}, "outputs": [], "source": [ @@ -41,6 +41,7 @@ " 'Add',\n", " inputs=['in1', 'in2'],\n", " outputs=['sum1'],\n", + " name='Add1'\n", ")" ] }, @@ -55,7 +56,7 @@ }, { "cell_type": "code", - "execution_count": 81, + "execution_count": 40, "metadata": {}, "outputs": [], "source": [ @@ -63,24 +64,28 @@ " 'Add',\n", " inputs=['sum1', 'in3'],\n", " outputs=['sum2'],\n", + " name='Add2',\n", ")\n", "\n", "Add3_node = onnx.helper.make_node(\n", " 'Add',\n", " inputs=['abs1', 'abs1'],\n", " outputs=['sum3'],\n", + " name='Add3',\n", ")\n", "\n", "Abs_node = onnx.helper.make_node(\n", " 'Abs',\n", " inputs=['sum2'],\n", " outputs=['abs1'],\n", + " name='Abs'\n", ")\n", "\n", "Round_node = onnx.helper.make_node(\n", " 'Round',\n", " inputs=['sum3'],\n", " outputs=['out1'],\n", + " name='Round',\n", ")\n" ] }, @@ -93,7 +98,7 @@ }, { "cell_type": "code", - "execution_count": 82, + "execution_count": 41, "metadata": {}, "outputs": [], "source": [ @@ -114,7 +119,7 @@ }, { "cell_type": "code", - "execution_count": 83, + "execution_count": 42, "metadata": {}, "outputs": [], "source": [ @@ -149,7 +154,7 @@ }, { "cell_type": "code", - "execution_count": 84, + "execution_count": 43, "metadata": {}, "outputs": [], "source": [ @@ -166,7 +171,7 @@ }, { "cell_type": "code", - "execution_count": 85, + "execution_count": 44, "metadata": {}, "outputs": [ { @@ -186,7 +191,7 @@ }, { "cell_type": "code", - "execution_count": 86, + "execution_count": 45, "metadata": { "scrolled": true }, @@ -220,7 +225,7 @@ }, { "cell_type": "code", - "execution_count": 92, + "execution_count": 46, "metadata": {}, "outputs": [], "source": [ @@ -243,7 +248,7 @@ }, { "cell_type": "code", - "execution_count": 93, + "execution_count": 47, "metadata": {}, "outputs": [], "source": [ @@ -261,7 +266,7 @@ }, { "cell_type": "code", - "execution_count": 94, + "execution_count": 48, "metadata": {}, "outputs": [], "source": [ @@ -280,7 +285,7 @@ }, { "cell_type": "code", - "execution_count": 95, + "execution_count": 49, "metadata": {}, "outputs": [], "source": [ @@ -299,7 +304,7 @@ }, { "cell_type": "code", - "execution_count": 96, + "execution_count": 50, "metadata": {}, "outputs": [ { @@ -307,16 +312,16 @@ "output_type": "stream", "text": [ "The output of the ONNX model is: \n", - "[[24. 2. 6. 2.]\n", - " [ 5. 1. 9. 5.]\n", - " [ 6. 1. 8. 8.]\n", - " [14. 4. 0. 1.]]\n", + "[[ 4. 17. 24. 20.]\n", + " [ 4. 7. 13. 21.]\n", + " [ 4. 3. 3. 14.]\n", + " [ 9. 12. 0. 5.]]\n", "\n", "The output of the reference function is: \n", - "[[24. 2. 6. 2.]\n", - " [ 5. 1. 9. 5.]\n", - " [ 6. 1. 8. 8.]\n", - " [14. 4. 0. 1.]]\n", + "[[ 4. 17. 24. 20.]\n", + " [ 4. 7. 13. 21.]\n", + " [ 4. 3. 3. 14.]\n", + " [ 9. 12. 0. 5.]]\n", "\n", "The results are the same!\n" ] @@ -360,12 +365,19 @@ }, { "cell_type": "code", - "execution_count": 104, + "execution_count": 57, "metadata": {}, "outputs": [], "source": [ "from finn.core.modelwrapper import ModelWrapper\n", - "finn_model = ModelWrapper(onnx_model)" + "finn_model = ModelWrapper(onnx_model)\n", + "\n", + "def get_node_id(model):\n", + " node_index = {}\n", + " node_ind = 0\n", + " for node in model.graph.node:\n", + " node_index[node_ind] = node.name\n", + " node_ind += 1" ] }, { @@ -377,7 +389,7 @@ }, { "cell_type": "code", - "execution_count": 105, + "execution_count": 60, "metadata": {}, "outputs": [], "source": [ @@ -401,38 +413,23 @@ }, { "cell_type": "code", - "execution_count": 106, + "execution_count": 62, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Found adder node: \n", - "input: \"in1\"\n", - "input: \"in2\"\n", - "output: \"sum1\"\n", - "op_type: \"Add\"\n", - "\n", - "Found adder node: \n", - "input: \"sum1\"\n", - "input: \"in3\"\n", - "output: \"sum2\"\n", - "op_type: \"Add\"\n", - "\n", - "Found adder node: \n", - "input: \"abs1\"\n", - "input: \"abs1\"\n", - "output: \"sum3\"\n", - "op_type: \"Add\"\n", - "\n" + "Found adder node: Add1\n", + "Found adder node: Add2\n", + "Found adder node: Add3\n" ] } ], "source": [ "add_nodes = identify_adder_nodes(finn_model)\n", "for node in add_nodes:\n", - " print(\"Found adder node: \\n{}\".format(node))" + " print(\"Found adder node: {}\".format(node.name))" ] }, { @@ -444,7 +441,7 @@ }, { "cell_type": "code", - "execution_count": 107, + "execution_count": 63, "metadata": {}, "outputs": [], "source": [ @@ -475,7 +472,7 @@ }, { "cell_type": "code", - "execution_count": 108, + "execution_count": 64, "metadata": {}, "outputs": [], "source": [ @@ -492,34 +489,75 @@ }, { "cell_type": "code", - "execution_count": 113, + "execution_count": 68, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "[input: \"in1\"\n", - "input: \"in2\"\n", - "output: \"sum1\"\n", - "op_type: \"Add\"\n", - ", input: \"sum1\"\n", - "input: \"in3\"\n", - "output: \"sum2\"\n", - "op_type: \"Add\"\n", - "]\n" + "Found following pair that could be replaced by a sum node:\n", + "Add1\n", + "Add2\n" ] } ], "source": [ - "i = 0\n", "for node in add_nodes:\n", " add_pair = adder_pair(finn_model, node)\n", " if len(add_pair) != 0:\n", - " print(add_pair)\n", - " i+=1" + " substitute_pair = add_pair\n", + " print(\"Found following pair that could be replaced by a sum node:\")\n", + " for node in add_pair:\n", + " print(node.name)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "<font size=\"3\">Now that the pair to be replaced has been identified, a sum node can be instantiated and inserted into the graph at the correct position. \n", + "\n", + "First of all, the inputs must be determined. For this the adder nodes inputs are used minus the input, which corresponds to the output of the other adder node.</font>" + ] + }, + { + "cell_type": "code", + "execution_count": 74, + "metadata": {}, + "outputs": [], + "source": [ + "input_list = []\n", + "for i in range(len(substitute_pair)):\n", + " if i == 0:\n", + " for j in range(len(substitute_pair[i].input)):\n", + " if substitute_pair[i].input[j] != substitute_pair[i+1].output[0]:\n", + " input_list.append(substitute_pair[i].input[j])\n", + " else:\n", + " for j in range(len(substitute_pair[i].input)):\n", + " if substitute_pair[i].input[j] != substitute_pair[i-1].output[0]:\n", + " input_list.append(substitute_pair[i].input[j])\n", + " " ] }, + { + "cell_type": "code", + "execution_count": 73, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['in1', 'in2', 'in3']" + ] + }, + "execution_count": 73, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [] + }, { "cell_type": "code", "execution_count": null,