diff --git a/src/finn/transformation/streamline/remove.py b/src/finn/transformation/streamline/remove.py index 12c6984c6e66e1917d2a1e0a74c8620ccb6afabc..0a36b8bbe5c05a8226ae647e0061c1551f3b1cbf 100644 --- a/src/finn/transformation/streamline/remove.py +++ b/src/finn/transformation/streamline/remove.py @@ -32,6 +32,23 @@ from finn.transformation.infer_shapes import InferShapes import numpy as np +def _remove_node_and_rewire(model, node): + producer = model.find_producer(node.input[0]) + if producer is not None: + # wire output tensor to + # output of producer node + producer.output[0] = node.output[0] + else: + # node is first in graph + consumer = model.find_consumer(node.output[0]) + assert consumer is not None, "Whole graph is identity" + assert consumer.input[0] == node.output[0] + # rewire consumer's input directly to graph input + consumer.input[0] = node.input[0] + # remove node + model.graph.node.remove(node) + + class RemoveIdentityOps(Transformation): """Remove identity ops like Add/Sub with zero or Mul/Div with one""" @@ -48,11 +65,7 @@ class RemoveIdentityOps(Transformation): ): A = model.get_initializer(n.input[1]) if A is not None and (A == np.zeros_like(A)).all(): - producer = model.find_producer(n.input[0]) - # remove node and wire output tensor to - # output of producer node - producer.output[0] = n.output[0] - graph.node.remove(n) + _remove_node_and_rewire(model, n) elif ( n.op_type in ["Mul", "Div"] @@ -61,10 +74,6 @@ class RemoveIdentityOps(Transformation): ): A = model.get_initializer(n.input[1]) if A is not None and (A == np.ones_like(A)).all(): - producer = model.find_producer(n.input[0]) - # remove node and wire output tensor to - # output of producer node - producer.output[0] = n.output[0] - graph.node.remove(n) + _remove_node_and_rewire(model, n) model = model.transform(InferShapes()) return (model, graph_modified) diff --git a/tests/transformation/streamline/test_remove_identity_ops.py b/tests/transformation/streamline/test_remove_identity_ops.py index 536c1ab0b48fa44388da23f45b528da3c5f3b2f2..98430fad0e0f4c17d77ddbf44afeeccd44372047 100644 --- a/tests/transformation/streamline/test_remove_identity_ops.py +++ b/tests/transformation/streamline/test_remove_identity_ops.py @@ -11,7 +11,7 @@ from finn.transformation.streamline.remove import RemoveIdentityOps from finn.util.basic import gen_finn_dt_tensor -def insert_identity_op(model, op): +def insert_identity_op(model, op, as_first_node): if op in ["Add", "Sub"]: val = np.asarray([0.0], dtype=np.float32) elif op in ["Mul", "Div"]: @@ -19,10 +19,15 @@ def insert_identity_op(model, op): else: return - identity_node = helper.make_node(op, ["div_out", "value"], ["ident_out"]) graph = model.graph - graph.node.insert(3, identity_node) - graph.node[-1].input[0] = "ident_out" + if as_first_node: + identity_node = helper.make_node(op, ["inp", "value"], ["ident_out"]) + graph.node.insert(0, identity_node) + graph.node[1].input[0] = "ident_out" + else: + identity_node = helper.make_node(op, ["div_out", "value"], ["ident_out"]) + graph.node.insert(3, identity_node) + graph.node[-1].input[0] = "ident_out" model.set_initializer("value", val) return model @@ -30,7 +35,8 @@ def insert_identity_op(model, op): # identity operations to be inserted @pytest.mark.parametrize("op", ["Add", "Sub", "Mul", "Div"]) -def test_remove_identity_ops(op): +@pytest.mark.parametrize("as_first_node", [False, True]) +def test_remove_identity_ops(op, as_first_node): # set up onnx model inp = helper.make_tensor_value_info("inp", TensorProto.FLOAT, [1, 4, 1, 1]) @@ -64,7 +70,7 @@ def test_remove_identity_ops(op): model.set_initializer("shape", shape_values) model.set_initializer("div", div_values) model.set_initializer("matmul", matmul_values) - insert_identity_op(model, op) + insert_identity_op(model, op, as_first_node) model = model.transform(InferShapes()) model = model.transform(InferDataTypes()) idict = {"inp": inp_values}