diff --git a/src/finn/transformation/general.py b/src/finn/transformation/general.py index 9a238ecf1b017e183284e4128164bb02c0dc3549..d9147b700cd212225628e8b5dfdcd2dd53c543e0 100644 --- a/src/finn/transformation/general.py +++ b/src/finn/transformation/general.py @@ -57,6 +57,21 @@ class RemoveUnusedInitAndValueInfo(Transformation): return (model, graph_modified) +class RemoveStaticGraphInputs(Transformation): + "Remove any top-level graph inputs that have initializers." + + def apply(self, model): + graph_modified = False + for i in model.graph.input: + if model.get_initializer(i.name) is not None: + # move ValueInfo to internal (value_info) container + model.graph.value_info.append(i) + model.graph.input.remove(i) + graph_modified = True + + return (model, graph_modified) + + class GiveUniqueNodeNames(Transformation): """Give unique names to each node in the graph using enumeration."""