From b9c1cd76b9e5862834c7fb7db835ae83947c6a2a Mon Sep 17 00:00:00 2001
From: auphelia <jakobapk@web.de>
Date: Mon, 6 Jul 2020 10:45:36 +0100
Subject: [PATCH] [Transform] Update MergeONNXModels

---
 src/finn/transformation/merge_onnx_models.py | 139 ++++++++++++-------
 1 file changed, 86 insertions(+), 53 deletions(-)

diff --git a/src/finn/transformation/merge_onnx_models.py b/src/finn/transformation/merge_onnx_models.py
index 0f46d0ab0..18af62daf 100644
--- a/src/finn/transformation/merge_onnx_models.py
+++ b/src/finn/transformation/merge_onnx_models.py
@@ -26,7 +26,7 @@
 # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 import copy
-
+import warnings
 from onnx import helper
 
 from finn.transformation import Transformation
@@ -41,49 +41,12 @@ from finn.transformation.general import (
 )
 
 
-def _make_model_values_unique(model1, model2):
-    # ensure that tensor and node names are different in each model
-    # tensors
-    names1 = model1.get_all_tensor_names()
-    names2 = model2.get_all_tensor_names()
-    duplicates = list(set(names1).intersection(names2))
-    # if there are duplicates in the tensor names rename these tensors
-    if duplicates:
-        used_names = names1 + names2
-        for name in duplicates:
-            # model1
-            new_name = util.random_string()
-            while new_name in used_names:
-                new_name = util.random_string()
-            model1.rename_tensor(name, new_name)
-            used_names.append(new_name)
-
-            # model2
-            new_name = util.random_string()
-            while new_name in used_names:
-                new_name = util.random_string()
-            model2.rename_tensor(name, new_name)
-            used_names.append(new_name)
-
-    # nodes
-    names1 = [x.name for x in model1.graph.node]
-    names1 = list(filter(None, names1))  # filter out empty node names
-    names2 = [x.name for x in model2.graph.node]
-    names2 = list(filter(None, names2))
-    duplicates = list(set(names1).intersection(names2))
-    # if there are duplicates erase all node names
-    if duplicates:
-        for n in model1.graph.node:
-            n.name = ""
-        for n in model2.graph.node:
-            n.name = ""
-
-    return (model1, model2)
-
-
 class MergeONNXModels(Transformation):
     """Merges two models. The model passed in the transformation will be inserted before
-    the model the transformation is applied on, the resulting model is returned."""
+    the model the transformation is applied on, the resulting model is returned.
+    This transformation will try to connect graph.output[0] of the pre model and
+    graph.input[0] of the post model.
+    If more than one input or output exists, a warning is raised."""
 
     def __init__(self, pre_model):
         super().__init__()
@@ -94,15 +57,52 @@ class MergeONNXModels(Transformation):
     def apply(self, model):
         graph_modified = False
         pre_model = self.pre_model
+        post_model = copy.deepcopy(model)
 
-        # make model values unique to avoid any conflict
-        (pre_model, model) = _make_model_values_unique(pre_model, model)
+        if len(pre_model.graph.output) != 1:
+            warnings.warn(
+                "The pre model has more than one output! The transformation tries "
+                "to connect output[0] to the input of the post model."
+            )
+
+        if len(post_model.graph.input) != 1:
+            warnings.warn(
+                "The post model has more than one input! The transformation tries "
+                "to connect input[0] to the output of the pre model."
+            )
+
+        # erase all node names to avoid conflict
+        for n in pre_model.graph.node:
+            n.name = ""
+        for n in post_model.graph.node:
+            n.name = ""
+
+        # randomize all tensor names
+        names1 = pre_model.get_all_tensor_names()
+        names2 = post_model.get_all_tensor_names()
+        used_names = names1 + names2
+
+        # pre_model
+        for tensor_name in names1:
+            new_name = util.random_string()
+            while new_name in used_names:
+                new_name = util.random_string()
+            pre_model.rename_tensor(tensor_name, new_name)
+            used_names.append(new_name)
+
+        # post_model
+        for tensor in names2:
+            new_name = util.random_string()
+            while new_name in used_names:
+                new_name = util.random_string()
+            post_model.rename_tensor(tensor_name, new_name)
+            used_names.append(new_name)
 
         # check if models can be merged
         output_model_a = pre_model.graph.output[0].name
-        input_model_b = model.graph.input[0].name
+        input_model_b = post_model.graph.input[0].name
         output_a_shape = pre_model.get_tensor_shape(output_model_a)
-        input_b_shape = model.get_tensor_shape(input_model_b)
+        input_b_shape = post_model.get_tensor_shape(input_model_b)
         assert (
             output_a_shape == input_b_shape
         ), "Models can't be merged! Shapes don't match."
@@ -116,7 +116,7 @@ class MergeONNXModels(Transformation):
 
         # nodes
         node_list_a = pre_model.graph.node
-        node_list_b = model.graph.node
+        node_list_b = post_model.graph.node
 
         node_list = node_list_a
         for node in node_list_b:
@@ -124,7 +124,7 @@ class MergeONNXModels(Transformation):
 
         # in and output
         inp = pre_model.graph.input[0]
-        outp = model.graph.output[0]
+        outp = post_model.graph.output[0]
 
         # create new graph and model
         new_graph = helper.make_graph(
@@ -138,27 +138,60 @@ class MergeONNXModels(Transformation):
         new_model = helper.make_model(new_graph, producer_name="fuse_model")
         new_model = ModelWrapper(new_model)
 
-        # add value info and initializers from both models to new model
+        # add value info from both models to new model
+        # pre model
         vi_pre = [x for x in pre_model.graph.input]
         vi_pre += [x for x in pre_model.graph.output]
         vi_pre += [x for x in pre_model.graph.value_info]
         for vi in vi_pre:
+            # graph input should not be part of graph.value_info, so skip
+            # if current vi == inp
             if vi == inp:
                 continue
             new_model.graph.value_info.append(vi)
+            # preserve intializers, quantization/sparsity annotation, etc.
+            # initializer
             init_val = pre_model.get_initializer(vi.name)
             if init_val is not None:
                 new_model.set_initializer(vi.name, init_val)
-        vi_model = [x for x in model.graph.input]
-        vi_model += [x for x in model.graph.output]
-        vi_model += [x for x in model.graph.value_info]
+            # FINN datatype
+            dtype = pre_model.get_tensor_datatype(vi.name)
+            new_model.set_tensor_datatype(vi.name, dtype)
+            # data layout
+            data_layout = pre_model.get_tensor_layout(vi.name)
+            if data_layout is not None:
+                new_model.set_tensor_layout(vi.name, data_layout)
+            # sparsity
+            sparsity = pre_model.get_tensor_sparsity(vi.name)
+            if sparsity is not None:
+                new_model.set_tensor_sparsity(vi.name, sparsity)
+
+        # post model
+        vi_model = [x for x in post_model.graph.input]
+        vi_model += [x for x in post_model.graph.output]
+        vi_model += [x for x in post_model.graph.value_info]
         for vi in vi_model:
+            # graph output should not be part of graph.value_info, so skip
+            # if current vi == outp
             if vi == outp:
                 continue
             new_model.graph.value_info.append(vi)
-            init_val = model.get_initializer(vi.name)
+            # preserve intializers, quantization/sparsity annotation, etc.
+            # initializer
+            init_val = post_model.get_initializer(vi.name)
             if init_val is not None:
                 new_model.set_initializer(vi.name, init_val)
+            # FINN datatype
+            dtype = post_model.get_tensor_datatype(vi.name)
+            new_model.set_tensor_datatype(vi.name, dtype)
+            # data layout
+            data_layout = post_model.get_tensor_layout(vi.name)
+            if data_layout is not None:
+                new_model.set_tensor_layout(vi.name, data_layout)
+            # sparsity
+            sparsity = post_model.get_tensor_sparsity(vi.name)
+            if sparsity is not None:
+                new_model.set_tensor_sparsity(vi.name, sparsity)
 
         # tidy-up new model
         model = new_model
-- 
GitLab