diff --git a/src/finn/core/modelwrapper.py b/src/finn/core/modelwrapper.py index cdf99dc3bd8b698bec60d79ef6e34640ac3b740c..ed32426abcc8ea71428a7f746a99454e8e4a2c17 100644 --- a/src/finn/core/modelwrapper.py +++ b/src/finn/core/modelwrapper.py @@ -333,6 +333,22 @@ class ModelWrapper: else: return None + def is_fork_node(self, node): + """Checks if the given node is a fork, that is, the node has multiple + direct successors""" + direct_successors = self.find_direct_successors(node) + is_fork = False if direct_successors is None else (len(direct_successors) > 1) + return is_fork + + def is_join_node(self, node): + """Checks if the given node is a join, that is, the node has multiple + direct predecessors""" + direct_predecessors = self.find_direct_predecessors(node) + is_join = ( + False if direct_predecessors is None else (len(direct_predecessors) > 1) + ) + return is_join + def get_all_tensor_names(self): """Returns a list of all (input, output and value_info) tensor names in the graph.""" diff --git a/src/finn/transformation/streamline/reorder.py b/src/finn/transformation/streamline/reorder.py index 1886c785705161c3a13493de44dc3f3f86463f4f..b91ffdb3f731d27d9a6ba68b090f3881e6d7293a 100644 --- a/src/finn/transformation/streamline/reorder.py +++ b/src/finn/transformation/streamline/reorder.py @@ -36,8 +36,9 @@ from finn.util.basic import get_by_name class MoveAddPastMul(Transformation): - """Move add operations past multiply operations. The aim is to have them - next to each other such that they can be collapsed into a single add.""" + """Move add operations past multiply operations on linear segments of the graph. + The aim is to have them next to each other such that they can be collapsed into + a single add.""" def apply(self, model): graph = model.graph @@ -45,9 +46,17 @@ class MoveAddPastMul(Transformation): graph_modified = False for n in graph.node: node_ind += 1 - if n.op_type == "Add": + if ( + n.op_type == "Add" + and not model.is_fork_node(n) + and not model.is_join_node(n) + ): consumer = model.find_consumer(n.output[0]) - if consumer is not None and consumer.op_type == "Mul": + if ( + consumer is not None + and consumer.op_type == "Mul" + and not model.is_join_node(consumer) + ): # have: (x) -> add(,B) -> (x+B) -> mul(,A) -> (xA+BA) # want: (x) -> mul(,A) -> (xA) -> add(,BA) -> (xA+BA) # assume input 0 is from the previous layer, input 1 is the @@ -63,12 +72,16 @@ class MoveAddPastMul(Transformation): end_name = consumer.output[0] # compute new param value for add BA = B * A + # make and insert new nodes new_mul = oh.make_node( - "Mul", [start_name, mul_weight_name], [middle_name] + "Mul", + [start_name, mul_weight_name], + [middle_name], + name=consumer.name, ) new_add = oh.make_node( - "Add", [middle_name, add_weight_name], [end_name] + "Add", [middle_name, add_weight_name], [end_name], name=n.name ) graph.node.insert(node_ind, new_mul) graph.node.insert(node_ind + 1, new_add) @@ -78,6 +91,7 @@ class MoveAddPastMul(Transformation): graph.node.remove(n) graph.node.remove(consumer) graph_modified = True + model = model.transform(InferShapes()) return (model, graph_modified) @@ -92,9 +106,17 @@ class MoveScalarMulPastMatMul(Transformation): graph_modified = False for n in graph.node: node_ind += 1 - if n.op_type == "Mul": + if ( + n.op_type == "Mul" + and not model.is_fork_node(n) + and not model.is_join_node(n) + ): consumer = model.find_consumer(n.output[0]) - if consumer is not None and consumer.op_type == "MatMul": + if ( + consumer is not None + and consumer.op_type == "MatMul" + and not model.is_join_node(consumer) + ): mul_weight_name = n.input[1] matmul_weight_name = consumer.input[1] A = model.get_initializer(mul_weight_name) @@ -109,10 +131,16 @@ class MoveScalarMulPastMatMul(Transformation): # if the mul is scalar, we can simply swap the order of ops # make and insert new nodes new_matmul = oh.make_node( - "MatMul", [start_name, matmul_weight_name], [middle_name] + "MatMul", + [start_name, matmul_weight_name], + [middle_name], + name=consumer.name, ) new_mul = oh.make_node( - "Mul", [middle_name, mul_weight_name], [end_name] + "Mul", + [middle_name, mul_weight_name], + [end_name], + name=n.name, ) graph.node.insert(node_ind, new_matmul) graph.node.insert(node_ind + 1, new_mul) @@ -135,9 +163,17 @@ class MoveScalarAddPastMatMul(Transformation): graph_modified = False for n in graph.node: node_ind += 1 - if n.op_type == "Add": + if ( + n.op_type == "Add" + and not model.is_fork_node(n) + and not model.is_join_node(n) + ): consumer = model.find_consumer(n.output[0]) - if consumer is not None and consumer.op_type == "MatMul": + if ( + consumer is not None + and consumer.op_type == "MatMul" + and not model.is_join_node(consumer) + ): add_weight_name = n.input[1] matmul_weight_name = consumer.input[1] A = model.get_initializer(add_weight_name) @@ -155,10 +191,16 @@ class MoveScalarAddPastMatMul(Transformation): # update the add weight model.set_initializer(add_weight_name, Anew) new_matmul = oh.make_node( - "MatMul", [start_name, matmul_weight_name], [middle_name] + "MatMul", + [start_name, matmul_weight_name], + [middle_name], + name=consumer.name, ) new_add = oh.make_node( - "Add", [middle_name, add_weight_name], [end_name] + "Add", + [middle_name, add_weight_name], + [end_name], + name=n.name, ) graph.node.insert(node_ind, new_matmul) graph.node.insert(node_ind + 1, new_add) @@ -181,9 +223,17 @@ class MoveScalarAddPastConv(Transformation): graph_modified = False for n in graph.node: node_ind += 1 - if n.op_type == "Add": + if ( + n.op_type == "Add" + and not model.is_fork_node(n) + and not model.is_join_node(n) + ): consumer = model.find_consumer(n.output[0]) - if consumer is not None and consumer.op_type == "Conv": + if ( + consumer is not None + and consumer.op_type == "Conv" + and not model.is_join_node(consumer) + ): conv_node = consumer add_node = n add_weight_name = n.input[1] @@ -238,9 +288,17 @@ class MoveScalarMulPastConv(Transformation): graph_modified = False for n in graph.node: node_ind += 1 - if n.op_type == "Mul": + if ( + n.op_type == "Mul" + and not model.is_fork_node(n) + and not model.is_join_node(n) + ): consumer = model.find_consumer(n.output[0]) - if consumer is not None and consumer.op_type == "Conv": + if ( + consumer is not None + and consumer.op_type == "Conv" + and not model.is_join_node(consumer) + ): mul_weight_name = n.input[1] A = model.get_initializer(mul_weight_name) assert A is not None, "Initializer for mul weights is not set." diff --git a/tests/core/test_modelwrapper.py b/tests/core/test_modelwrapper.py index d1da6934a5db07aabe41a9ca40b5de497b6460a1..4bd9385536bc6721c66726169dfa4c69e5f06772 100644 --- a/tests/core/test_modelwrapper.py +++ b/tests/core/test_modelwrapper.py @@ -127,3 +127,45 @@ def test_modelwrapper_graph_order(): assert model.get_node_index(Round_node) == 1 assert model.get_node_index(Ceil_node) == 2 assert model.get_node_index(Add_node) == 3 + + +def test_modelwrapper_detect_forks_n_joins(): + # create small network with properties to be tested + Neg_node = onnx.helper.make_node("Neg", inputs=["in1"], outputs=["neg1"]) + Round_node = onnx.helper.make_node("Round", inputs=["neg1"], outputs=["round1"]) + + Ceil_node = onnx.helper.make_node("Ceil", inputs=["neg1"], outputs=["ceil1"]) + Add_node = onnx.helper.make_node( + "Add", inputs=["round1", "ceil1"], outputs=["out1"] + ) + + in1 = onnx.helper.make_tensor_value_info("in1", onnx.TensorProto.FLOAT, [4, 4]) + out1 = onnx.helper.make_tensor_value_info("out1", onnx.TensorProto.FLOAT, [4, 4]) + + graph = onnx.helper.make_graph( + nodes=[Neg_node, Round_node, Ceil_node, Add_node], + name="simple_graph", + inputs=[in1], + outputs=[out1], + value_info=[ + onnx.helper.make_tensor_value_info("neg1", onnx.TensorProto.FLOAT, [4, 4]), + onnx.helper.make_tensor_value_info( + "round1", onnx.TensorProto.FLOAT, [4, 4] + ), + onnx.helper.make_tensor_value_info("ceil1", onnx.TensorProto.FLOAT, [4, 4]), + ], + ) + + onnx_model = onnx.helper.make_model(graph, producer_name="simple-model") + model = ModelWrapper(onnx_model) + + # test + assert model.is_fork_node(Neg_node) + assert not model.is_fork_node(Round_node) + assert not model.is_fork_node(Ceil_node) + assert not model.is_fork_node(Add_node) + + assert not model.is_join_node(Neg_node) + assert not model.is_join_node(Round_node) + assert not model.is_join_node(Ceil_node) + assert model.is_join_node(Add_node) diff --git a/tests/transformation/test_move_add_past_mul.py b/tests/transformation/test_move_add_past_mul.py index a0516d6fb2ff985fc112185ce99ad8facd841caf..163b9d310a5f12bd0b854f9aa46f53a549bf109e 100644 --- a/tests/transformation/test_move_add_past_mul.py +++ b/tests/transformation/test_move_add_past_mul.py @@ -60,6 +60,9 @@ def test_move_add_past_mul_single(): new_model = model.transform(MoveAddPastMul()) inp_dict = {"top_in": np.asarray([-1.0, 1.0], dtype=np.float32)} assert ox.compare_execution(model, new_model, inp_dict) + assert new_model.graph.node[0].op_type == "Mul" + assert new_model.graph.node[1].op_type == "Add" + assert new_model.graph.node[0].output[0] == new_model.graph.node[1].input[0] def test_move_add_past_mul_multi(): @@ -92,3 +95,50 @@ def test_move_add_past_mul_multi(): new_model = model.transform(MoveAddPastMul()) inp_dict = {"top_in": np.asarray([-1.0, 1.0], dtype=np.float32)} assert ox.compare_execution(model, new_model, inp_dict) + assert new_model.graph.node[0].op_type == "Mul" + assert new_model.graph.node[1].op_type == "Mul" + assert new_model.graph.node[2].op_type == "Add" + assert new_model.graph.node[3].op_type == "Add" + for i in range(len(new_model.graph.node) - 1): + assert new_model.graph.node[i].output[0] == new_model.graph.node[i + 1].input[0] + + +def test_move_add_past_mul_only_if_linear(): + top_in = oh.make_tensor_value_info("top_in", TensorProto.FLOAT, [2]) + top_out = oh.make_tensor_value_info("top_out", TensorProto.FLOAT, [2]) + + value_info = [oh.make_tensor_value_info("add1_param", TensorProto.FLOAT, [1])] + value_info += [oh.make_tensor_value_info("mul1_param", TensorProto.FLOAT, [1])] + value_info += [oh.make_tensor_value_info("mul2_param", TensorProto.FLOAT, [1])] + value_info += [oh.make_tensor_value_info("mul3_param", TensorProto.FLOAT, [1])] + modelproto = oh.make_model( + oh.make_graph( + name="test", + inputs=[top_in], + outputs=[top_out], + value_info=value_info, + nodes=[ + oh.make_node("Add", ["top_in", "add1_param"], ["t1"]), + oh.make_node("Mul", ["t1", "mul1_param"], ["fork"]), + oh.make_node("Mul", ["fork", "mul2_param"], ["t3"]), + oh.make_node("Add", ["t3", "fork"], ["t4"]), + oh.make_node("Mul", ["t4", "mul3_param"], ["top_out"]), + ], + ) + ) + model = ModelWrapper(modelproto) + model = model.transform(InferShapes()) + + np.random.seed(0) + model.set_initializer("add1_param", np.random.rand(2).astype(np.float32)) + model.set_initializer("mul1_param", np.random.rand(2).astype(np.float32)) + model.set_initializer("mul2_param", np.random.rand(2).astype(np.float32)) + model.set_initializer("mul3_param", np.random.rand(2).astype(np.float32)) + new_model = model.transform(MoveAddPastMul()) + inp_dict = {"top_in": np.random.rand(2).astype(np.float32)} + assert ox.compare_execution(model, new_model, inp_dict) + assert new_model.graph.node[0].op_type == "Mul" + assert new_model.graph.node[1].op_type == "Add" + assert new_model.graph.node[2].op_type == "Mul" + assert new_model.graph.node[3].op_type == "Add" + assert new_model.graph.node[4].op_type == "Mul" diff --git a/tests/transformation/test_move_scalar_past_conv.py b/tests/transformation/test_move_scalar_past_conv.py new file mode 100644 index 0000000000000000000000000000000000000000..9992d17b96ab5f419f3ac495f126ddfa736349a2 --- /dev/null +++ b/tests/transformation/test_move_scalar_past_conv.py @@ -0,0 +1,87 @@ +import numpy as np +import onnx.helper as oh +import pytest +from onnx import TensorProto + +import finn.core.onnx_exec as ox +from finn.core.modelwrapper import ModelWrapper +from finn.transformation.infer_shapes import InferShapes +from finn.transformation.streamline import ( + MoveScalarAddPastConv, + MoveScalarMulPastConv, +) + + +@pytest.mark.parametrize( + "test_args", [("Add", MoveScalarAddPastConv()), ("Mul", MoveScalarMulPastConv())], +) +def test_move_scalar_past_conv_only_if_linear(test_args): + scalar_op = test_args[0] + transf_fxn = test_args[1] + + in_feature_dim = 7 + in_chn = 1 + padding = False + stages = 3 + kernel_size = 3 + + out_feature_dim = ( + in_feature_dim if padding else in_feature_dim - (kernel_size // 2 * 2) * stages + ) + + input_shape = [1, in_chn, in_feature_dim, in_feature_dim] + output_shape = [1, in_chn, out_feature_dim, out_feature_dim] + + conv_param_shape = [in_chn, in_chn, kernel_size, kernel_size] + + conv_config = {} + conv_config["dilations"] = [1, 1] + conv_config["group"] = 1 + conv_config["kernel_shape"] = [kernel_size, kernel_size] + conv_config["pads"] = [0, 0, 0, 0] + conv_config["strides"] = [1, 1] + + top_in = oh.make_tensor_value_info("top_in", TensorProto.FLOAT, input_shape) + top_out = oh.make_tensor_value_info("top_out", TensorProto.FLOAT, output_shape) + + value_info = [oh.make_tensor_value_info("p1", TensorProto.FLOAT, [1])] + value_info += [oh.make_tensor_value_info("p2", TensorProto.FLOAT, conv_param_shape)] + value_info += [oh.make_tensor_value_info("p3", TensorProto.FLOAT, conv_param_shape)] + value_info += [oh.make_tensor_value_info("p4", TensorProto.FLOAT, conv_param_shape)] + value_info += [oh.make_tensor_value_info("p5", TensorProto.FLOAT, conv_param_shape)] + + modelproto = oh.make_model( + oh.make_graph( + name="test", + inputs=[top_in], + outputs=[top_out], + value_info=value_info, + nodes=[ + oh.make_node("Conv", ["top_in", "p2"], ["t1"], **conv_config), + oh.make_node(scalar_op, ["t1", "p1"], ["t2"]), + oh.make_node("Conv", ["t2", "p3"], ["t3"], **conv_config), + oh.make_node("Conv", ["t2", "p4"], ["t4"], **conv_config), + oh.make_node(scalar_op, ["t3", "t4"], ["t5"]), + oh.make_node("Conv", ["t5", "p5"], ["top_out"], **conv_config), + ], + ) + ) + model = ModelWrapper(modelproto) + model = model.transform(InferShapes()) + + np.random.seed(0) + model.set_initializer("p1", *np.random.rand(1).astype(np.float32)) + model.set_initializer("p2", np.random.rand(*conv_param_shape).astype(np.float32)) + model.set_initializer("p3", np.random.rand(*conv_param_shape).astype(np.float32)) + model.set_initializer("p4", np.random.rand(*conv_param_shape).astype(np.float32)) + model.set_initializer("p5", np.random.rand(*conv_param_shape).astype(np.float32)) + new_model = model.transform(transf_fxn) + inp_dict = {"top_in": np.random.rand(*input_shape).astype(np.float32)} + + assert ox.compare_execution(model, new_model, inp_dict) + assert new_model.graph.node[0].op_type == "Conv" + assert new_model.graph.node[1].op_type == scalar_op + assert new_model.graph.node[2].op_type == "Conv" + assert new_model.graph.node[3].op_type == "Conv" + assert new_model.graph.node[4].op_type == scalar_op + assert new_model.graph.node[5].op_type == "Conv" diff --git a/tests/transformation/test_move_scalar_past_matmul.py b/tests/transformation/test_move_scalar_past_matmul.py index 896527e82d8cfa869cb979d1102904c70703a14c..e432dbf4ec1a38551609e5914e2d44968a020908 100644 --- a/tests/transformation/test_move_scalar_past_matmul.py +++ b/tests/transformation/test_move_scalar_past_matmul.py @@ -27,6 +27,7 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import numpy as np +import pytest import onnx.helper as oh from onnx import TensorProto @@ -99,3 +100,56 @@ def test_move_scalar_add_past_matmul(): assert new_model.graph.node[0].op_type == "MatMul" assert new_model.graph.node[1].op_type == "Add" assert new_model.graph.node[0].output[0] == new_model.graph.node[1].input[0] + + +@pytest.mark.parametrize( + "test_args", + [("Add", MoveScalarAddPastMatMul()), ("Mul", MoveScalarMulPastMatMul())], +) +def test_move_scalar_past_matmul_only_if_linear(test_args): + scalar_op = test_args[0] + transf_fxn = test_args[1] + input_shape = [1, 2] + matmul_shape = [2, 2] + top_in = oh.make_tensor_value_info("top_in", TensorProto.FLOAT, input_shape) + top_out = oh.make_tensor_value_info("top_out", TensorProto.FLOAT, input_shape) + + p1 = oh.make_tensor_value_info("p1", TensorProto.FLOAT, [1, 1]) + p2 = oh.make_tensor_value_info("p2", TensorProto.FLOAT, matmul_shape) + p3 = oh.make_tensor_value_info("p3", TensorProto.FLOAT, matmul_shape) + p4 = oh.make_tensor_value_info("p4", TensorProto.FLOAT, matmul_shape) + modelproto = oh.make_model( + oh.make_graph( + name="test", + inputs=[top_in], + outputs=[top_out], + value_info=[p1, p2, p3, p4], + nodes=[ + oh.make_node(scalar_op, ["top_in", "p1"], ["t1"]), + oh.make_node("MatMul", ["t1", "p2"], ["fork"]), + oh.make_node("MatMul", ["fork", "p3"], ["t3"]), + oh.make_node(scalar_op, ["t3", "fork"], ["t4"]), + oh.make_node("MatMul", ["t4", "p4"], ["top_out"]), + ], + ) + ) + model = ModelWrapper(modelproto) + model = model.transform(InferShapes()) + + np.random.seed(0) + model.set_initializer("p1", np.random.rand(1, 1).astype(np.float32)) + model.set_initializer("p2", np.random.rand(*matmul_shape).astype(np.float32)) + model.set_initializer("p3", np.random.rand(*matmul_shape).astype(np.float32)) + model.set_initializer("p4", np.random.rand(*matmul_shape).astype(np.float32)) + + # Transform + new_model = model.transform(transf_fxn) + + # Test + inp_dict = {"top_in": np.random.rand(*input_shape).astype(np.float32)} + assert ox.compare_execution(model, new_model, inp_dict) + assert new_model.graph.node[0].op_type == "MatMul" + assert new_model.graph.node[1].op_type == scalar_op + assert new_model.graph.node[2].op_type == "MatMul" + assert new_model.graph.node[3].op_type == scalar_op + assert new_model.graph.node[4].op_type == "MatMul"