Skip to content
Snippets Groups Projects
Commit 4b79ea38 authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Refactor] use ONNX textual input for test_move_past_fork testcase

parent 4d0c0c43
No related branches found
No related tags found
No related merge requests found
...@@ -28,9 +28,11 @@ ...@@ -28,9 +28,11 @@
import pytest import pytest
import numpy as np import numpy as np
from onnx import TensorProto, helper import onnx.parser as oprs
from qonnx.core.modelwrapper import ModelWrapper from qonnx.core.modelwrapper import ModelWrapper
from qonnx.transformation.general import GiveUniqueNodeNames
from qonnx.transformation.infer_shapes import InferShapes from qonnx.transformation.infer_shapes import InferShapes
from qonnx.util.basic import get_by_name
import finn.core.onnx_exec as oxe import finn.core.onnx_exec as oxe
from finn.transformation.streamline.reorder import MoveLinearPastFork from finn.transformation.streamline.reorder import MoveLinearPastFork
...@@ -41,67 +43,65 @@ from finn.transformation.streamline.reorder import MoveLinearPastFork ...@@ -41,67 +43,65 @@ from finn.transformation.streamline.reorder import MoveLinearPastFork
# ifmdim # ifmdim
@pytest.mark.parametrize("ifmdim", [-1, 7]) @pytest.mark.parametrize("ifmdim", [-1, 7])
def test_move_past_fork(ch, ifmdim): def test_move_past_fork(ch, ifmdim):
# generate test vectors of correct shape
if ifmdim == -1: if ifmdim == -1:
input_shape = (1, ch) shp = [1, ch]
else: else:
input_shape = (1, ch, ifmdim, ifmdim) shp = [1, ch, ifmdim, ifmdim]
shp_str = str(shp)
input = f"""
<
ir_version: 7,
opset_import: ["" : 9]
>
agraph (float{shp_str} in0) => (float{shp_str} out0)
<
float{shp_str} add0_param,
float{shp_str} mul_shared_param,
float{shp_str} add2_param,
float{shp_str} mul2_param,
float{shp_str} add3_param,
float{shp_str} add4_param,
float{shp_str} mul3_param,
float{shp_str} add6_param
>
{{
top_in = helper.make_tensor_value_info("top_in", TensorProto.FLOAT, input_shape) add0_out = Add(in0, add0_param)
top_out = helper.make_tensor_value_info("top_out", TensorProto.FLOAT, input_shape) mul0_out = Mul(add0_out, mul_shared_param)
mul1_out = Mul(add0_out, mul_shared_param)
num_of_params = 8 add1_out = Add(mul0_out, mul1_out)
value_info = [] add2_out = Add(add1_out, add2_param)
for i in range(num_of_params): mul2_out = Mul(add2_out, mul2_param)
value_info += [ add3_out = Add(mul2_out, add3_param)
helper.make_tensor_value_info("p" + str(i), TensorProto.FLOAT, input_shape) add4_out = Add(mul2_out, add4_param)
] add5_out = Add(add3_out, add4_out)
mul3_out = Mul(add5_out, mul3_param)
add_1_to_move = helper.make_node("Add", ["top_in", "p0"], ["fork1"]) out0 = Add(mul3_out, add6_param)
mul_1_to_move = helper.make_node("Mul", ["t5", "p4"], ["fork2"]) }}
add_2_to_move = helper.make_node("Add", ["fork2", "p5"], ["t6"]) """
mul_1_not_to_move = helper.make_node("Mul", ["t8", "p7"], ["fork3"]) model = oprs.parse_model(input)
modelproto = helper.make_model( model = ModelWrapper(model)
helper.make_graph(
name="test",
inputs=[top_in],
outputs=[top_out],
value_info=value_info,
nodes=[
# fork1
add_1_to_move,
helper.make_node("Mul", ["fork1", "p1"], ["t2"]),
helper.make_node("Mul", ["fork1", "p2"], ["t3"]),
helper.make_node("Add", ["t2", "t3"], ["t4"]),
helper.make_node("Add", ["t4", "p3"], ["t5"]),
# fork2
mul_1_to_move,
add_2_to_move,
helper.make_node("Add", ["fork2", "p6"], ["t7"]),
helper.make_node("Add", ["t6", "t7"], ["t8"]),
# empty branches: do nothing
mul_1_not_to_move,
helper.make_node("Add", ["fork3", "fork3"], ["top_out"]),
],
)
)
model = ModelWrapper(modelproto)
model = model.transform(InferShapes()) model = model.transform(InferShapes())
np.random.seed(0) np.random.seed(0)
for i in range(num_of_params): for tensor_name in model.get_all_tensor_names():
model.set_initializer( if tensor_name.endswith("_param"):
"p" + str(i), np.random.rand(*input_shape).astype(np.float32) pshape = model.get_tensor_shape(tensor_name)
) model.set_initializer(
tensor_name, np.random.rand(*pshape).astype(np.float32)
)
model = model.transform(GiveUniqueNodeNames())
# Transform # Transform
new_model = model.transform(MoveLinearPastFork()) new_model = model.transform(MoveLinearPastFork())
inp_dict = {"top_in": np.random.rand(*input_shape).astype(np.float32)} new_model = new_model.transform(GiveUniqueNodeNames())
inp_dict = {"top_in": np.random.rand(*shp).astype(np.float32)}
# Test # Test
assert oxe.compare_execution(model, new_model, inp_dict) assert oxe.compare_execution(model, new_model, inp_dict)
assert not new_model.is_fork_node(add_1_to_move) nodes = new_model.graph.node
assert not new_model.is_fork_node(mul_1_to_move) assert len(new_model.get_nodes_by_op_type("Add")) == 9
assert not new_model.is_fork_node(add_2_to_move) assert len(new_model.get_nodes_by_op_type("Mul")) == 5
assert new_model.is_fork_node(mul_1_not_to_move) assert not new_model.is_fork_node(get_by_name(nodes, "Add_0"))
assert new_model.is_join_node(get_by_name(nodes, "Add_2"))
assert not new_model.is_fork_node(get_by_name(nodes, "Mul_2"))
assert not new_model.is_join_node(get_by_name(nodes, "Add_5"))
assert len(new_model.graph.node) == 14 assert len(new_model.graph.node) == 14
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment