From c2a6a615111edef534eb4a241c9c590ec146c884 Mon Sep 17 00:00:00 2001 From: auphelia <jakobapk@web.de> Date: Wed, 1 Jul 2020 10:43:24 +0100 Subject: [PATCH] [Transform] Fix some bug in MergeONNXModels --- src/finn/transformation/merge_onnx_models.py | 68 ++++++++++++++++---- 1 file changed, 57 insertions(+), 11 deletions(-) diff --git a/src/finn/transformation/merge_onnx_models.py b/src/finn/transformation/merge_onnx_models.py index 4f6a3fe38..4cc056ead 100644 --- a/src/finn/transformation/merge_onnx_models.py +++ b/src/finn/transformation/merge_onnx_models.py @@ -1,3 +1,32 @@ +# Copyright (c) 2020, Xilinx +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of FINN nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import copy + from onnx import helper from finn.transformation import Transformation @@ -53,23 +82,40 @@ def _make_model_values_unique(model1, model2): class MergeONNXModels(Transformation): - def __init__(self, pre_proc_model): + """Merges two models. The model passed in the transformation will be inserted before + the model the transformation is applied on, the resulting model is returned.""" + + def __init__(self, pre_model): super().__init__() - self.pre_proc_model = pre_proc_model + # use deep copy of model that should be inserted in the beginning of + # the other model to ensure that it stays unchanged + self.pre_model = copy.deepcopy(pre_model) def apply(self, model): graph_modified = False - pre_proc_model = self.pre_proc_model + pre_model = self.pre_model - (pre_proc_model, model) = _make_model_values_unique(pre_proc_model, model) + (pre_model, model) = _make_model_values_unique(pre_model, model) - node_list_a = pre_proc_model.graph.node + # check if models can be merged + output_model_a = pre_model.graph.output[0].name + input_model_b = model.graph.input[0].name + output_a_shape = pre_model.get_tensor_shape(output_model_a) + input_b_shape = model.get_tensor_shape(input_model_b) + assert ( + output_a_shape == input_b_shape + ), "Models can't be merged! Shapes don't match." + for n in pre_model.graph.node: + if output_model_a == n.output[0]: + n.output[0] = input_model_b + + node_list_a = pre_model.graph.node node_list_b = model.graph.node + node_list = node_list_a - node_list[-1].output[0] = node_list_b[0].input[0] for node in node_list_b: node_list.append(node) - inp = pre_proc_model.graph.input[0] + inp = pre_model.graph.input[0] outp = model.graph.output[0] new_graph = helper.make_graph( nodes=node_list, @@ -81,14 +127,14 @@ class MergeONNXModels(Transformation): new_model = helper.make_model(new_graph, producer_name="fuse_model") new_model = ModelWrapper(new_model) - vi_preproc = [x for x in pre_proc_model.graph.input] - vi_preproc += [x for x in pre_proc_model.graph.output] - vi_preproc += [x for x in pre_proc_model.graph.value_info] + vi_preproc = [x for x in pre_model.graph.input] + vi_preproc += [x for x in pre_model.graph.output] + vi_preproc += [x for x in pre_model.graph.value_info] for vi in vi_preproc: if vi == inp: continue new_model.graph.value_info.append(vi) - init_val = pre_proc_model.get_initializer(vi.name) + init_val = pre_model.get_initializer(vi.name) if init_val is not None: new_model.set_initializer(vi.name, init_val) vi_model = [x for x in model.graph.input] -- GitLab