From 122c89cd197c0caff9c43a534262ec460f51cfda Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu <maltanar@gmail.com> Date: Tue, 18 Aug 2020 21:44:18 +0200 Subject: [PATCH] [Transform] fix MergeONNXModels after revealed bugs duplicate VI's prior to refactor --- src/finn/transformation/merge_onnx_models.py | 113 +++++-------------- 1 file changed, 29 insertions(+), 84 deletions(-) diff --git a/src/finn/transformation/merge_onnx_models.py b/src/finn/transformation/merge_onnx_models.py index 5dc6127ed..ceacab197 100644 --- a/src/finn/transformation/merge_onnx_models.py +++ b/src/finn/transformation/merge_onnx_models.py @@ -31,12 +31,12 @@ from onnx import helper from finn.transformation import Transformation from finn.core.modelwrapper import ModelWrapper -import finn.util.basic as util from finn.transformation.infer_shapes import InferShapes from finn.transformation.infer_datatypes import InferDataTypes from finn.transformation.infer_data_layouts import InferDataLayouts from finn.transformation.general import ( GiveReadableTensorNames, + GiveRandomTensorNames, GiveUniqueNodeNames, GiveUniqueParameterTensors, ) @@ -59,6 +59,9 @@ class MergeONNXModels(Transformation): graph_modified = False pre_model = self.pre_model post_model = copy.deepcopy(model) + # to avoid mix-ups, start by giving all tensors random names + pre_model = pre_model.transform(GiveRandomTensorNames()) + post_model = post_model.transform(GiveRandomTensorNames()) # check for dynamic outputs of pre model dyn_outp = [] @@ -94,27 +97,6 @@ class MergeONNXModels(Transformation): for n in post_model.graph.node: n.name = "" - # randomize all tensor names - names1 = pre_model.get_all_tensor_names() - names2 = post_model.get_all_tensor_names() - used_names = names1 + names2 - - # pre_model - for tensor_name in names1: - new_name = util.random_string() - while new_name in used_names: - new_name = util.random_string() - pre_model.rename_tensor(tensor_name, new_name) - used_names.append(new_name) - - # post_model - for tensor in names2: - new_name = util.random_string() - while new_name in used_names: - new_name = util.random_string() - post_model.rename_tensor(tensor_name, new_name) - used_names.append(new_name) - # check if models can be merged output_model_a = dyn_outp[0].name input_model_b = dyn_inp[0].name @@ -124,6 +106,9 @@ class MergeONNXModels(Transformation): output_a_shape == input_b_shape ), "Models can't be merged! Shapes don't match." + pre_model.save("pre.onnx") + post_model.save("post.onnx") + # connect output of one model to input of the other for n in pre_model.graph.node: if output_model_a == n.output[0]: @@ -132,83 +117,43 @@ class MergeONNXModels(Transformation): # extract information for new model # nodes - node_list_a = pre_model.graph.node - node_list_b = post_model.graph.node - - node_list = node_list_a - for node in node_list_b: - node_list.append(node) + node_pre = [node for node in pre_model.graph.node] + node_post = [node for node in post_model.graph.node] + node_new = node_pre + node_post # in and output inp = pre_model.graph.input[0] outp = post_model.graph.output[0] + vi_pre = [x for x in pre_model.graph.value_info] + out_pre = [x for x in pre_model.graph.output] + qa_pre = [x for x in pre_model.graph.quantization_annotation] + init_pre = [x for x in pre_model.graph.initializer] + + vi_post = [x for x in post_model.graph.value_info] + qa_post = [x for x in post_model.graph.quantization_annotation] + init_post = [x for x in post_model.graph.initializer] + + vi_new = vi_pre + vi_post + out_pre + qa_new = qa_pre + qa_post + init_new = init_pre + init_post + # create new graph and model new_graph = helper.make_graph( - nodes=node_list, + nodes=node_new, name="fuse-graph", inputs=[inp], outputs=[outp], - value_info=[], + value_info=vi_new, ) new_model = helper.make_model(new_graph, producer_name="fuse_model") new_model = ModelWrapper(new_model) - # add value info from both models to new model - # pre model - vi_pre = [x for x in pre_model.graph.input] - vi_pre += [x for x in pre_model.graph.output] - vi_pre += [x for x in pre_model.graph.value_info] - for vi in vi_pre: - # preserve intializers, quantization/sparsity annotation, etc. - # initializer - init_val = pre_model.get_initializer(vi.name) - if init_val is not None: - new_model.set_initializer(vi.name, init_val) - # FINN datatype - dtype = pre_model.get_tensor_datatype(vi.name) - new_model.set_tensor_datatype(vi.name, dtype) - # data layout - data_layout = pre_model.get_tensor_layout(vi.name) - if data_layout is not None: - new_model.set_tensor_layout(vi.name, data_layout) - # sparsity - sparsity = pre_model.get_tensor_sparsity(vi.name) - if sparsity is not None: - new_model.set_tensor_sparsity(vi.name, sparsity) - # graph input should not be part of graph.value_info, so don't insert - # if current vi == inp, but the quantization annotation is preserved - if vi == inp: - continue - new_model.graph.value_info.append(vi) - - # post model - vi_model = [x for x in post_model.graph.input] - vi_model += [x for x in post_model.graph.output] - vi_model += [x for x in post_model.graph.value_info] - for vi in vi_model: - # preserve intializers, quantization/sparsity annotation, etc. - # initializer - init_val = post_model.get_initializer(vi.name) - if init_val is not None: - new_model.set_initializer(vi.name, init_val) - # FINN datatype - dtype = post_model.get_tensor_datatype(vi.name) - new_model.set_tensor_datatype(vi.name, dtype) - # data layout - data_layout = post_model.get_tensor_layout(vi.name) - if data_layout is not None: - new_model.set_tensor_layout(vi.name, data_layout) - # sparsity - sparsity = post_model.get_tensor_sparsity(vi.name) - if sparsity is not None: - new_model.set_tensor_sparsity(vi.name, sparsity) - # graph output should not be part of graph.value_info, so don't insert - # if current vi == outp, but the quantization annotation is preserved - if vi == outp: - continue - new_model.graph.value_info.append(vi) + for i in init_new: + new_model.graph.initializer.append(i) + for qa in qa_new: + new_model.graph.quantization_annotation.append(qa) # tidy-up new model model = new_model -- GitLab