From 122c89cd197c0caff9c43a534262ec460f51cfda Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <maltanar@gmail.com>
Date: Tue, 18 Aug 2020 21:44:18 +0200
Subject: [PATCH] [Transform] fix MergeONNXModels after revealed bugs

duplicate VI's prior to refactor
---
 src/finn/transformation/merge_onnx_models.py | 113 +++++--------------
 1 file changed, 29 insertions(+), 84 deletions(-)

diff --git a/src/finn/transformation/merge_onnx_models.py b/src/finn/transformation/merge_onnx_models.py
index 5dc6127ed..ceacab197 100644
--- a/src/finn/transformation/merge_onnx_models.py
+++ b/src/finn/transformation/merge_onnx_models.py
@@ -31,12 +31,12 @@ from onnx import helper
 
 from finn.transformation import Transformation
 from finn.core.modelwrapper import ModelWrapper
-import finn.util.basic as util
 from finn.transformation.infer_shapes import InferShapes
 from finn.transformation.infer_datatypes import InferDataTypes
 from finn.transformation.infer_data_layouts import InferDataLayouts
 from finn.transformation.general import (
     GiveReadableTensorNames,
+    GiveRandomTensorNames,
     GiveUniqueNodeNames,
     GiveUniqueParameterTensors,
 )
@@ -59,6 +59,9 @@ class MergeONNXModels(Transformation):
         graph_modified = False
         pre_model = self.pre_model
         post_model = copy.deepcopy(model)
+        # to avoid mix-ups, start by giving all tensors random names
+        pre_model = pre_model.transform(GiveRandomTensorNames())
+        post_model = post_model.transform(GiveRandomTensorNames())
 
         # check for dynamic outputs of pre model
         dyn_outp = []
@@ -94,27 +97,6 @@ class MergeONNXModels(Transformation):
         for n in post_model.graph.node:
             n.name = ""
 
-        # randomize all tensor names
-        names1 = pre_model.get_all_tensor_names()
-        names2 = post_model.get_all_tensor_names()
-        used_names = names1 + names2
-
-        # pre_model
-        for tensor_name in names1:
-            new_name = util.random_string()
-            while new_name in used_names:
-                new_name = util.random_string()
-            pre_model.rename_tensor(tensor_name, new_name)
-            used_names.append(new_name)
-
-        # post_model
-        for tensor in names2:
-            new_name = util.random_string()
-            while new_name in used_names:
-                new_name = util.random_string()
-            post_model.rename_tensor(tensor_name, new_name)
-            used_names.append(new_name)
-
         # check if models can be merged
         output_model_a = dyn_outp[0].name
         input_model_b = dyn_inp[0].name
@@ -124,6 +106,9 @@ class MergeONNXModels(Transformation):
             output_a_shape == input_b_shape
         ), "Models can't be merged! Shapes don't match."
 
+        pre_model.save("pre.onnx")
+        post_model.save("post.onnx")
+
         # connect output of one model to input of the other
         for n in pre_model.graph.node:
             if output_model_a == n.output[0]:
@@ -132,83 +117,43 @@ class MergeONNXModels(Transformation):
         # extract information for new model
 
         # nodes
-        node_list_a = pre_model.graph.node
-        node_list_b = post_model.graph.node
-
-        node_list = node_list_a
-        for node in node_list_b:
-            node_list.append(node)
+        node_pre = [node for node in pre_model.graph.node]
+        node_post = [node for node in post_model.graph.node]
+        node_new = node_pre + node_post
 
         # in and output
         inp = pre_model.graph.input[0]
         outp = post_model.graph.output[0]
 
+        vi_pre = [x for x in pre_model.graph.value_info]
+        out_pre = [x for x in pre_model.graph.output]
+        qa_pre = [x for x in pre_model.graph.quantization_annotation]
+        init_pre = [x for x in pre_model.graph.initializer]
+
+        vi_post = [x for x in post_model.graph.value_info]
+        qa_post = [x for x in post_model.graph.quantization_annotation]
+        init_post = [x for x in post_model.graph.initializer]
+
+        vi_new = vi_pre + vi_post + out_pre
+        qa_new = qa_pre + qa_post
+        init_new = init_pre + init_post
+
         # create new graph and model
         new_graph = helper.make_graph(
-            nodes=node_list,
+            nodes=node_new,
             name="fuse-graph",
             inputs=[inp],
             outputs=[outp],
-            value_info=[],
+            value_info=vi_new,
         )
 
         new_model = helper.make_model(new_graph, producer_name="fuse_model")
         new_model = ModelWrapper(new_model)
 
-        # add value info from both models to new model
-        # pre model
-        vi_pre = [x for x in pre_model.graph.input]
-        vi_pre += [x for x in pre_model.graph.output]
-        vi_pre += [x for x in pre_model.graph.value_info]
-        for vi in vi_pre:
-            # preserve intializers, quantization/sparsity annotation, etc.
-            # initializer
-            init_val = pre_model.get_initializer(vi.name)
-            if init_val is not None:
-                new_model.set_initializer(vi.name, init_val)
-            # FINN datatype
-            dtype = pre_model.get_tensor_datatype(vi.name)
-            new_model.set_tensor_datatype(vi.name, dtype)
-            # data layout
-            data_layout = pre_model.get_tensor_layout(vi.name)
-            if data_layout is not None:
-                new_model.set_tensor_layout(vi.name, data_layout)
-            # sparsity
-            sparsity = pre_model.get_tensor_sparsity(vi.name)
-            if sparsity is not None:
-                new_model.set_tensor_sparsity(vi.name, sparsity)
-            # graph input should not be part of graph.value_info, so don't insert
-            # if current vi == inp, but the quantization annotation is preserved
-            if vi == inp:
-                continue
-            new_model.graph.value_info.append(vi)
-
-        # post model
-        vi_model = [x for x in post_model.graph.input]
-        vi_model += [x for x in post_model.graph.output]
-        vi_model += [x for x in post_model.graph.value_info]
-        for vi in vi_model:
-            # preserve intializers, quantization/sparsity annotation, etc.
-            # initializer
-            init_val = post_model.get_initializer(vi.name)
-            if init_val is not None:
-                new_model.set_initializer(vi.name, init_val)
-            # FINN datatype
-            dtype = post_model.get_tensor_datatype(vi.name)
-            new_model.set_tensor_datatype(vi.name, dtype)
-            # data layout
-            data_layout = post_model.get_tensor_layout(vi.name)
-            if data_layout is not None:
-                new_model.set_tensor_layout(vi.name, data_layout)
-            # sparsity
-            sparsity = post_model.get_tensor_sparsity(vi.name)
-            if sparsity is not None:
-                new_model.set_tensor_sparsity(vi.name, sparsity)
-            # graph output should not be part of graph.value_info, so don't insert
-            # if current vi == outp, but the quantization annotation is preserved
-            if vi == outp:
-                continue
-            new_model.graph.value_info.append(vi)
+        for i in init_new:
+            new_model.graph.initializer.append(i)
+        for qa in qa_new:
+            new_model.graph.quantization_annotation.append(qa)
 
         # tidy-up new model
         model = new_model
-- 
GitLab