Skip to content
Snippets Groups Projects
Unverified Commit 3a50a3f4 authored by Yaman Umuroglu's avatar Yaman Umuroglu Committed by GitHub
Browse files

Merge pull request #340 from Xilinx/feature/sf_rounding

Remove close-to-identity ops
parents 5b832859 79e77052
No related branches found
No related tags found
No related merge requests found
......@@ -50,7 +50,13 @@ def _remove_node_and_rewire(model, node):
class RemoveIdentityOps(Transformation):
"""Remove identity ops like Add/Sub with zero or Mul/Div with one"""
"""Remove identity ops like Add/Sub with zero or Mul/Div with one. A tolerance
value (defaults to 1e-05) can be specified during init for the comparison
to zero/one."""
def __init__(self, atol=1e-05):
super().__init__()
self.atol = atol
def apply(self, model):
graph = model.graph
......@@ -64,7 +70,10 @@ class RemoveIdentityOps(Transformation):
and not model.is_join_node(n)
):
A = model.get_initializer(n.input[1])
if A is not None and (A == np.zeros_like(A)).all():
if (
A is not None
and np.isclose(A, np.zeros_like(A), atol=self.atol).all()
):
_remove_node_and_rewire(model, n)
elif (
......@@ -73,7 +82,10 @@ class RemoveIdentityOps(Transformation):
and not model.is_join_node(n)
):
A = model.get_initializer(n.input[1])
if A is not None and (A == np.ones_like(A)).all():
if (
A is not None
and np.isclose(A, np.ones_like(A), atol=self.atol).all()
):
_remove_node_and_rewire(model, n)
model = model.transform(InferShapes())
return (model, graph_modified)
......@@ -188,6 +188,10 @@ def test_end2end_mobilenet_streamline():
model = model.transform(GiveReadableTensorNames())
model = model.transform(InferDataTypes())
model.save(build_dir + "/end2end_mobilenet_streamlined.onnx")
assert (
len(model.get_nodes_by_op_type("Add")) == 1
) # only final quantized bias Add op remains
assert len(model.get_nodes_by_op_type("Mul")) == 0 # no Mul ops remain
def test_end2end_mobilenet_lowering():
......
......@@ -11,11 +11,17 @@ from finn.transformation.streamline.remove import RemoveIdentityOps
from finn.util.basic import gen_finn_dt_tensor
def insert_identity_op(model, op, as_first_node):
def insert_identity_op(model, op, as_first_node, approx):
if approx:
zero_val = 0.000001
one_val = 0.999999
else:
zero_val = 0.0
one_val = 1.0
if op in ["Add", "Sub"]:
val = np.asarray([0.0], dtype=np.float32)
val = np.asarray([zero_val], dtype=np.float32)
elif op in ["Mul", "Div"]:
val = np.asarray([1.0], dtype=np.float32)
val = np.asarray([one_val], dtype=np.float32)
else:
return
......@@ -35,8 +41,9 @@ def insert_identity_op(model, op, as_first_node):
# identity operations to be inserted
@pytest.mark.parametrize("op", ["Add", "Sub", "Mul", "Div"])
@pytest.mark.parametrize("approx", [False, True])
@pytest.mark.parametrize("as_first_node", [False, True])
def test_remove_identity_ops(op, as_first_node):
def test_remove_identity_ops(op, as_first_node, approx):
# set up onnx model
inp = helper.make_tensor_value_info("inp", TensorProto.FLOAT, [1, 4, 1, 1])
......@@ -70,7 +77,7 @@ def test_remove_identity_ops(op, as_first_node):
model.set_initializer("shape", shape_values)
model.set_initializer("div", div_values)
model.set_initializer("matmul", matmul_values)
insert_identity_op(model, op, as_first_node)
insert_identity_op(model, op, as_first_node, approx)
model = model.transform(InferShapes())
model = model.transform(InferDataTypes())
idict = {"inp": inp_values}
......@@ -84,4 +91,4 @@ def test_remove_identity_ops(op, as_first_node):
odict = oxe.execute_onnx(model, idict)
out_after = odict["outp"]
assert (out_before == out_after).all()
assert np.isclose(out_before, out_after, atol=1e-3).all()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment