Skip to content
Snippets Groups Projects
Commit 18067458 authored by Tobi-Alonso's avatar Tobi-Alonso
Browse files

[TEST] Add test linear_past_eltwise_add when the graph has multiple joins...

[TEST] Add test linear_past_eltwise_add when the graph has multiple joins where to apply the transformation (was failing in this case) plu test moving linear tensor ops, not only scalar ones
parent 188b6ac5
No related branches found
No related tags found
No related merge requests found
......@@ -35,7 +35,7 @@ import finn.core.onnx_exec as oxe
from finn.core.modelwrapper import ModelWrapper
from finn.transformation.fold_constants import FoldConstants
from finn.transformation.general import GiveReadableTensorNames, GiveUniqueNodeNames
from finn.transformation.streamline.reorder import MoveScalarLinearPastEltwiseAdd
from finn.transformation.streamline.reorder import MoveLinearPastEltwiseAdd
from finn.transformation.infer_shapes import InferShapes
from finn.transformation.double_to_single_float import DoubleToSingleFloat
......@@ -95,7 +95,7 @@ def make_model(shape):
@pytest.mark.parametrize("ch", [64])
# ifmdim
@pytest.mark.parametrize("ifmdim", [-1, 7])
def test_scalar_past_eltwise(ch, ifmdim):
def test_linear_past_eltwise_add(ch, ifmdim):
# generate test vectors of correct shape
if ifmdim == -1:
input_tensor_shape = (1, ch)
......@@ -124,7 +124,7 @@ def test_scalar_past_eltwise(ch, ifmdim):
assert len(model.get_nodes_by_op_type("Add")) == 3
assert len(model.get_nodes_by_op_type("Mul")) == 2
model = model.transform(MoveScalarLinearPastEltwiseAdd())
model = model.transform(MoveLinearPastEltwiseAdd())
# verify again, to check we didnt break anything
output_dict = oxe.execute_onnx(model, input_dict, True)
......@@ -134,3 +134,68 @@ def test_scalar_past_eltwise(ch, ifmdim):
assert len(model.get_nodes_by_op_type("Mul")) == 1
os.remove(export_onnx_path)
@pytest.mark.parametrize("ch", [64, 1])
# ifmdim
@pytest.mark.parametrize("ifmdim", [-1, 7])
def test_linear_past_eltwise_add_multiple_forks(ch, ifmdim):
# generate test vectors of correct shape
if ifmdim == -1:
input_shape = (1, ch)
else:
input_shape = (1, ch, ifmdim, ifmdim)
top_in = helper.make_tensor_value_info("top_in", TensorProto.FLOAT, input_shape)
top_out = helper.make_tensor_value_info("top_out", TensorProto.FLOAT, input_shape)
num_of_params = 6
value_info = []
for i in range(num_of_params):
value_info += [
helper.make_tensor_value_info("p" + str(i), TensorProto.FLOAT, input_shape)
]
modelproto = helper.make_model(
helper.make_graph(
name="test",
inputs=[top_in],
outputs=[top_out],
value_info=value_info,
nodes=[
helper.make_node("Add", ["top_in", "p0"], ["fork1"]),
helper.make_node("Mul", ["fork1", "p1"], ["t2"]),
helper.make_node("Mul", ["fork1", "p2"], ["t3"]),
helper.make_node("Add", ["t2", "t3"], ["t4"]),
helper.make_node("Mul", ["t4", "p3"], ["fork2"]),
helper.make_node("Add", ["fork2", "p4"], ["t5"]),
helper.make_node("Add", ["fork2", "p5"], ["t6"]),
helper.make_node("Add", ["t5", "t6"], ["top_out"]),
],
)
)
model = ModelWrapper(modelproto)
model = model.transform(InferShapes())
np.random.seed(0)
for i in range(num_of_params):
model.set_initializer(
"p" + str(i), np.random.rand(*input_shape).astype(np.float32)
)
# need equal mults:
model.set_initializer("p2", model.get_initializer("p1"))
# Transform
new_model = model.transform(MoveLinearPastEltwiseAdd())
inp_dict = {"top_in": np.random.rand(*input_shape).astype(np.float32)}
# Test
assert oxe.compare_execution(model, new_model, inp_dict)
assert new_model.graph.node[0].op_type == "Add"
assert new_model.graph.node[1].op_type == "Add"
assert new_model.graph.node[2].op_type == "Mul"
assert new_model.graph.node[3].op_type == "Mul"
assert new_model.graph.node[4].op_type == "Add"
assert new_model.graph.node[5].op_type == "Add"
assert len(new_model.graph.node) == 6
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment