Skip to content
Snippets Groups Projects
Commit 63581062 authored by auphelia's avatar auphelia
Browse files

[Transform] Add comments to MergeONNXModels

parent 5aca8761
No related branches found
No related tags found
No related merge requests found
......@@ -95,6 +95,7 @@ class MergeONNXModels(Transformation):
graph_modified = False
pre_model = self.pre_model
# make model values unique to avoid any conflict
(pre_model, model) = _make_model_values_unique(pre_model, model)
# check if models can be merged
......@@ -105,18 +106,27 @@ class MergeONNXModels(Transformation):
assert (
output_a_shape == input_b_shape
), "Models can't be merged! Shapes don't match."
# connect output of one model to input of the other
for n in pre_model.graph.node:
if output_model_a == n.output[0]:
n.output[0] = input_model_b
# extract information for new model
# nodes
node_list_a = pre_model.graph.node
node_list_b = model.graph.node
node_list = node_list_a
for node in node_list_b:
node_list.append(node)
# in and output
inp = pre_model.graph.input[0]
outp = model.graph.output[0]
# create new graph and model
new_graph = helper.make_graph(
nodes=node_list,
name="fuse-graph",
......@@ -127,10 +137,12 @@ class MergeONNXModels(Transformation):
new_model = helper.make_model(new_graph, producer_name="fuse_model")
new_model = ModelWrapper(new_model)
vi_preproc = [x for x in pre_model.graph.input]
vi_preproc += [x for x in pre_model.graph.output]
vi_preproc += [x for x in pre_model.graph.value_info]
for vi in vi_preproc:
# add value info and initializers from both models to new model
vi_pre = [x for x in pre_model.graph.input]
vi_pre += [x for x in pre_model.graph.output]
vi_pre += [x for x in pre_model.graph.value_info]
for vi in vi_pre:
if vi == inp:
continue
new_model.graph.value_info.append(vi)
......@@ -148,10 +160,12 @@ class MergeONNXModels(Transformation):
if init_val is not None:
new_model.set_initializer(vi.name, init_val)
# tidy-up new model
model = new_model
model = model.transform(InferShapes())
model = model.transform(InferDataTypes())
model = model.transform(GiveUniqueNodeNames())
model = model.transform(GiveUniqueParameterTensors())
model = model.transform(GiveReadableTensorNames())
return (model, graph_modified)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment