From e8fb53c44c6f295a9c65b86ed868744078b42e63 Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu <yamanu@xilinx.com> Date: Thu, 24 Oct 2019 12:07:14 +0100 Subject: [PATCH] [Test] add two test_is_linear variants --- tests/test_is_linear.py | 57 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 57 insertions(+) create mode 100644 tests/test_is_linear.py diff --git a/tests/test_is_linear.py b/tests/test_is_linear.py new file mode 100644 index 000000000..1995604b7 --- /dev/null +++ b/tests/test_is_linear.py @@ -0,0 +1,57 @@ +import onnx.helper as oh +from onnx import TensorProto + +import finn.analysis.topology as ta +import finn.transformation.infer_shapes as si +from finn.core.modelwrapper import ModelWrapper + + +def test_is_linear_linear(): + top_in = oh.make_tensor_value_info("top_in", TensorProto.FLOAT, [2]) + add_param = oh.make_tensor_value_info("add_param", TensorProto.FLOAT, [2]) + mul_param = oh.make_tensor_value_info("mul_param", TensorProto.FLOAT, [2]) + top_out = oh.make_tensor_value_info("top_out", TensorProto.FLOAT, [2]) + modelproto = oh.make_model( + oh.make_graph( + name="test", + inputs=[top_in], + outputs=[top_out], + value_info=[add_param, mul_param], + nodes=[ + oh.make_node("Add", ["top_in", "add_param"], ["middle"]), + oh.make_node("Mul", ["middle", "mul_param"], ["top_out"]), + ], + ) + ) + model = ModelWrapper(modelproto) + model = model.transform_single(si.infer_shapes) + ret = model.analysis(ta.is_linear) + assert ret["is_linear"] is True + + +def test_is_linear_forked_node_output(): + top_in = oh.make_tensor_value_info("top_in", TensorProto.FLOAT, [2]) + add_param = oh.make_tensor_value_info("add_param", TensorProto.FLOAT, [2]) + mul0_param = oh.make_tensor_value_info("mul0_param", TensorProto.FLOAT, [2]) + mul1_param = oh.make_tensor_value_info("mul1_param", TensorProto.FLOAT, [2]) + mul0_res = oh.make_tensor_value_info("mul0_res", TensorProto.FLOAT, [2]) + mul1_res = oh.make_tensor_value_info("mul1_res", TensorProto.FLOAT, [2]) + top_out = oh.make_tensor_value_info("top_out", TensorProto.FLOAT, [2]) + modelproto = oh.make_model( + oh.make_graph( + name="test", + inputs=[top_in], + outputs=[top_out], + value_info=[add_param, mul0_param, mul1_param, mul0_res, mul1_res], + nodes=[ + oh.make_node("Add", ["top_in", "add_param"], ["middle"]), + oh.make_node("Mul", ["middle", "mul0_param"], ["mul0_res"]), + oh.make_node("Mul", ["middle", "mul1_param"], ["mul1_res"]), + oh.make_node("Add", ["mul0_res", "mul1_res"], ["top_out"]), + ], + ) + ) + model = ModelWrapper(modelproto) + model = model.transform_single(si.infer_shapes) + ret = model.analysis(ta.is_linear) + assert ret["is_linear"] is False -- GitLab