From b9c1cd76b9e5862834c7fb7db835ae83947c6a2a Mon Sep 17 00:00:00 2001 From: auphelia <jakobapk@web.de> Date: Mon, 6 Jul 2020 10:45:36 +0100 Subject: [PATCH] [Transform] Update MergeONNXModels --- src/finn/transformation/merge_onnx_models.py | 139 ++++++++++++------- 1 file changed, 86 insertions(+), 53 deletions(-) diff --git a/src/finn/transformation/merge_onnx_models.py b/src/finn/transformation/merge_onnx_models.py index 0f46d0ab0..18af62daf 100644 --- a/src/finn/transformation/merge_onnx_models.py +++ b/src/finn/transformation/merge_onnx_models.py @@ -26,7 +26,7 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import copy - +import warnings from onnx import helper from finn.transformation import Transformation @@ -41,49 +41,12 @@ from finn.transformation.general import ( ) -def _make_model_values_unique(model1, model2): - # ensure that tensor and node names are different in each model - # tensors - names1 = model1.get_all_tensor_names() - names2 = model2.get_all_tensor_names() - duplicates = list(set(names1).intersection(names2)) - # if there are duplicates in the tensor names rename these tensors - if duplicates: - used_names = names1 + names2 - for name in duplicates: - # model1 - new_name = util.random_string() - while new_name in used_names: - new_name = util.random_string() - model1.rename_tensor(name, new_name) - used_names.append(new_name) - - # model2 - new_name = util.random_string() - while new_name in used_names: - new_name = util.random_string() - model2.rename_tensor(name, new_name) - used_names.append(new_name) - - # nodes - names1 = [x.name for x in model1.graph.node] - names1 = list(filter(None, names1)) # filter out empty node names - names2 = [x.name for x in model2.graph.node] - names2 = list(filter(None, names2)) - duplicates = list(set(names1).intersection(names2)) - # if there are duplicates erase all node names - if duplicates: - for n in model1.graph.node: - n.name = "" - for n in model2.graph.node: - n.name = "" - - return (model1, model2) - - class MergeONNXModels(Transformation): """Merges two models. The model passed in the transformation will be inserted before - the model the transformation is applied on, the resulting model is returned.""" + the model the transformation is applied on, the resulting model is returned. + This transformation will try to connect graph.output[0] of the pre model and + graph.input[0] of the post model. + If more than one input or output exists, a warning is raised.""" def __init__(self, pre_model): super().__init__() @@ -94,15 +57,52 @@ class MergeONNXModels(Transformation): def apply(self, model): graph_modified = False pre_model = self.pre_model + post_model = copy.deepcopy(model) - # make model values unique to avoid any conflict - (pre_model, model) = _make_model_values_unique(pre_model, model) + if len(pre_model.graph.output) != 1: + warnings.warn( + "The pre model has more than one output! The transformation tries " + "to connect output[0] to the input of the post model." + ) + + if len(post_model.graph.input) != 1: + warnings.warn( + "The post model has more than one input! The transformation tries " + "to connect input[0] to the output of the pre model." + ) + + # erase all node names to avoid conflict + for n in pre_model.graph.node: + n.name = "" + for n in post_model.graph.node: + n.name = "" + + # randomize all tensor names + names1 = pre_model.get_all_tensor_names() + names2 = post_model.get_all_tensor_names() + used_names = names1 + names2 + + # pre_model + for tensor_name in names1: + new_name = util.random_string() + while new_name in used_names: + new_name = util.random_string() + pre_model.rename_tensor(tensor_name, new_name) + used_names.append(new_name) + + # post_model + for tensor in names2: + new_name = util.random_string() + while new_name in used_names: + new_name = util.random_string() + post_model.rename_tensor(tensor_name, new_name) + used_names.append(new_name) # check if models can be merged output_model_a = pre_model.graph.output[0].name - input_model_b = model.graph.input[0].name + input_model_b = post_model.graph.input[0].name output_a_shape = pre_model.get_tensor_shape(output_model_a) - input_b_shape = model.get_tensor_shape(input_model_b) + input_b_shape = post_model.get_tensor_shape(input_model_b) assert ( output_a_shape == input_b_shape ), "Models can't be merged! Shapes don't match." @@ -116,7 +116,7 @@ class MergeONNXModels(Transformation): # nodes node_list_a = pre_model.graph.node - node_list_b = model.graph.node + node_list_b = post_model.graph.node node_list = node_list_a for node in node_list_b: @@ -124,7 +124,7 @@ class MergeONNXModels(Transformation): # in and output inp = pre_model.graph.input[0] - outp = model.graph.output[0] + outp = post_model.graph.output[0] # create new graph and model new_graph = helper.make_graph( @@ -138,27 +138,60 @@ class MergeONNXModels(Transformation): new_model = helper.make_model(new_graph, producer_name="fuse_model") new_model = ModelWrapper(new_model) - # add value info and initializers from both models to new model + # add value info from both models to new model + # pre model vi_pre = [x for x in pre_model.graph.input] vi_pre += [x for x in pre_model.graph.output] vi_pre += [x for x in pre_model.graph.value_info] for vi in vi_pre: + # graph input should not be part of graph.value_info, so skip + # if current vi == inp if vi == inp: continue new_model.graph.value_info.append(vi) + # preserve intializers, quantization/sparsity annotation, etc. + # initializer init_val = pre_model.get_initializer(vi.name) if init_val is not None: new_model.set_initializer(vi.name, init_val) - vi_model = [x for x in model.graph.input] - vi_model += [x for x in model.graph.output] - vi_model += [x for x in model.graph.value_info] + # FINN datatype + dtype = pre_model.get_tensor_datatype(vi.name) + new_model.set_tensor_datatype(vi.name, dtype) + # data layout + data_layout = pre_model.get_tensor_layout(vi.name) + if data_layout is not None: + new_model.set_tensor_layout(vi.name, data_layout) + # sparsity + sparsity = pre_model.get_tensor_sparsity(vi.name) + if sparsity is not None: + new_model.set_tensor_sparsity(vi.name, sparsity) + + # post model + vi_model = [x for x in post_model.graph.input] + vi_model += [x for x in post_model.graph.output] + vi_model += [x for x in post_model.graph.value_info] for vi in vi_model: + # graph output should not be part of graph.value_info, so skip + # if current vi == outp if vi == outp: continue new_model.graph.value_info.append(vi) - init_val = model.get_initializer(vi.name) + # preserve intializers, quantization/sparsity annotation, etc. + # initializer + init_val = post_model.get_initializer(vi.name) if init_val is not None: new_model.set_initializer(vi.name, init_val) + # FINN datatype + dtype = post_model.get_tensor_datatype(vi.name) + new_model.set_tensor_datatype(vi.name, dtype) + # data layout + data_layout = post_model.get_tensor_layout(vi.name) + if data_layout is not None: + new_model.set_tensor_layout(vi.name, data_layout) + # sparsity + sparsity = post_model.get_tensor_sparsity(vi.name) + if sparsity is not None: + new_model.set_tensor_sparsity(vi.name, sparsity) # tidy-up new model model = new_model -- GitLab