diff --git a/src/finn/transformation/streamline/remove.py b/src/finn/transformation/streamline/remove.py new file mode 100644 index 0000000000000000000000000000000000000000..8f5e50368bcfd02394c2b57b0a5a119b9e94e9ea --- /dev/null +++ b/src/finn/transformation/streamline/remove.py @@ -0,0 +1,40 @@ +from finn.transformation import Transformation +from finn.transformation.infer_shapes import InferShapes + + +class RemoveIdentityOps(Transformation): + """Remove identity ops like Add/Sub with zero or Mul/Div with one""" + + def apply(self, model): + graph = model.graph + node_ind = 0 + graph_modified = False + for n in graph.node: + node_ind += 1 + if ( + n.op_type in ["Add", "Sub"] + and not model.is_fork_node(n) + and not model.is_join_node(n) + ): + A = model.get_initializer(n.input[1]) + if A == 0: + producer = model.find_producer(n.input[0]) + # remove node and wire output tensor to + # output of producer node + producer.output[0] = n.output[0] + graph.node.remove(n) + + elif ( + n.op_type in ["Mul", "Div"] + and not model.is_fork_node(n) + and not model.is_join_node(n) + ): + A = model.get_initializer(n.input[1]) + if A == 1: + producer = model.find_producer(n.input[0]) + # remove node and wire output tensor to + # output of producer node + producer.output[0] = n.output[0] + graph.node.remove(n) + model = model.transform(InferShapes()) + return (model, graph_modified)