diff --git a/src/finn/transformation/streamline/reorder.py b/src/finn/transformation/streamline/reorder.py index 85354fac87f38c6a0ae424f3aeec24a72a36aad0..0c800afa3021c22a4c7840d53c4c4cedcc0a72d5 100644 --- a/src/finn/transformation/streamline/reorder.py +++ b/src/finn/transformation/streamline/reorder.py @@ -360,3 +360,90 @@ class MakeMaxPoolNHWC(Transformation): graph.node.insert(node_ind - 1, consumer) graph_modified = True return (model, graph_modified) + + +class MoveOpPastFork(Transformation): + """Move node operations past graph forks. Used when a node before a fork + can be merged with nodes in the branches + """ + + def __init__(self, op_name_list): + super().__init__() + self.ops_to_move = op_name_list + + def apply(self, model): + graph = model.graph + graph_modified = False + nodes = [n for n in graph.node] + node_ind = 0 + for n in nodes: + node_ind += 1 + if ( + n.op_type in self.ops_to_move + and model.is_fork_node(n) + and not model.is_join_node(n) + ): + + # Restrict this transform to operations with constant parameters + # Assuming parameters is in input 1 + op_init_param = model.get_initializer(n.input[1]) + if op_init_param is None: + continue + + # Check case when branches are empty and go + # to the same node + consumers = model.find_consumers(n.output[0]) + unique_consumer = True + for consum_node in consumers[1:]: + if consumers[0] != consum_node: + unique_consumer = False + break + + if unique_consumer: + continue + + for consumer_node in consumers[1:]: + # create new node + new_param_name = model.make_new_valueinfo_name() + new_output_tensor_name = model.make_new_valueinfo_name() + new_node = oh.make_node( + n.op_type, + [n.input[0], new_param_name], + [new_output_tensor_name], + ) + graph.node.insert(node_ind, new_node) + node_ind += 1 + model.set_initializer(new_param_name, op_init_param) + + # change consumer input tensor + graph.node.remove(consumer_node) + for idx, consumer_input in enumerate(consumer_node.input): + if consumer_input == n.output[0]: + consumer_node.input[idx] = new_output_tensor_name + break + else: + raise Exception( + "Consumer should have the current node output as input" + ) + + graph.node.insert(node_ind, consumer_node) + + graph_modified = True + + model = model.transform(InferShapes()) + return (model, graph_modified) + + +class MoveAddPastFork(MoveOpPastFork): + def __init__(self): + super().__init__(["Add"]) + + +class MoveMulPastFork(MoveOpPastFork): + def __init__(self): + super().__init__(["Mul"]) + + +class MoveLinearPastFork(MoveOpPastFork): + def __init__(self): + super().__init__(["Add", "Mul"])