Skip to content
Snippets Groups Projects
Commit f70c3e1c authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Transform] accept approximation in RemoveIdentityOps

parent aa331d12
No related branches found
No related tags found
No related merge requests found
......@@ -50,7 +50,13 @@ def _remove_node_and_rewire(model, node):
class RemoveIdentityOps(Transformation):
"""Remove identity ops like Add/Sub with zero or Mul/Div with one"""
"""Remove identity ops like Add/Sub with zero or Mul/Div with one. A tolerance
value (defaults to 1e-05) can be specified during init for the comparison
to zero/one."""
def __init__(self, atol=1e-05):
super().__init__()
self.atol = atol
def apply(self, model):
graph = model.graph
......@@ -64,7 +70,10 @@ class RemoveIdentityOps(Transformation):
and not model.is_join_node(n)
):
A = model.get_initializer(n.input[1])
if A is not None and (A == np.zeros_like(A)).all():
if (
A is not None
and np.isclose(A, np.zeros_like(A), atol=self.atol).all()
):
_remove_node_and_rewire(model, n)
elif (
......@@ -73,7 +82,10 @@ class RemoveIdentityOps(Transformation):
and not model.is_join_node(n)
):
A = model.get_initializer(n.input[1])
if A is not None and (A == np.ones_like(A)).all():
if (
A is not None
and np.isclose(A, np.ones_like(A), atol=self.atol).all()
):
_remove_node_and_rewire(model, n)
model = model.transform(InferShapes())
return (model, graph_modified)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment