Skip to content
Snippets Groups Projects
Commit 64779fdc authored by Tobi-Alonso's avatar Tobi-Alonso
Browse files

[TEST] Add test to verify that CollapseRepeatedOp only operates on linear...

[TEST] Add test to verify that CollapseRepeatedOp only operates on linear segments. Add check to previous test to verify that the transformation is doing the required job
parent bfd42761
No related merge requests found
......@@ -34,6 +34,7 @@ import finn.core.onnx_exec as ox
from finn.core.modelwrapper import ModelWrapper
from finn.transformation.infer_shapes import InferShapes
from finn.transformation.streamline import CollapseRepeatedAdd, CollapseRepeatedMul
import pytest
def test_collapse_repeated_op():
......@@ -67,3 +68,60 @@ def test_collapse_repeated_op():
new_model = new_model.transform(CollapseRepeatedMul())
inp_dict = {"top_in": np.asarray([-1.0, 1.0], dtype=np.float32)}
assert ox.compare_execution(model, new_model, inp_dict)
assert len(new_model.graph.node) == 2
assert new_model.graph.node[0].op_type == "Add"
assert new_model.graph.node[1].op_type == "Mul"
@pytest.mark.parametrize(
"test_args", [("Add", CollapseRepeatedAdd()), ("Mul", CollapseRepeatedMul())],
)
def test_collapse_repeated_only_if_linear(test_args):
scalar_op = test_args[0]
transf_fxn = test_args[1]
input_shape = [4, 4]
output_shape = input_shape
top_in = oh.make_tensor_value_info("top_in", TensorProto.FLOAT, input_shape)
top_out = oh.make_tensor_value_info("top_out", TensorProto.FLOAT, output_shape)
value_info = [oh.make_tensor_value_info("p1", TensorProto.FLOAT, [1])]
value_info += [oh.make_tensor_value_info("p2", TensorProto.FLOAT, [1])]
value_info += [oh.make_tensor_value_info("p3", TensorProto.FLOAT, [1])]
value_info += [oh.make_tensor_value_info("p4", TensorProto.FLOAT, [1])]
value_info += [oh.make_tensor_value_info("p5", TensorProto.FLOAT, [1])]
modelproto = oh.make_model(
oh.make_graph(
name="test",
inputs=[top_in],
outputs=[top_out],
value_info=value_info,
nodes=[
oh.make_node(scalar_op, ["top_in", "p2"], ["t1"]),
oh.make_node(scalar_op, ["t1", "p1"], ["t2"]),
oh.make_node(scalar_op, ["t2", "p3"], ["t3"]),
oh.make_node(scalar_op, ["t2", "p4"], ["t4"]),
oh.make_node(scalar_op, ["t3", "t4"], ["t5"]),
oh.make_node(scalar_op, ["t5", "p5"], ["top_out"]),
],
)
)
model = ModelWrapper(modelproto)
model = model.transform(InferShapes())
np.random.seed(0)
model.set_initializer("p1", *np.random.rand(1).astype(np.float32))
model.set_initializer("p2", *np.random.rand(1).astype(np.float32))
model.set_initializer("p3", *np.random.rand(1).astype(np.float32))
model.set_initializer("p4", *np.random.rand(1).astype(np.float32))
model.set_initializer("p5", *np.random.rand(1).astype(np.float32))
# Transform
new_model = model.transform(transf_fxn)
# Test
inp_dict = {"top_in": np.random.rand(*input_shape).astype(np.float32)}
assert ox.compare_execution(model, new_model, inp_dict)
assert len(new_model.graph.node) == 5
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment