Skip to content
Snippets Groups Projects
Commit f70c3e1c authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Transform] accept approximation in RemoveIdentityOps

parent aa331d12
No related branches found
No related tags found
No related merge requests found
...@@ -50,7 +50,13 @@ def _remove_node_and_rewire(model, node): ...@@ -50,7 +50,13 @@ def _remove_node_and_rewire(model, node):
class RemoveIdentityOps(Transformation): class RemoveIdentityOps(Transformation):
"""Remove identity ops like Add/Sub with zero or Mul/Div with one""" """Remove identity ops like Add/Sub with zero or Mul/Div with one. A tolerance
value (defaults to 1e-05) can be specified during init for the comparison
to zero/one."""
def __init__(self, atol=1e-05):
super().__init__()
self.atol = atol
def apply(self, model): def apply(self, model):
graph = model.graph graph = model.graph
...@@ -64,7 +70,10 @@ class RemoveIdentityOps(Transformation): ...@@ -64,7 +70,10 @@ class RemoveIdentityOps(Transformation):
and not model.is_join_node(n) and not model.is_join_node(n)
): ):
A = model.get_initializer(n.input[1]) A = model.get_initializer(n.input[1])
if A is not None and (A == np.zeros_like(A)).all(): if (
A is not None
and np.isclose(A, np.zeros_like(A), atol=self.atol).all()
):
_remove_node_and_rewire(model, n) _remove_node_and_rewire(model, n)
elif ( elif (
...@@ -73,7 +82,10 @@ class RemoveIdentityOps(Transformation): ...@@ -73,7 +82,10 @@ class RemoveIdentityOps(Transformation):
and not model.is_join_node(n) and not model.is_join_node(n)
): ):
A = model.get_initializer(n.input[1]) A = model.get_initializer(n.input[1])
if A is not None and (A == np.ones_like(A)).all(): if (
A is not None
and np.isclose(A, np.ones_like(A), atol=self.atol).all()
):
_remove_node_and_rewire(model, n) _remove_node_and_rewire(model, n)
model = model.transform(InferShapes()) model = model.transform(InferShapes())
return (model, graph_modified) return (model, graph_modified)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment