diff --git a/src/finn/transformation/streamline/remove.py b/src/finn/transformation/streamline/remove.py index 0a36b8bbe5c05a8226ae647e0061c1551f3b1cbf..0abcf441f9636a52f9194df325874d293530f78a 100644 --- a/src/finn/transformation/streamline/remove.py +++ b/src/finn/transformation/streamline/remove.py @@ -50,7 +50,13 @@ def _remove_node_and_rewire(model, node): class RemoveIdentityOps(Transformation): - """Remove identity ops like Add/Sub with zero or Mul/Div with one""" + """Remove identity ops like Add/Sub with zero or Mul/Div with one. A tolerance + value (defaults to 1e-05) can be specified during init for the comparison + to zero/one.""" + + def __init__(self, atol=1e-05): + super().__init__() + self.atol = atol def apply(self, model): graph = model.graph @@ -64,7 +70,10 @@ class RemoveIdentityOps(Transformation): and not model.is_join_node(n) ): A = model.get_initializer(n.input[1]) - if A is not None and (A == np.zeros_like(A)).all(): + if ( + A is not None + and np.isclose(A, np.zeros_like(A), atol=self.atol).all() + ): _remove_node_and_rewire(model, n) elif ( @@ -73,7 +82,10 @@ class RemoveIdentityOps(Transformation): and not model.is_join_node(n) ): A = model.get_initializer(n.input[1]) - if A is not None and (A == np.ones_like(A)).all(): + if ( + A is not None + and np.isclose(A, np.ones_like(A), atol=self.atol).all() + ): _remove_node_and_rewire(model, n) model = model.transform(InferShapes()) return (model, graph_modified) diff --git a/tests/end2end/test_end2end_mobilenet_v1.py b/tests/end2end/test_end2end_mobilenet_v1.py index 79263a7099b91fb0dbaa10871f7859690ab9e4c2..5bfe8e1ea1b48ed77c40a584d624cc0ecdedb668 100644 --- a/tests/end2end/test_end2end_mobilenet_v1.py +++ b/tests/end2end/test_end2end_mobilenet_v1.py @@ -188,6 +188,10 @@ def test_end2end_mobilenet_streamline(): model = model.transform(GiveReadableTensorNames()) model = model.transform(InferDataTypes()) model.save(build_dir + "/end2end_mobilenet_streamlined.onnx") + assert ( + len(model.get_nodes_by_op_type("Add")) == 1 + ) # only final quantized bias Add op remains + assert len(model.get_nodes_by_op_type("Mul")) == 0 # no Mul ops remain def test_end2end_mobilenet_lowering(): diff --git a/tests/transformation/streamline/test_remove_identity_ops.py b/tests/transformation/streamline/test_remove_identity_ops.py index 98430fad0e0f4c17d77ddbf44afeeccd44372047..d02e1d39755bf4783cd5dbdc2b88ca0931e02874 100644 --- a/tests/transformation/streamline/test_remove_identity_ops.py +++ b/tests/transformation/streamline/test_remove_identity_ops.py @@ -11,11 +11,17 @@ from finn.transformation.streamline.remove import RemoveIdentityOps from finn.util.basic import gen_finn_dt_tensor -def insert_identity_op(model, op, as_first_node): +def insert_identity_op(model, op, as_first_node, approx): + if approx: + zero_val = 0.000001 + one_val = 0.999999 + else: + zero_val = 0.0 + one_val = 1.0 if op in ["Add", "Sub"]: - val = np.asarray([0.0], dtype=np.float32) + val = np.asarray([zero_val], dtype=np.float32) elif op in ["Mul", "Div"]: - val = np.asarray([1.0], dtype=np.float32) + val = np.asarray([one_val], dtype=np.float32) else: return @@ -35,8 +41,9 @@ def insert_identity_op(model, op, as_first_node): # identity operations to be inserted @pytest.mark.parametrize("op", ["Add", "Sub", "Mul", "Div"]) +@pytest.mark.parametrize("approx", [False, True]) @pytest.mark.parametrize("as_first_node", [False, True]) -def test_remove_identity_ops(op, as_first_node): +def test_remove_identity_ops(op, as_first_node, approx): # set up onnx model inp = helper.make_tensor_value_info("inp", TensorProto.FLOAT, [1, 4, 1, 1]) @@ -70,7 +77,7 @@ def test_remove_identity_ops(op, as_first_node): model.set_initializer("shape", shape_values) model.set_initializer("div", div_values) model.set_initializer("matmul", matmul_values) - insert_identity_op(model, op, as_first_node) + insert_identity_op(model, op, as_first_node, approx) model = model.transform(InferShapes()) model = model.transform(InferDataTypes()) idict = {"inp": inp_values} @@ -84,4 +91,4 @@ def test_remove_identity_ops(op, as_first_node): odict = oxe.execute_onnx(model, idict) out_after = odict["outp"] - assert (out_before == out_after).all() + assert np.isclose(out_before, out_after, atol=1e-3).all()