diff --git a/src/finn/transformation/merge_onnx_models.py b/src/finn/transformation/merge_onnx_models.py new file mode 100644 index 0000000000000000000000000000000000000000..4f6a3fe38f35413afc1b5990ea5aaf84b06d02f5 --- /dev/null +++ b/src/finn/transformation/merge_onnx_models.py @@ -0,0 +1,111 @@ +from onnx import helper + +from finn.transformation import Transformation +from finn.core.modelwrapper import ModelWrapper +import finn.util.basic as util +from finn.transformation.infer_shapes import InferShapes +from finn.transformation.infer_datatypes import InferDataTypes +from finn.transformation.general import ( + GiveReadableTensorNames, + GiveUniqueNodeNames, + GiveUniqueParameterTensors, +) + + +def _make_model_values_unique(model1, model2): + # ensure that tensor and node names are different in each model + # tensors + names1 = model1.get_all_tensor_names() + names2 = model2.get_all_tensor_names() + duplicates = list(set(names1).intersection(names2)) + # if there are duplicates in the tensor names rename these tensors + if duplicates: + used_names = names1 + names2 + for name in duplicates: + # model1 + new_name = util.random_string() + while new_name in used_names: + new_name = util.random_string() + model1.rename_tensor(name, new_name) + used_names.append(new_name) + + # model2 + new_name = util.random_string() + while new_name in used_names: + new_name = util.random_string() + model2.rename_tensor(name, new_name) + used_names.append(new_name) + + # nodes + names1 = [x.name for x in model1.graph.node] + names1 = list(filter(None, names1)) # filter out empty node names + names2 = [x.name for x in model2.graph.node] + names2 = list(filter(None, names2)) + duplicates = list(set(names1).intersection(names2)) + # if there are duplicates erase all node names + if duplicates: + for n in model1.graph.node: + n.name = "" + for n in model2.graph.node: + n.name = "" + + return (model1, model2) + + +class MergeONNXModels(Transformation): + def __init__(self, pre_proc_model): + super().__init__() + self.pre_proc_model = pre_proc_model + + def apply(self, model): + graph_modified = False + pre_proc_model = self.pre_proc_model + + (pre_proc_model, model) = _make_model_values_unique(pre_proc_model, model) + + node_list_a = pre_proc_model.graph.node + node_list_b = model.graph.node + node_list = node_list_a + node_list[-1].output[0] = node_list_b[0].input[0] + for node in node_list_b: + node_list.append(node) + inp = pre_proc_model.graph.input[0] + outp = model.graph.output[0] + new_graph = helper.make_graph( + nodes=node_list, + name="fuse-graph", + inputs=[inp], + outputs=[outp], + value_info=[], + ) + + new_model = helper.make_model(new_graph, producer_name="fuse_model") + new_model = ModelWrapper(new_model) + vi_preproc = [x for x in pre_proc_model.graph.input] + vi_preproc += [x for x in pre_proc_model.graph.output] + vi_preproc += [x for x in pre_proc_model.graph.value_info] + for vi in vi_preproc: + if vi == inp: + continue + new_model.graph.value_info.append(vi) + init_val = pre_proc_model.get_initializer(vi.name) + if init_val is not None: + new_model.set_initializer(vi.name, init_val) + vi_model = [x for x in model.graph.input] + vi_model += [x for x in model.graph.output] + vi_model += [x for x in model.graph.value_info] + for vi in vi_model: + if vi == outp: + continue + new_model.graph.value_info.append(vi) + init_val = model.get_initializer(vi.name) + if init_val is not None: + new_model.set_initializer(vi.name, init_val) + + model = new_model + model = model.transform(InferShapes()) + model = model.transform(InferDataTypes()) + model = model.transform(GiveUniqueNodeNames()) + model = model.transform(GiveUniqueParameterTensors()) + model = model.transform(GiveReadableTensorNames()) + return (model, graph_modified)