From 6358106280872730fa7468d303b148aa9756f695 Mon Sep 17 00:00:00 2001
From: auphelia <jakobapk@web.de>
Date: Wed, 1 Jul 2020 10:50:24 +0100
Subject: [PATCH] [Transform] Add comments to MergeONNXModels

---
 src/finn/transformation/merge_onnx_models.py | 22 ++++++++++++++++----
 1 file changed, 18 insertions(+), 4 deletions(-)

diff --git a/src/finn/transformation/merge_onnx_models.py b/src/finn/transformation/merge_onnx_models.py
index 4cc056ead..0f46d0ab0 100644
--- a/src/finn/transformation/merge_onnx_models.py
+++ b/src/finn/transformation/merge_onnx_models.py
@@ -95,6 +95,7 @@ class MergeONNXModels(Transformation):
         graph_modified = False
         pre_model = self.pre_model
 
+        # make model values unique to avoid any conflict
         (pre_model, model) = _make_model_values_unique(pre_model, model)
 
         # check if models can be merged
@@ -105,18 +106,27 @@ class MergeONNXModels(Transformation):
         assert (
             output_a_shape == input_b_shape
         ), "Models can't be merged! Shapes don't match."
+
+        # connect output of one model to input of the other
         for n in pre_model.graph.node:
             if output_model_a == n.output[0]:
                 n.output[0] = input_model_b
 
+        # extract information for new model
+
+        # nodes
         node_list_a = pre_model.graph.node
         node_list_b = model.graph.node
 
         node_list = node_list_a
         for node in node_list_b:
             node_list.append(node)
+
+        # in and output
         inp = pre_model.graph.input[0]
         outp = model.graph.output[0]
+
+        # create new graph and model
         new_graph = helper.make_graph(
             nodes=node_list,
             name="fuse-graph",
@@ -127,10 +137,12 @@ class MergeONNXModels(Transformation):
 
         new_model = helper.make_model(new_graph, producer_name="fuse_model")
         new_model = ModelWrapper(new_model)
-        vi_preproc = [x for x in pre_model.graph.input]
-        vi_preproc += [x for x in pre_model.graph.output]
-        vi_preproc += [x for x in pre_model.graph.value_info]
-        for vi in vi_preproc:
+
+        # add value info and initializers from both models to new model
+        vi_pre = [x for x in pre_model.graph.input]
+        vi_pre += [x for x in pre_model.graph.output]
+        vi_pre += [x for x in pre_model.graph.value_info]
+        for vi in vi_pre:
             if vi == inp:
                 continue
             new_model.graph.value_info.append(vi)
@@ -148,10 +160,12 @@ class MergeONNXModels(Transformation):
             if init_val is not None:
                 new_model.set_initializer(vi.name, init_val)
 
+        # tidy-up new model
         model = new_model
         model = model.transform(InferShapes())
         model = model.transform(InferDataTypes())
         model = model.transform(GiveUniqueNodeNames())
         model = model.transform(GiveUniqueParameterTensors())
         model = model.transform(GiveReadableTensorNames())
+
         return (model, graph_modified)
-- 
GitLab