From a8bbc11f91ba7235d9a3c15b54d85761f5b321ac Mon Sep 17 00:00:00 2001 From: auphelia <jakobapk@web.de> Date: Wed, 1 Jul 2020 13:24:28 +0100 Subject: [PATCH] Rename MoveScalarAddPastConv to MoveAddPastConv --- notebooks/end2end_example/tfc_end2end_example.ipynb | 12 +++++++----- src/finn/transformation/streamline/__init__.py | 4 ++-- src/finn/transformation/streamline/reorder.py | 2 +- tests/transformation/test_move_scalar_past_conv.py | 6 +++--- 4 files changed, 13 insertions(+), 11 deletions(-) diff --git a/notebooks/end2end_example/tfc_end2end_example.ipynb b/notebooks/end2end_example/tfc_end2end_example.ipynb index d57306148..c84efc964 100644 --- a/notebooks/end2end_example/tfc_end2end_example.ipynb +++ b/notebooks/end2end_example/tfc_end2end_example.ipynb @@ -132,7 +132,7 @@ " " ], "text/plain": [ - "<IPython.lib.display.IFrame at 0x7f8890385828>" + "<IPython.lib.display.IFrame at 0x7f7cc4290940>" ] }, "execution_count": 3, @@ -293,7 +293,7 @@ " " ], "text/plain": [ - "<IPython.lib.display.IFrame at 0x7fe1ad0639e8>" + "<IPython.lib.display.IFrame at 0x7f7c6c567f28>" ] }, "execution_count": 6, @@ -333,9 +333,10 @@ " ConvertDivToMul(),\n", " BatchNormToAffine(),\n", " ConvertSignToThres(),\n", + " AbsorbSignBiasIntoMultiThreshold(),\n", " MoveAddPastMul(),\n", " MoveScalarAddPastMatMul(),\n", - " MoveScalarAddPastConv(),\n", + " MoveAddPastConv(),\n", " MoveScalarMulPastMatMul(),\n", " MoveScalarMulPastConv(),\n", " MoveAddPastMul(),\n", @@ -350,6 +351,7 @@ " ]\n", " for trn in streamline_transformations:\n", " model = model.transform(trn)\n", + " model = model.transform(RemoveIdentityOps())\n", " model = model.transform(GiveUniqueNodeNames())\n", " model = model.transform(GiveReadableTensorNames())\n", " model = model.transform(InferDataTypes())\n", @@ -400,7 +402,7 @@ " " ], "text/plain": [ - "<IPython.lib.display.IFrame at 0x7fe1346e4ef0>" + "<IPython.lib.display.IFrame at 0x7f7c6c0bf898>" ] }, "execution_count": 8, @@ -454,7 +456,7 @@ " " ], "text/plain": [ - "<IPython.lib.display.IFrame at 0x7fe1346f7780>" + "<IPython.lib.display.IFrame at 0x7f7c6c0e5c18>" ] }, "execution_count": 9, diff --git a/src/finn/transformation/streamline/__init__.py b/src/finn/transformation/streamline/__init__.py index d9c12a209..d7686eaad 100644 --- a/src/finn/transformation/streamline/__init__.py +++ b/src/finn/transformation/streamline/__init__.py @@ -53,7 +53,7 @@ from finn.transformation.streamline.reorder import ( MoveAddPastMul, MoveScalarMulPastMatMul, MoveScalarAddPastMatMul, - MoveScalarAddPastConv, + MoveAddPastConv, MoveScalarMulPastConv, ) @@ -75,7 +75,7 @@ class Streamline(Transformation): AbsorbSignBiasIntoMultiThreshold(), MoveAddPastMul(), MoveScalarAddPastMatMul(), - MoveScalarAddPastConv(), + MoveAddPastConv(), MoveScalarMulPastMatMul(), MoveScalarMulPastConv(), MoveAddPastMul(), diff --git a/src/finn/transformation/streamline/reorder.py b/src/finn/transformation/streamline/reorder.py index a1bd16f6d..a12c31c96 100644 --- a/src/finn/transformation/streamline/reorder.py +++ b/src/finn/transformation/streamline/reorder.py @@ -216,7 +216,7 @@ class MoveScalarAddPastMatMul(Transformation): return (model, graph_modified) -class MoveScalarAddPastConv(Transformation): +class MoveAddPastConv(Transformation): """Move scalar add operations past conv operations. We want to have adds next to each other such that they can be collapsed into a single add.""" diff --git a/tests/transformation/test_move_scalar_past_conv.py b/tests/transformation/test_move_scalar_past_conv.py index 0f50642d2..96afce4bd 100644 --- a/tests/transformation/test_move_scalar_past_conv.py +++ b/tests/transformation/test_move_scalar_past_conv.py @@ -7,14 +7,14 @@ import finn.core.onnx_exec as ox from finn.core.modelwrapper import ModelWrapper from finn.transformation.infer_shapes import InferShapes from finn.transformation.streamline import ( - MoveScalarAddPastConv, + MoveAddPastConv, MoveScalarMulPastConv, ) @pytest.mark.parametrize("padding", [False, True]) @pytest.mark.parametrize( - "test_args", [("Add", MoveScalarAddPastConv()), ("Mul", MoveScalarMulPastConv())], + "test_args", [("Add", MoveAddPastConv()), ("Mul", MoveScalarMulPastConv())], ) def test_move_scalar_past_conv(test_args, padding): scalar_op = test_args[0] @@ -92,7 +92,7 @@ def test_move_scalar_past_conv(test_args, padding): @pytest.mark.parametrize( - "test_args", [("Add", MoveScalarAddPastConv()), ("Mul", MoveScalarMulPastConv())], + "test_args", [("Add", MoveAddPastConv()), ("Mul", MoveScalarMulPastConv())], ) def test_move_scalar_past_conv_only_if_linear(test_args): scalar_op = test_args[0] -- GitLab