From c2a6a615111edef534eb4a241c9c590ec146c884 Mon Sep 17 00:00:00 2001
From: auphelia <jakobapk@web.de>
Date: Wed, 1 Jul 2020 10:43:24 +0100
Subject: [PATCH] [Transform] Fix some bug in MergeONNXModels

---
 src/finn/transformation/merge_onnx_models.py | 68 ++++++++++++++++----
 1 file changed, 57 insertions(+), 11 deletions(-)

diff --git a/src/finn/transformation/merge_onnx_models.py b/src/finn/transformation/merge_onnx_models.py
index 4f6a3fe38..4cc056ead 100644
--- a/src/finn/transformation/merge_onnx_models.py
+++ b/src/finn/transformation/merge_onnx_models.py
@@ -1,3 +1,32 @@
+# Copyright (c) 2020, Xilinx
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+#   list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+#   this list of conditions and the following disclaimer in the documentation
+#   and/or other materials provided with the distribution.
+#
+# * Neither the name of FINN nor the names of its
+#   contributors may be used to endorse or promote products derived from
+#   this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+import copy
+
 from onnx import helper
 
 from finn.transformation import Transformation
@@ -53,23 +82,40 @@ def _make_model_values_unique(model1, model2):
 
 
 class MergeONNXModels(Transformation):
-    def __init__(self, pre_proc_model):
+    """Merges two models. The model passed in the transformation will be inserted before
+    the model the transformation is applied on, the resulting model is returned."""
+
+    def __init__(self, pre_model):
         super().__init__()
-        self.pre_proc_model = pre_proc_model
+        # use deep copy of model that should be inserted in the beginning of
+        # the other model to ensure that it stays unchanged
+        self.pre_model = copy.deepcopy(pre_model)
 
     def apply(self, model):
         graph_modified = False
-        pre_proc_model = self.pre_proc_model
+        pre_model = self.pre_model
 
-        (pre_proc_model, model) = _make_model_values_unique(pre_proc_model, model)
+        (pre_model, model) = _make_model_values_unique(pre_model, model)
 
-        node_list_a = pre_proc_model.graph.node
+        # check if models can be merged
+        output_model_a = pre_model.graph.output[0].name
+        input_model_b = model.graph.input[0].name
+        output_a_shape = pre_model.get_tensor_shape(output_model_a)
+        input_b_shape = model.get_tensor_shape(input_model_b)
+        assert (
+            output_a_shape == input_b_shape
+        ), "Models can't be merged! Shapes don't match."
+        for n in pre_model.graph.node:
+            if output_model_a == n.output[0]:
+                n.output[0] = input_model_b
+
+        node_list_a = pre_model.graph.node
         node_list_b = model.graph.node
+
         node_list = node_list_a
-        node_list[-1].output[0] = node_list_b[0].input[0]
         for node in node_list_b:
             node_list.append(node)
-        inp = pre_proc_model.graph.input[0]
+        inp = pre_model.graph.input[0]
         outp = model.graph.output[0]
         new_graph = helper.make_graph(
             nodes=node_list,
@@ -81,14 +127,14 @@ class MergeONNXModels(Transformation):
 
         new_model = helper.make_model(new_graph, producer_name="fuse_model")
         new_model = ModelWrapper(new_model)
-        vi_preproc = [x for x in pre_proc_model.graph.input]
-        vi_preproc += [x for x in pre_proc_model.graph.output]
-        vi_preproc += [x for x in pre_proc_model.graph.value_info]
+        vi_preproc = [x for x in pre_model.graph.input]
+        vi_preproc += [x for x in pre_model.graph.output]
+        vi_preproc += [x for x in pre_model.graph.value_info]
         for vi in vi_preproc:
             if vi == inp:
                 continue
             new_model.graph.value_info.append(vi)
-            init_val = pre_proc_model.get_initializer(vi.name)
+            init_val = pre_model.get_initializer(vi.name)
             if init_val is not None:
                 new_model.set_initializer(vi.name, init_val)
         vi_model = [x for x in model.graph.input]
-- 
GitLab