diff --git a/tests/transformation/streamline/test_remove_identity_ops.py b/tests/transformation/streamline/test_remove_identity_ops.py
index 98430fad0e0f4c17d77ddbf44afeeccd44372047..d02e1d39755bf4783cd5dbdc2b88ca0931e02874 100644
--- a/tests/transformation/streamline/test_remove_identity_ops.py
+++ b/tests/transformation/streamline/test_remove_identity_ops.py
@@ -11,11 +11,17 @@ from finn.transformation.streamline.remove import RemoveIdentityOps
 from finn.util.basic import gen_finn_dt_tensor
 
 
-def insert_identity_op(model, op, as_first_node):
+def insert_identity_op(model, op, as_first_node, approx):
+    if approx:
+        zero_val = 0.000001
+        one_val = 0.999999
+    else:
+        zero_val = 0.0
+        one_val = 1.0
     if op in ["Add", "Sub"]:
-        val = np.asarray([0.0], dtype=np.float32)
+        val = np.asarray([zero_val], dtype=np.float32)
     elif op in ["Mul", "Div"]:
-        val = np.asarray([1.0], dtype=np.float32)
+        val = np.asarray([one_val], dtype=np.float32)
     else:
         return
 
@@ -35,8 +41,9 @@ def insert_identity_op(model, op, as_first_node):
 
 # identity operations to be inserted
 @pytest.mark.parametrize("op", ["Add", "Sub", "Mul", "Div"])
+@pytest.mark.parametrize("approx", [False, True])
 @pytest.mark.parametrize("as_first_node", [False, True])
-def test_remove_identity_ops(op, as_first_node):
+def test_remove_identity_ops(op, as_first_node, approx):
 
     # set up onnx model
     inp = helper.make_tensor_value_info("inp", TensorProto.FLOAT, [1, 4, 1, 1])
@@ -70,7 +77,7 @@ def test_remove_identity_ops(op, as_first_node):
     model.set_initializer("shape", shape_values)
     model.set_initializer("div", div_values)
     model.set_initializer("matmul", matmul_values)
-    insert_identity_op(model, op, as_first_node)
+    insert_identity_op(model, op, as_first_node, approx)
     model = model.transform(InferShapes())
     model = model.transform(InferDataTypes())
     idict = {"inp": inp_values}
@@ -84,4 +91,4 @@ def test_remove_identity_ops(op, as_first_node):
 
     odict = oxe.execute_onnx(model, idict)
     out_after = odict["outp"]
-    assert (out_before == out_after).all()
+    assert np.isclose(out_before, out_after, atol=1e-3).all()