Skip to content
Snippets Groups Projects
Commit 4b79ea38 authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Refactor] use ONNX textual input for test_move_past_fork testcase

parent 4d0c0c43
No related branches found
No related tags found
No related merge requests found
......@@ -28,9 +28,11 @@
import pytest
import numpy as np
from onnx import TensorProto, helper
import onnx.parser as oprs
from qonnx.core.modelwrapper import ModelWrapper
from qonnx.transformation.general import GiveUniqueNodeNames
from qonnx.transformation.infer_shapes import InferShapes
from qonnx.util.basic import get_by_name
import finn.core.onnx_exec as oxe
from finn.transformation.streamline.reorder import MoveLinearPastFork
......@@ -41,67 +43,65 @@ from finn.transformation.streamline.reorder import MoveLinearPastFork
# ifmdim
@pytest.mark.parametrize("ifmdim", [-1, 7])
def test_move_past_fork(ch, ifmdim):
# generate test vectors of correct shape
if ifmdim == -1:
input_shape = (1, ch)
shp = [1, ch]
else:
input_shape = (1, ch, ifmdim, ifmdim)
shp = [1, ch, ifmdim, ifmdim]
shp_str = str(shp)
input = f"""
<
ir_version: 7,
opset_import: ["" : 9]
>
agraph (float{shp_str} in0) => (float{shp_str} out0)
<
float{shp_str} add0_param,
float{shp_str} mul_shared_param,
float{shp_str} add2_param,
float{shp_str} mul2_param,
float{shp_str} add3_param,
float{shp_str} add4_param,
float{shp_str} mul3_param,
float{shp_str} add6_param
>
{{
top_in = helper.make_tensor_value_info("top_in", TensorProto.FLOAT, input_shape)
top_out = helper.make_tensor_value_info("top_out", TensorProto.FLOAT, input_shape)
num_of_params = 8
value_info = []
for i in range(num_of_params):
value_info += [
helper.make_tensor_value_info("p" + str(i), TensorProto.FLOAT, input_shape)
]
add_1_to_move = helper.make_node("Add", ["top_in", "p0"], ["fork1"])
mul_1_to_move = helper.make_node("Mul", ["t5", "p4"], ["fork2"])
add_2_to_move = helper.make_node("Add", ["fork2", "p5"], ["t6"])
mul_1_not_to_move = helper.make_node("Mul", ["t8", "p7"], ["fork3"])
modelproto = helper.make_model(
helper.make_graph(
name="test",
inputs=[top_in],
outputs=[top_out],
value_info=value_info,
nodes=[
# fork1
add_1_to_move,
helper.make_node("Mul", ["fork1", "p1"], ["t2"]),
helper.make_node("Mul", ["fork1", "p2"], ["t3"]),
helper.make_node("Add", ["t2", "t3"], ["t4"]),
helper.make_node("Add", ["t4", "p3"], ["t5"]),
# fork2
mul_1_to_move,
add_2_to_move,
helper.make_node("Add", ["fork2", "p6"], ["t7"]),
helper.make_node("Add", ["t6", "t7"], ["t8"]),
# empty branches: do nothing
mul_1_not_to_move,
helper.make_node("Add", ["fork3", "fork3"], ["top_out"]),
],
)
)
model = ModelWrapper(modelproto)
add0_out = Add(in0, add0_param)
mul0_out = Mul(add0_out, mul_shared_param)
mul1_out = Mul(add0_out, mul_shared_param)
add1_out = Add(mul0_out, mul1_out)
add2_out = Add(add1_out, add2_param)
mul2_out = Mul(add2_out, mul2_param)
add3_out = Add(mul2_out, add3_param)
add4_out = Add(mul2_out, add4_param)
add5_out = Add(add3_out, add4_out)
mul3_out = Mul(add5_out, mul3_param)
out0 = Add(mul3_out, add6_param)
}}
"""
model = oprs.parse_model(input)
model = ModelWrapper(model)
model = model.transform(InferShapes())
np.random.seed(0)
for i in range(num_of_params):
model.set_initializer(
"p" + str(i), np.random.rand(*input_shape).astype(np.float32)
)
for tensor_name in model.get_all_tensor_names():
if tensor_name.endswith("_param"):
pshape = model.get_tensor_shape(tensor_name)
model.set_initializer(
tensor_name, np.random.rand(*pshape).astype(np.float32)
)
model = model.transform(GiveUniqueNodeNames())
# Transform
new_model = model.transform(MoveLinearPastFork())
inp_dict = {"top_in": np.random.rand(*input_shape).astype(np.float32)}
new_model = new_model.transform(GiveUniqueNodeNames())
inp_dict = {"top_in": np.random.rand(*shp).astype(np.float32)}
# Test
assert oxe.compare_execution(model, new_model, inp_dict)
assert not new_model.is_fork_node(add_1_to_move)
assert not new_model.is_fork_node(mul_1_to_move)
assert not new_model.is_fork_node(add_2_to_move)
assert new_model.is_fork_node(mul_1_not_to_move)
nodes = new_model.graph.node
assert len(new_model.get_nodes_by_op_type("Add")) == 9
assert len(new_model.get_nodes_by_op_type("Mul")) == 5
assert not new_model.is_fork_node(get_by_name(nodes, "Add_0"))
assert new_model.is_join_node(get_by_name(nodes, "Add_2"))
assert not new_model.is_fork_node(get_by_name(nodes, "Mul_2"))
assert not new_model.is_join_node(get_by_name(nodes, "Add_5"))
assert len(new_model.graph.node) == 14
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment