Skip to content
Snippets Groups Projects
Commit 8f7a26b6 authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Refactor] use textual ONNX for test_absorb_opposite_transposes

parent cb5f0f36
No related branches found
No related tags found
No related merge requests found
...@@ -29,8 +29,7 @@ ...@@ -29,8 +29,7 @@
import pytest import pytest
import numpy as np import numpy as np
import onnx.helper as oh import onnx.parser as oprs
from onnx import TensorProto
from qonnx.core.modelwrapper import ModelWrapper from qonnx.core.modelwrapper import ModelWrapper
from qonnx.transformation.infer_shapes import InferShapes from qonnx.transformation.infer_shapes import InferShapes
...@@ -41,38 +40,36 @@ from finn.transformation.streamline.absorb import AbsorbConsecutiveTransposes ...@@ -41,38 +40,36 @@ from finn.transformation.streamline.absorb import AbsorbConsecutiveTransposes
@pytest.mark.streamline @pytest.mark.streamline
def test_absorb_opposite_transposes(): def test_absorb_opposite_transposes():
np.random.seed(0) np.random.seed(0)
input_shape = [1, 3, 4, 2] shp = [1, 3, 4, 2]
top_in = oh.make_tensor_value_info("top_in", TensorProto.FLOAT, input_shape) shp_str = str(shp)
top_out = oh.make_tensor_value_info("top_out", TensorProto.FLOAT, input_shape) input = f"""
value_info = [oh.make_tensor_value_info("add_param_0", TensorProto.FLOAT, [1])] <
value_info += [oh.make_tensor_value_info("add_param_1", TensorProto.FLOAT, [1])] ir_version: 7,
value_info += [oh.make_tensor_value_info("mul_param_0", TensorProto.FLOAT, [1])] opset_import: ["" : 9]
modelproto = oh.make_model( >
oh.make_graph( agraph (float{shp_str} in0) => (float{shp_str} out0)
name="test", <
inputs=[top_in], float[1] add0_param = {{1.0}},
outputs=[top_out], float[1] add1_param = {{3.0}},
value_info=value_info, float[1] mul0_param = {{2.0}}
nodes=[ >
oh.make_node("Add", ["top_in", "add_param_0"], ["t0"]), {{
oh.make_node("Transpose", ["t0"], ["t1"], perm=[0, 2, 3, 1]), add0_out = Add(in0, add0_param)
oh.make_node("Transpose", ["t1"], ["t2"], perm=[0, 3, 1, 2]), t0_out = Transpose<perm=[0,2,3,1]>(add0_out)
oh.make_node("Add", ["t2", "add_param_1"], ["t3"]), t1_out = Transpose<perm=[0,3,1,2]>(t0_out)
oh.make_node("Transpose", ["t3"], ["t4"], perm=[0, 2, 3, 1]), add1_out = Add(t1_out, add1_param)
oh.make_node("Transpose", ["t4"], ["t5"], perm=[0, 3, 1, 2]), t2_out = Transpose<perm=[0,2,3,1]>(add1_out)
oh.make_node("Add", ["t5", "t2"], ["t6"]), t3_out = Transpose<perm=[0,3,1,2]>(t2_out)
oh.make_node("Mul", ["t6", "mul_param_0"], ["top_out"]), add2_out = Add(t1_out, t3_out)
], out0 = Mul(add2_out, mul0_param)
) }}
) """
model = ModelWrapper(modelproto) model = oprs.parse_model(input)
model = ModelWrapper(model)
model = model.transform(InferShapes()) model = model.transform(InferShapes())
model.set_initializer("add_param_0", np.asarray([1], dtype=np.float32))
model.set_initializer("add_param_1", np.asarray([3], dtype=np.float32))
model.set_initializer("mul_param_0", np.asarray([2], dtype=np.float32))
new_model = model.transform(AbsorbConsecutiveTransposes()) new_model = model.transform(AbsorbConsecutiveTransposes())
new_model = new_model.transform(InferShapes()) new_model = new_model.transform(InferShapes())
inp_dict = {"top_in": np.random.rand(*input_shape).astype(np.float32)} inp_dict = {"top_in": np.random.rand(*shp).astype(np.float32)}
assert ox.compare_execution(model, model, inp_dict) assert ox.compare_execution(model, model, inp_dict)
assert len(new_model.graph.node) == 4 assert len(new_model.graph.node) == 4
for n in new_model.graph.node: for n in new_model.graph.node:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment