diff --git a/src/finn/transformation/streamline/remove.py b/src/finn/transformation/streamline/remove.py index 0a36b8bbe5c05a8226ae647e0061c1551f3b1cbf..0abcf441f9636a52f9194df325874d293530f78a 100644 --- a/src/finn/transformation/streamline/remove.py +++ b/src/finn/transformation/streamline/remove.py @@ -50,7 +50,13 @@ def _remove_node_and_rewire(model, node): class RemoveIdentityOps(Transformation): - """Remove identity ops like Add/Sub with zero or Mul/Div with one""" + """Remove identity ops like Add/Sub with zero or Mul/Div with one. A tolerance + value (defaults to 1e-05) can be specified during init for the comparison + to zero/one.""" + + def __init__(self, atol=1e-05): + super().__init__() + self.atol = atol def apply(self, model): graph = model.graph @@ -64,7 +70,10 @@ class RemoveIdentityOps(Transformation): and not model.is_join_node(n) ): A = model.get_initializer(n.input[1]) - if A is not None and (A == np.zeros_like(A)).all(): + if ( + A is not None + and np.isclose(A, np.zeros_like(A), atol=self.atol).all() + ): _remove_node_and_rewire(model, n) elif ( @@ -73,7 +82,10 @@ class RemoveIdentityOps(Transformation): and not model.is_join_node(n) ): A = model.get_initializer(n.input[1]) - if A is not None and (A == np.ones_like(A)).all(): + if ( + A is not None + and np.isclose(A, np.ones_like(A), atol=self.atol).all() + ): _remove_node_and_rewire(model, n) model = model.transform(InferShapes()) return (model, graph_modified)