From 245cefe18c15a96c9af02f581dc007b0e7cae72a Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <maltanar@gmail.com>
Date: Mon, 6 Jul 2020 14:35:23 +0100
Subject: [PATCH] [Transform] rename to RemoveUnusedTensors, also rm
 annotations + test

---
 src/finn/transformation/general.py                 | 14 +++++++++++---
 tests/end2end/test_end2end_cnv_w1a1.py             |  4 ++--
 tests/end2end/test_end2end_cnv_w2a2.py             |  4 ++--
 tests/end2end/test_end2end_tfc_w1a1.py             |  4 ++--
 tests/end2end/test_end2end_tfc_w1a2.py             |  4 ++--
 tests/end2end/test_end2end_tfc_w2a2.py             |  4 ++--
 .../streamline/test_streamline_cnv.py              |  4 ++--
 .../streamline/test_streamline_fc.py               |  5 +++--
 8 files changed, 26 insertions(+), 17 deletions(-)

diff --git a/src/finn/transformation/general.py b/src/finn/transformation/general.py
index d9147b700..4303eb17f 100644
--- a/src/finn/transformation/general.py
+++ b/src/finn/transformation/general.py
@@ -31,8 +31,11 @@ from finn.transformation import Transformation
 from toposort import toposort_flatten
 
 
-class RemoveUnusedInitAndValueInfo(Transformation):
-    "Remove any unused initializers and value_info in the graph."
+class RemoveUnusedTensors(Transformation):
+    """Remove any unused tensors in the graph by removing any initializers,
+    ValueInfo and tensor annotations associated with it. Unused tensors do not
+    appear as any input/output for any graph nodes.
+    """
 
     def apply(self, model):
         graph_modified = False
@@ -44,7 +47,8 @@ class RemoveUnusedInitAndValueInfo(Transformation):
                 used_tensors.add(i)
             for o in node.output:
                 used_tensors.add(o)
-        # remove initializers and value_info not in the used set
+        # remove initializers, value_info and annotations that are not in the
+        # used set of tensors, as determined by the graph node i/o
         for init in onnx_graph.initializer:
             if init.name not in used_tensors:
                 onnx_graph.initializer.remove(init)
@@ -53,6 +57,10 @@ class RemoveUnusedInitAndValueInfo(Transformation):
             if vi.name not in used_tensors:
                 onnx_graph.value_info.remove(vi)
                 graph_modified = True
+        for qa in onnx_graph.quantization_annotation:
+            if qa.tensor_name not in used_tensors:
+                onnx_graph.quantization_annotation.remove(qa)
+                graph_modified = True
 
         return (model, graph_modified)
 
diff --git a/tests/end2end/test_end2end_cnv_w1a1.py b/tests/end2end/test_end2end_cnv_w1a1.py
index c54079785..a2cfcd3a8 100644
--- a/tests/end2end/test_end2end_cnv_w1a1.py
+++ b/tests/end2end/test_end2end_cnv_w1a1.py
@@ -43,7 +43,7 @@ from finn.transformation.infer_shapes import InferShapes
 from finn.transformation.move_reshape import RemoveCNVtoFCFlatten
 from finn.transformation.fold_constants import FoldConstants
 from finn.transformation.general import (
-    RemoveUnusedInitAndValueInfo,
+    RemoveUnusedTensors,
     RemoveStaticGraphInputs,
     GiveReadableTensorNames,
     GiveUniqueNodeNames,
@@ -114,7 +114,7 @@ def test_end2end_cnv_w1a1_streamline():
     model = model.transform(absorb.AbsorbTransposeIntoMultiThreshold())
     model = model.transform(ConvertBipolarMatMulToXnorPopcount())
     model = model.transform(Streamline())
-    model = model.transform(RemoveUnusedInitAndValueInfo())
+    model = model.transform(RemoveUnusedTensors())
     model.save(build_dir + "/end2end_cnv_w1a1_streamlined.onnx")
 
 
diff --git a/tests/end2end/test_end2end_cnv_w2a2.py b/tests/end2end/test_end2end_cnv_w2a2.py
index ae1771fab..f45b0a3ec 100644
--- a/tests/end2end/test_end2end_cnv_w2a2.py
+++ b/tests/end2end/test_end2end_cnv_w2a2.py
@@ -43,7 +43,7 @@ from finn.transformation.infer_shapes import InferShapes
 from finn.transformation.move_reshape import RemoveCNVtoFCFlatten
 from finn.transformation.fold_constants import FoldConstants
 from finn.transformation.general import (
-    RemoveUnusedInitAndValueInfo,
+    RemoveUnusedTensors,
     RemoveStaticGraphInputs,
     GiveReadableTensorNames,
     GiveUniqueNodeNames,
@@ -112,7 +112,7 @@ def test_end2end_cnv_w2a2_streamline():
     model = model.transform(MakeMaxPoolNHWC())
     model = model.transform(absorb.AbsorbTransposeIntoMultiThreshold())
     model = model.transform(Streamline())
-    model = model.transform(RemoveUnusedInitAndValueInfo())
+    model = model.transform(RemoveUnusedTensors())
     model.save(build_dir + "/end2end_cnv_w2a2_streamlined.onnx")
 
 
diff --git a/tests/end2end/test_end2end_tfc_w1a1.py b/tests/end2end/test_end2end_tfc_w1a1.py
index 8dafc7fec..31659df63 100644
--- a/tests/end2end/test_end2end_tfc_w1a1.py
+++ b/tests/end2end/test_end2end_tfc_w1a1.py
@@ -64,7 +64,7 @@ from finn.transformation.fpgadataflow.replace_verilog_relpaths import (
 from finn.transformation.fpgadataflow.set_exec_mode import SetExecMode
 from finn.transformation.fpgadataflow.synth_pynq_proj import SynthPYNQProject
 from finn.transformation.general import (
-    RemoveUnusedInitAndValueInfo,
+    RemoveUnusedTensors,
     RemoveStaticGraphInputs,
     GiveReadableTensorNames,
     GiveUniqueNodeNames,
@@ -110,7 +110,7 @@ def test_end2end_tfc_w1a1_import_and_tidy():
 def test_end2end_tfc_w1a1_streamline():
     model = load_test_checkpoint_or_skip(build_dir + "/end2end_tfc_w1a1_tidy.onnx")
     model = model.transform(Streamline())
-    model = model.transform(RemoveUnusedInitAndValueInfo())
+    model = model.transform(RemoveUnusedTensors())
     model.save(build_dir + "/end2end_tfc_w1a1_streamlined.onnx")
 
 
diff --git a/tests/end2end/test_end2end_tfc_w1a2.py b/tests/end2end/test_end2end_tfc_w1a2.py
index cba558d83..d5579f625 100644
--- a/tests/end2end/test_end2end_tfc_w1a2.py
+++ b/tests/end2end/test_end2end_tfc_w1a2.py
@@ -62,7 +62,7 @@ from finn.transformation.fpgadataflow.replace_verilog_relpaths import (
 from finn.transformation.fpgadataflow.set_exec_mode import SetExecMode
 from finn.transformation.fpgadataflow.synth_pynq_proj import SynthPYNQProject
 from finn.transformation.general import (
-    RemoveUnusedInitAndValueInfo,
+    RemoveUnusedTensors,
     RemoveStaticGraphInputs,
     GiveReadableTensorNames,
     GiveUniqueNodeNames,
@@ -105,7 +105,7 @@ def test_end2end_tfc_w1a2_import_and_tidy():
 def test_end2end_tfc_w1a2_streamline():
     model = load_test_checkpoint_or_skip(build_dir + "/end2end_tfc_w1a2_tidy.onnx")
     model = model.transform(Streamline())
-    model = model.transform(RemoveUnusedInitAndValueInfo())
+    model = model.transform(RemoveUnusedTensors())
     model.save(build_dir + "/end2end_tfc_w1a2_streamlined.onnx")
 
 
diff --git a/tests/end2end/test_end2end_tfc_w2a2.py b/tests/end2end/test_end2end_tfc_w2a2.py
index 44a8a8b17..470119f34 100644
--- a/tests/end2end/test_end2end_tfc_w2a2.py
+++ b/tests/end2end/test_end2end_tfc_w2a2.py
@@ -62,7 +62,7 @@ from finn.transformation.fpgadataflow.replace_verilog_relpaths import (
 from finn.transformation.fpgadataflow.set_exec_mode import SetExecMode
 from finn.transformation.fpgadataflow.synth_pynq_proj import SynthPYNQProject
 from finn.transformation.general import (
-    RemoveUnusedInitAndValueInfo,
+    RemoveUnusedTensors,
     RemoveStaticGraphInputs,
     GiveReadableTensorNames,
     GiveUniqueNodeNames,
@@ -105,7 +105,7 @@ def test_end2end_tfc_w2a2_import_and_tidy():
 def test_end2end_tfc_w2a2_streamline():
     model = load_test_checkpoint_or_skip(build_dir + "/end2end_tfc_w2a2_tidy.onnx")
     model = model.transform(Streamline())
-    model = model.transform(RemoveUnusedInitAndValueInfo())
+    model = model.transform(RemoveUnusedTensors())
     model.save(build_dir + "/end2end_tfc_w2a2_streamlined.onnx")
 
 
diff --git a/tests/transformation/streamline/test_streamline_cnv.py b/tests/transformation/streamline/test_streamline_cnv.py
index 050e12a97..bcb66a2c2 100644
--- a/tests/transformation/streamline/test_streamline_cnv.py
+++ b/tests/transformation/streamline/test_streamline_cnv.py
@@ -35,7 +35,7 @@ import finn.core.onnx_exec as oxe
 from finn.core.modelwrapper import ModelWrapper
 from finn.transformation.fold_constants import FoldConstants
 from finn.transformation.general import (
-    RemoveUnusedInitAndValueInfo,
+    RemoveUnusedTensors,
     RemoveStaticGraphInputs,
     GiveReadableTensorNames,
     GiveUniqueNodeNames,
@@ -79,7 +79,7 @@ def test_streamline_cnv(size, wbits, abits):
     expected = expected_ctx[model.graph.output[0].name]
     # model.save("orig_cnv.onnx")
     model = model.transform(Streamline())
-    model = model.transform(RemoveUnusedInitAndValueInfo())
+    model = model.transform(RemoveUnusedTensors())
     assert len(model.graph.initializer) == 21
     assert len(model.graph.value_info) == 43
     # model.save("streamlined_cnv.onnx")
diff --git a/tests/transformation/streamline/test_streamline_fc.py b/tests/transformation/streamline/test_streamline_fc.py
index cb356dc38..dd7e756b4 100644
--- a/tests/transformation/streamline/test_streamline_fc.py
+++ b/tests/transformation/streamline/test_streamline_fc.py
@@ -38,7 +38,7 @@ import finn.core.onnx_exec as oxe
 from finn.core.modelwrapper import ModelWrapper
 from finn.transformation.fold_constants import FoldConstants
 from finn.transformation.general import (
-    RemoveUnusedInitAndValueInfo,
+    RemoveUnusedTensors,
     RemoveStaticGraphInputs,
     GiveReadableTensorNames,
     GiveUniqueNodeNames,
@@ -79,9 +79,10 @@ def test_streamline_fc(size, wbits, abits):
     expected_ctx = oxe.execute_onnx(model, input_dict, True)
     expected = expected_ctx[model.graph.output[0].name]
     model = model.transform(Streamline())
-    model = model.transform(RemoveUnusedInitAndValueInfo())
+    model = model.transform(RemoveUnusedTensors())
     assert len(model.graph.initializer) == 11
     assert len(model.graph.value_info) == 21
+    assert len(model.graph.quantization_annotation) == 18
     produced_ctx = oxe.execute_onnx(model, input_dict, True)
     produced = produced_ctx[model.graph.output[0].name]
     assert np.isclose(expected, produced, atol=1e-3).all()
-- 
GitLab