From 15e8facada13bc747477d9a521e628fbe2a5ff67 Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <maltanar@gmail.com>
Date: Mon, 6 Jul 2020 11:25:39 +0100
Subject: [PATCH] [Transform] introduce RemoveUnusedInitAndValueInfo

---
 src/finn/transformation/general.py | 57 +++++++++++++++++++++---------
 1 file changed, 41 insertions(+), 16 deletions(-)

diff --git a/src/finn/transformation/general.py b/src/finn/transformation/general.py
index 488391740..9a238ecf1 100644
--- a/src/finn/transformation/general.py
+++ b/src/finn/transformation/general.py
@@ -31,6 +31,32 @@ from finn.transformation import Transformation
 from toposort import toposort_flatten
 
 
+class RemoveUnusedInitAndValueInfo(Transformation):
+    "Remove any unused initializers and value_info in the graph."
+
+    def apply(self, model):
+        graph_modified = False
+        onnx_graph = model.model.graph
+        # build a set of tensors that we actually use in the graph nodes
+        used_tensors = set()
+        for node in model.graph.node:
+            for i in node.input:
+                used_tensors.add(i)
+            for o in node.output:
+                used_tensors.add(o)
+        # remove initializers and value_info not in the used set
+        for init in onnx_graph.initializer:
+            if init.name not in used_tensors:
+                onnx_graph.initializer.remove(init)
+                graph_modified = True
+        for vi in onnx_graph.value_info:
+            if vi.name not in used_tensors:
+                onnx_graph.value_info.remove(vi)
+                graph_modified = True
+
+        return (model, graph_modified)
+
+
 class GiveUniqueNodeNames(Transformation):
     """Give unique names to each node in the graph using enumeration."""
 
@@ -121,24 +147,23 @@ class GiveUniqueParameterTensors(Transformation):
 
 class SortGraph(Transformation):
     """ Returns the model with its node list sorted topologically.
-    Any ONNX graph to be executed must have a topologically sorted node list, as dictated
-    by the ONNX standard.
+    Any ONNX graph to be executed must have a topologically sorted node list,
+    as dictated by the ONNX standard.
     """
-    
+
     # Notes on SortGraph performance:
-    #         benchmark in  tests/transformation/test_sort_graph.py
-    # 
-    #         The algorithm doesn't move initializers so its performance should only depend on
-    #         the number of nodes
-    # 
-    #         Relative order of magnitudes for time per step:
-    #             - Gather graph structure:       base
-    #             - Sort nodes:                   0.1 of base
-    #             - Remove and insert in order :  0.001 of base
-    # 
-    #     Notes:
-    #         Remove nodes and insert them in order:
-    #           Probably this is faster than copying initializers and more robust in general
+    # benchmark in  tests/transformation/test_sort_graph.py
+    # The algorithm doesn't move initializers so its performance should only depend on
+    # the number of nodes
+    #
+    # Relative order of magnitudes for time per step:
+    # - Gather graph structure:       base
+    # - Sort nodes:                   0.1 of base
+    # - Remove and insert in order :  0.001 of base
+    #
+    # Notes:
+    # Remove nodes and insert them in order:
+    # Probably this is faster than copying initializers and more robust in general
 
     def apply(self, model):
         # Gather graph structure
-- 
GitLab