diff --git a/tests/transformation/streamline/test_remove_identity_ops.py b/tests/transformation/streamline/test_remove_identity_ops.py index 98430fad0e0f4c17d77ddbf44afeeccd44372047..d02e1d39755bf4783cd5dbdc2b88ca0931e02874 100644 --- a/tests/transformation/streamline/test_remove_identity_ops.py +++ b/tests/transformation/streamline/test_remove_identity_ops.py @@ -11,11 +11,17 @@ from finn.transformation.streamline.remove import RemoveIdentityOps from finn.util.basic import gen_finn_dt_tensor -def insert_identity_op(model, op, as_first_node): +def insert_identity_op(model, op, as_first_node, approx): + if approx: + zero_val = 0.000001 + one_val = 0.999999 + else: + zero_val = 0.0 + one_val = 1.0 if op in ["Add", "Sub"]: - val = np.asarray([0.0], dtype=np.float32) + val = np.asarray([zero_val], dtype=np.float32) elif op in ["Mul", "Div"]: - val = np.asarray([1.0], dtype=np.float32) + val = np.asarray([one_val], dtype=np.float32) else: return @@ -35,8 +41,9 @@ def insert_identity_op(model, op, as_first_node): # identity operations to be inserted @pytest.mark.parametrize("op", ["Add", "Sub", "Mul", "Div"]) +@pytest.mark.parametrize("approx", [False, True]) @pytest.mark.parametrize("as_first_node", [False, True]) -def test_remove_identity_ops(op, as_first_node): +def test_remove_identity_ops(op, as_first_node, approx): # set up onnx model inp = helper.make_tensor_value_info("inp", TensorProto.FLOAT, [1, 4, 1, 1]) @@ -70,7 +77,7 @@ def test_remove_identity_ops(op, as_first_node): model.set_initializer("shape", shape_values) model.set_initializer("div", div_values) model.set_initializer("matmul", matmul_values) - insert_identity_op(model, op, as_first_node) + insert_identity_op(model, op, as_first_node, approx) model = model.transform(InferShapes()) model = model.transform(InferDataTypes()) idict = {"inp": inp_values} @@ -84,4 +91,4 @@ def test_remove_identity_ops(op, as_first_node): odict = oxe.execute_onnx(model, idict) out_after = odict["outp"] - assert (out_before == out_after).all() + assert np.isclose(out_before, out_after, atol=1e-3).all()