From 6358106280872730fa7468d303b148aa9756f695 Mon Sep 17 00:00:00 2001 From: auphelia <jakobapk@web.de> Date: Wed, 1 Jul 2020 10:50:24 +0100 Subject: [PATCH] [Transform] Add comments to MergeONNXModels --- src/finn/transformation/merge_onnx_models.py | 22 ++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/src/finn/transformation/merge_onnx_models.py b/src/finn/transformation/merge_onnx_models.py index 4cc056ead..0f46d0ab0 100644 --- a/src/finn/transformation/merge_onnx_models.py +++ b/src/finn/transformation/merge_onnx_models.py @@ -95,6 +95,7 @@ class MergeONNXModels(Transformation): graph_modified = False pre_model = self.pre_model + # make model values unique to avoid any conflict (pre_model, model) = _make_model_values_unique(pre_model, model) # check if models can be merged @@ -105,18 +106,27 @@ class MergeONNXModels(Transformation): assert ( output_a_shape == input_b_shape ), "Models can't be merged! Shapes don't match." + + # connect output of one model to input of the other for n in pre_model.graph.node: if output_model_a == n.output[0]: n.output[0] = input_model_b + # extract information for new model + + # nodes node_list_a = pre_model.graph.node node_list_b = model.graph.node node_list = node_list_a for node in node_list_b: node_list.append(node) + + # in and output inp = pre_model.graph.input[0] outp = model.graph.output[0] + + # create new graph and model new_graph = helper.make_graph( nodes=node_list, name="fuse-graph", @@ -127,10 +137,12 @@ class MergeONNXModels(Transformation): new_model = helper.make_model(new_graph, producer_name="fuse_model") new_model = ModelWrapper(new_model) - vi_preproc = [x for x in pre_model.graph.input] - vi_preproc += [x for x in pre_model.graph.output] - vi_preproc += [x for x in pre_model.graph.value_info] - for vi in vi_preproc: + + # add value info and initializers from both models to new model + vi_pre = [x for x in pre_model.graph.input] + vi_pre += [x for x in pre_model.graph.output] + vi_pre += [x for x in pre_model.graph.value_info] + for vi in vi_pre: if vi == inp: continue new_model.graph.value_info.append(vi) @@ -148,10 +160,12 @@ class MergeONNXModels(Transformation): if init_val is not None: new_model.set_initializer(vi.name, init_val) + # tidy-up new model model = new_model model = model.transform(InferShapes()) model = model.transform(InferDataTypes()) model = model.transform(GiveUniqueNodeNames()) model = model.transform(GiveUniqueParameterTensors()) model = model.transform(GiveReadableTensorNames()) + return (model, graph_modified) -- GitLab