diff --git a/src/finn/transformation/streamline/reorder.py b/src/finn/transformation/streamline/reorder.py index 96046602efb32a9262a4cf0bbb21a8367d719910..1886c785705161c3a13493de44dc3f3f86463f4f 100644 --- a/src/finn/transformation/streamline/reorder.py +++ b/src/finn/transformation/streamline/reorder.py @@ -34,8 +34,6 @@ from finn.transformation.infer_shapes import InferShapes from finn.core.onnx_exec import execute_node from finn.util.basic import get_by_name -def is_scalar(x): - return np.prod(x.shape) == 1 class MoveAddPastMul(Transformation): """Move add operations past multiply operations. The aim is to have them @@ -273,12 +271,12 @@ class MoveScalarMulPastConv(Transformation): return (model, graph_modified) -class MoveScalarLinearPastEltwiseAdd(Transformation): - """Move scalar linear operations (mul, add) past elementwise add operations where possible. Specifically, - matches and transforms the following patterns: +class MoveLinearPastEltwiseAdd(Transformation): + """Move linear operations (mul, add) past elementwise add operations where possible. + Specifically,matches and transforms the following patterns: (x*C) + (y*C) -> (x + y) * C (x+A) + (y+B) -> (x + y) + (A + B) - where x and y are dynamic inputs, A, B, C are constants. + where x and y are dynamic inputs, A, B, C are constant tensors (in general). """ def move_node(self, graph, n, prod0, prod1, node_ind): @@ -305,7 +303,8 @@ class MoveScalarLinearPastEltwiseAdd(Transformation): graph = model.graph node_ind = 0 graph_modified = False - for n in graph.node: + nodes = [n for n in graph.node] + for n in nodes: node_ind += 1 if n.op_type == "Add": # check for tensors on both inputs (eltwise add) @@ -321,17 +320,16 @@ class MoveScalarLinearPastEltwiseAdd(Transformation): # check for mul with same initializer on both inputs prod0 = model.find_producer(in0) prod1 = model.find_producer(in1) - if prod0 is None or prod1 is None: + # Also check case when both branches are empty and come + # from the same node: (prod0 == prod1) + # Other transform should handle that + if prod0 is None or prod1 is None or (prod0 == prod1): continue init0 = model.get_initializer(prod0.input[1]) init1 = model.get_initializer(prod1.input[1]) # if either initializer is None, skip if init0 is None or init1 is None: continue - # if either initializer is non-scalar, skip - # TODO relax this to 1D tensors? - if (not is_scalar(init0)) or (not is_scalar(init1)): - continue if prod0.op_type == "Mul" and prod1.op_type == "Mul": if np.array_equal(init0, init1): self.move_node(graph, n, prod0, prod1, node_ind) diff --git a/tests/transformation/test_scalar_past_eltwise.py b/tests/transformation/test_linear_past_eltwise.py similarity index 69% rename from tests/transformation/test_scalar_past_eltwise.py rename to tests/transformation/test_linear_past_eltwise.py index e845f32176a9293046b297b7d9e2ab64fabc1791..b77f59779a5e8559f80e017d13b66bcb67249830 100644 --- a/tests/transformation/test_scalar_past_eltwise.py +++ b/tests/transformation/test_linear_past_eltwise.py @@ -35,7 +35,7 @@ import finn.core.onnx_exec as oxe from finn.core.modelwrapper import ModelWrapper from finn.transformation.fold_constants import FoldConstants from finn.transformation.general import GiveReadableTensorNames, GiveUniqueNodeNames -from finn.transformation.streamline.reorder import MoveScalarLinearPastEltwiseAdd +from finn.transformation.streamline.reorder import MoveLinearPastEltwiseAdd from finn.transformation.infer_shapes import InferShapes from finn.transformation.double_to_single_float import DoubleToSingleFloat @@ -95,7 +95,7 @@ def make_model(shape): @pytest.mark.parametrize("ch", [64]) # ifmdim @pytest.mark.parametrize("ifmdim", [-1, 7]) -def test_scalar_past_eltwise(ch, ifmdim): +def test_linear_past_eltwise_add(ch, ifmdim): # generate test vectors of correct shape if ifmdim == -1: input_tensor_shape = (1, ch) @@ -124,7 +124,7 @@ def test_scalar_past_eltwise(ch, ifmdim): assert len(model.get_nodes_by_op_type("Add")) == 3 assert len(model.get_nodes_by_op_type("Mul")) == 2 - model = model.transform(MoveScalarLinearPastEltwiseAdd()) + model = model.transform(MoveLinearPastEltwiseAdd()) # verify again, to check we didnt break anything output_dict = oxe.execute_onnx(model, input_dict, True) @@ -134,3 +134,68 @@ def test_scalar_past_eltwise(ch, ifmdim): assert len(model.get_nodes_by_op_type("Mul")) == 1 os.remove(export_onnx_path) + + +@pytest.mark.parametrize("ch", [64, 1]) +# ifmdim +@pytest.mark.parametrize("ifmdim", [-1, 7]) +def test_linear_past_eltwise_add_multiple_forks(ch, ifmdim): + # generate test vectors of correct shape + if ifmdim == -1: + input_shape = (1, ch) + else: + input_shape = (1, ch, ifmdim, ifmdim) + + top_in = helper.make_tensor_value_info("top_in", TensorProto.FLOAT, input_shape) + top_out = helper.make_tensor_value_info("top_out", TensorProto.FLOAT, input_shape) + + num_of_params = 6 + value_info = [] + for i in range(num_of_params): + value_info += [ + helper.make_tensor_value_info("p" + str(i), TensorProto.FLOAT, input_shape) + ] + + modelproto = helper.make_model( + helper.make_graph( + name="test", + inputs=[top_in], + outputs=[top_out], + value_info=value_info, + nodes=[ + helper.make_node("Add", ["top_in", "p0"], ["fork1"]), + helper.make_node("Mul", ["fork1", "p1"], ["t2"]), + helper.make_node("Mul", ["fork1", "p2"], ["t3"]), + helper.make_node("Add", ["t2", "t3"], ["t4"]), + helper.make_node("Mul", ["t4", "p3"], ["fork2"]), + helper.make_node("Add", ["fork2", "p4"], ["t5"]), + helper.make_node("Add", ["fork2", "p5"], ["t6"]), + helper.make_node("Add", ["t5", "t6"], ["top_out"]), + ], + ) + ) + model = ModelWrapper(modelproto) + model = model.transform(InferShapes()) + + np.random.seed(0) + for i in range(num_of_params): + model.set_initializer( + "p" + str(i), np.random.rand(*input_shape).astype(np.float32) + ) + + # need equal mults: + model.set_initializer("p2", model.get_initializer("p1")) + + # Transform + new_model = model.transform(MoveLinearPastEltwiseAdd()) + inp_dict = {"top_in": np.random.rand(*input_shape).astype(np.float32)} + + # Test + assert oxe.compare_execution(model, new_model, inp_dict) + assert new_model.graph.node[0].op_type == "Add" + assert new_model.graph.node[1].op_type == "Add" + assert new_model.graph.node[2].op_type == "Mul" + assert new_model.graph.node[3].op_type == "Mul" + assert new_model.graph.node[4].op_type == "Add" + assert new_model.graph.node[5].op_type == "Add" + assert len(new_model.graph.node) == 6