From f70c3e1c3dbb80cfdde43aa7d9eddf687b1da49f Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu <maltanar@gmail.com> Date: Tue, 1 Jun 2021 11:44:46 +0100 Subject: [PATCH] [Transform] accept approximation in RemoveIdentityOps --- src/finn/transformation/streamline/remove.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/src/finn/transformation/streamline/remove.py b/src/finn/transformation/streamline/remove.py index 0a36b8bbe..0abcf441f 100644 --- a/src/finn/transformation/streamline/remove.py +++ b/src/finn/transformation/streamline/remove.py @@ -50,7 +50,13 @@ def _remove_node_and_rewire(model, node): class RemoveIdentityOps(Transformation): - """Remove identity ops like Add/Sub with zero or Mul/Div with one""" + """Remove identity ops like Add/Sub with zero or Mul/Div with one. A tolerance + value (defaults to 1e-05) can be specified during init for the comparison + to zero/one.""" + + def __init__(self, atol=1e-05): + super().__init__() + self.atol = atol def apply(self, model): graph = model.graph @@ -64,7 +70,10 @@ class RemoveIdentityOps(Transformation): and not model.is_join_node(n) ): A = model.get_initializer(n.input[1]) - if A is not None and (A == np.zeros_like(A)).all(): + if ( + A is not None + and np.isclose(A, np.zeros_like(A), atol=self.atol).all() + ): _remove_node_and_rewire(model, n) elif ( @@ -73,7 +82,10 @@ class RemoveIdentityOps(Transformation): and not model.is_join_node(n) ): A = model.get_initializer(n.input[1]) - if A is not None and (A == np.ones_like(A)).all(): + if ( + A is not None + and np.isclose(A, np.ones_like(A), atol=self.atol).all() + ): _remove_node_and_rewire(model, n) model = model.transform(InferShapes()) return (model, graph_modified) -- GitLab