From d9cf7042eccb81d3e1f039fe703713a76f1ca76b Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu <yamanu@xilinx.com> Date: Wed, 30 Oct 2019 11:11:32 +0000 Subject: [PATCH] [Test] add test_move_scalar_mul_past_matmul --- tests/test_move_scalar_mul_past_matmul.py | 41 +++++++++++++++++++++++ 1 file changed, 41 insertions(+) create mode 100644 tests/test_move_scalar_mul_past_matmul.py diff --git a/tests/test_move_scalar_mul_past_matmul.py b/tests/test_move_scalar_mul_past_matmul.py new file mode 100644 index 000000000..55a7d8cbe --- /dev/null +++ b/tests/test_move_scalar_mul_past_matmul.py @@ -0,0 +1,41 @@ +import numpy as np +import onnx.helper as oh +from onnx import TensorProto + +import finn.core.onnx_exec as ox +import finn.transformation.infer_shapes as si +import finn.transformation.streamline as tx +from finn.core.modelwrapper import ModelWrapper + + +def test_move_scalar_mul_past_matmul(): + top_in = oh.make_tensor_value_info("top_in", TensorProto.FLOAT, [1, 2]) + mul_param = oh.make_tensor_value_info("mul_param", TensorProto.FLOAT, [1, 1]) + matmul_param = oh.make_tensor_value_info("matmul_param", TensorProto.FLOAT, [2, 2]) + top_out = oh.make_tensor_value_info("top_out", TensorProto.FLOAT, [1, 2]) + modelproto = oh.make_model( + oh.make_graph( + name="test", + inputs=[top_in], + outputs=[top_out], + value_info=[mul_param, matmul_param], + nodes=[ + oh.make_node("Mul", ["top_in", "mul_param"], ["middle"]), + oh.make_node("MatMul", ["middle", "matmul_param"], ["top_out"]), + ], + ) + ) + model = ModelWrapper(modelproto) + model = model.transform_single(si.infer_shapes) + model.set_initializer("mul_param", np.asarray([[3]], dtype=np.float32)) + model.set_initializer( + "matmul_param", np.asarray([[2, 4], [-1, 1]], dtype=np.float32) + ) + new_model = model.transform_repeated(tx.move_scalar_mul_past_matmul) + inp_dict = {"top_in": np.asarray([[-1.0, 1.0]], dtype=np.float32)} + out_orig = ox.execute_onnx(model, inp_dict)["top_out"] + out_transformed = ox.execute_onnx(new_model, inp_dict)["top_out"] + assert np.isclose(out_orig, out_transformed).all() + assert new_model.graph.node[0].op_type == "MatMul" + assert new_model.graph.node[1].op_type == "Mul" + assert new_model.graph.node[0].output[0] == new_model.graph.node[1].input[0] -- GitLab