diff --git a/src/finn/transformation/general.py b/src/finn/transformation/general.py index 488391740fc25f1f7caa657adc9ed55bdc2f9722..9a238ecf1b017e183284e4128164bb02c0dc3549 100644 --- a/src/finn/transformation/general.py +++ b/src/finn/transformation/general.py @@ -31,6 +31,32 @@ from finn.transformation import Transformation from toposort import toposort_flatten +class RemoveUnusedInitAndValueInfo(Transformation): + "Remove any unused initializers and value_info in the graph." + + def apply(self, model): + graph_modified = False + onnx_graph = model.model.graph + # build a set of tensors that we actually use in the graph nodes + used_tensors = set() + for node in model.graph.node: + for i in node.input: + used_tensors.add(i) + for o in node.output: + used_tensors.add(o) + # remove initializers and value_info not in the used set + for init in onnx_graph.initializer: + if init.name not in used_tensors: + onnx_graph.initializer.remove(init) + graph_modified = True + for vi in onnx_graph.value_info: + if vi.name not in used_tensors: + onnx_graph.value_info.remove(vi) + graph_modified = True + + return (model, graph_modified) + + class GiveUniqueNodeNames(Transformation): """Give unique names to each node in the graph using enumeration.""" @@ -121,24 +147,23 @@ class GiveUniqueParameterTensors(Transformation): class SortGraph(Transformation): """ Returns the model with its node list sorted topologically. - Any ONNX graph to be executed must have a topologically sorted node list, as dictated - by the ONNX standard. + Any ONNX graph to be executed must have a topologically sorted node list, + as dictated by the ONNX standard. """ - + # Notes on SortGraph performance: - # benchmark in tests/transformation/test_sort_graph.py - # - # The algorithm doesn't move initializers so its performance should only depend on - # the number of nodes - # - # Relative order of magnitudes for time per step: - # - Gather graph structure: base - # - Sort nodes: 0.1 of base - # - Remove and insert in order : 0.001 of base - # - # Notes: - # Remove nodes and insert them in order: - # Probably this is faster than copying initializers and more robust in general + # benchmark in tests/transformation/test_sort_graph.py + # The algorithm doesn't move initializers so its performance should only depend on + # the number of nodes + # + # Relative order of magnitudes for time per step: + # - Gather graph structure: base + # - Sort nodes: 0.1 of base + # - Remove and insert in order : 0.001 of base + # + # Notes: + # Remove nodes and insert them in order: + # Probably this is faster than copying initializers and more robust in general def apply(self, model): # Gather graph structure