diff --git a/src/finn/transformation/streamline/collapse_repeated.py b/src/finn/transformation/streamline/collapse_repeated.py index aa059747b602bc6b659bc8b53b1f18988bba1ef0..67824ad4f633983b93e3178d03118927a1ddd85b 100644 --- a/src/finn/transformation/streamline/collapse_repeated.py +++ b/src/finn/transformation/streamline/collapse_repeated.py @@ -48,9 +48,17 @@ class CollapseRepeatedOp(Transformation): graph_modified = False for n in graph.node: node_ind += 1 - if n.op_type == self.op_name: + if ( + n.op_type == self.op_name + and not model.is_fork_node(n) + and not model.is_join_node(n) + ): consumer = model.find_consumer(n.output[0]) - if consumer is not None and consumer.op_type == self.op_name: + if ( + consumer is not None + and consumer.op_type == self.op_name + and not model.is_join_node(consumer) + ): op0_param_name = n.input[1] op1_param_name = consumer.input[1] op0_param = model.get_initializer(op0_param_name) diff --git a/tests/transformation/test_collapse_repeated_op.py b/tests/transformation/test_collapse_repeated_op.py index 01d932ece0be4b0beb7ad6094284ec3efb1e525e..b74d868f9b921c35ff9f596c811583f45f761374 100644 --- a/tests/transformation/test_collapse_repeated_op.py +++ b/tests/transformation/test_collapse_repeated_op.py @@ -34,6 +34,7 @@ import finn.core.onnx_exec as ox from finn.core.modelwrapper import ModelWrapper from finn.transformation.infer_shapes import InferShapes from finn.transformation.streamline import CollapseRepeatedAdd, CollapseRepeatedMul +import pytest def test_collapse_repeated_op(): @@ -67,3 +68,60 @@ def test_collapse_repeated_op(): new_model = new_model.transform(CollapseRepeatedMul()) inp_dict = {"top_in": np.asarray([-1.0, 1.0], dtype=np.float32)} assert ox.compare_execution(model, new_model, inp_dict) + assert len(new_model.graph.node) == 2 + assert new_model.graph.node[0].op_type == "Add" + assert new_model.graph.node[1].op_type == "Mul" + + +@pytest.mark.parametrize( + "test_args", [("Add", CollapseRepeatedAdd()), ("Mul", CollapseRepeatedMul())], +) +def test_collapse_repeated_only_if_linear(test_args): + scalar_op = test_args[0] + transf_fxn = test_args[1] + + input_shape = [4, 4] + output_shape = input_shape + + top_in = oh.make_tensor_value_info("top_in", TensorProto.FLOAT, input_shape) + top_out = oh.make_tensor_value_info("top_out", TensorProto.FLOAT, output_shape) + + value_info = [oh.make_tensor_value_info("p1", TensorProto.FLOAT, [1])] + value_info += [oh.make_tensor_value_info("p2", TensorProto.FLOAT, [1])] + value_info += [oh.make_tensor_value_info("p3", TensorProto.FLOAT, [1])] + value_info += [oh.make_tensor_value_info("p4", TensorProto.FLOAT, [1])] + value_info += [oh.make_tensor_value_info("p5", TensorProto.FLOAT, [1])] + + modelproto = oh.make_model( + oh.make_graph( + name="test", + inputs=[top_in], + outputs=[top_out], + value_info=value_info, + nodes=[ + oh.make_node(scalar_op, ["top_in", "p2"], ["t1"]), + oh.make_node(scalar_op, ["t1", "p1"], ["t2"]), + oh.make_node(scalar_op, ["t2", "p3"], ["t3"]), + oh.make_node(scalar_op, ["t2", "p4"], ["t4"]), + oh.make_node(scalar_op, ["t3", "t4"], ["t5"]), + oh.make_node(scalar_op, ["t5", "p5"], ["top_out"]), + ], + ) + ) + model = ModelWrapper(modelproto) + model = model.transform(InferShapes()) + + np.random.seed(0) + model.set_initializer("p1", *np.random.rand(1).astype(np.float32)) + model.set_initializer("p2", *np.random.rand(1).astype(np.float32)) + model.set_initializer("p3", *np.random.rand(1).astype(np.float32)) + model.set_initializer("p4", *np.random.rand(1).astype(np.float32)) + model.set_initializer("p5", *np.random.rand(1).astype(np.float32)) + + # Transform + new_model = model.transform(transf_fxn) + + # Test + inp_dict = {"top_in": np.random.rand(*input_shape).astype(np.float32)} + assert ox.compare_execution(model, new_model, inp_dict) + assert len(new_model.graph.node) == 5