From f70c3e1c3dbb80cfdde43aa7d9eddf687b1da49f Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <maltanar@gmail.com>
Date: Tue, 1 Jun 2021 11:44:46 +0100
Subject: [PATCH] [Transform] accept approximation in RemoveIdentityOps

---
 src/finn/transformation/streamline/remove.py | 18 +++++++++++++++---
 1 file changed, 15 insertions(+), 3 deletions(-)

diff --git a/src/finn/transformation/streamline/remove.py b/src/finn/transformation/streamline/remove.py
index 0a36b8bbe..0abcf441f 100644
--- a/src/finn/transformation/streamline/remove.py
+++ b/src/finn/transformation/streamline/remove.py
@@ -50,7 +50,13 @@ def _remove_node_and_rewire(model, node):
 
 
 class RemoveIdentityOps(Transformation):
-    """Remove identity ops like Add/Sub with zero or Mul/Div with one"""
+    """Remove identity ops like Add/Sub with zero or Mul/Div with one. A tolerance
+    value (defaults to 1e-05) can be specified during init for the comparison
+    to zero/one."""
+
+    def __init__(self, atol=1e-05):
+        super().__init__()
+        self.atol = atol
 
     def apply(self, model):
         graph = model.graph
@@ -64,7 +70,10 @@ class RemoveIdentityOps(Transformation):
                 and not model.is_join_node(n)
             ):
                 A = model.get_initializer(n.input[1])
-                if A is not None and (A == np.zeros_like(A)).all():
+                if (
+                    A is not None
+                    and np.isclose(A, np.zeros_like(A), atol=self.atol).all()
+                ):
                     _remove_node_and_rewire(model, n)
 
             elif (
@@ -73,7 +82,10 @@ class RemoveIdentityOps(Transformation):
                 and not model.is_join_node(n)
             ):
                 A = model.get_initializer(n.input[1])
-                if A is not None and (A == np.ones_like(A)).all():
+                if (
+                    A is not None
+                    and np.isclose(A, np.ones_like(A), atol=self.atol).all()
+                ):
                     _remove_node_and_rewire(model, n)
         model = model.transform(InferShapes())
         return (model, graph_modified)
-- 
GitLab