Skip to content
Snippets Groups Projects
Commit cf8719a6 authored by auphelia's avatar auphelia
Browse files

[Transform] Add first draft of transformation to merge two onnx models together

parent ddd4fdbf
No related branches found
No related tags found
No related merge requests found
from onnx import helper
from finn.transformation import Transformation
from finn.core.modelwrapper import ModelWrapper
import finn.util.basic as util
from finn.transformation.infer_shapes import InferShapes
from finn.transformation.infer_datatypes import InferDataTypes
from finn.transformation.general import (
GiveReadableTensorNames,
GiveUniqueNodeNames,
GiveUniqueParameterTensors,
)
def _make_model_values_unique(model1, model2):
# ensure that tensor and node names are different in each model
# tensors
names1 = model1.get_all_tensor_names()
names2 = model2.get_all_tensor_names()
duplicates = list(set(names1).intersection(names2))
# if there are duplicates in the tensor names rename these tensors
if duplicates:
used_names = names1 + names2
for name in duplicates:
# model1
new_name = util.random_string()
while new_name in used_names:
new_name = util.random_string()
model1.rename_tensor(name, new_name)
used_names.append(new_name)
# model2
new_name = util.random_string()
while new_name in used_names:
new_name = util.random_string()
model2.rename_tensor(name, new_name)
used_names.append(new_name)
# nodes
names1 = [x.name for x in model1.graph.node]
names1 = list(filter(None, names1)) # filter out empty node names
names2 = [x.name for x in model2.graph.node]
names2 = list(filter(None, names2))
duplicates = list(set(names1).intersection(names2))
# if there are duplicates erase all node names
if duplicates:
for n in model1.graph.node:
n.name = ""
for n in model2.graph.node:
n.name = ""
return (model1, model2)
class MergeONNXModels(Transformation):
def __init__(self, pre_proc_model):
super().__init__()
self.pre_proc_model = pre_proc_model
def apply(self, model):
graph_modified = False
pre_proc_model = self.pre_proc_model
(pre_proc_model, model) = _make_model_values_unique(pre_proc_model, model)
node_list_a = pre_proc_model.graph.node
node_list_b = model.graph.node
node_list = node_list_a
node_list[-1].output[0] = node_list_b[0].input[0]
for node in node_list_b:
node_list.append(node)
inp = pre_proc_model.graph.input[0]
outp = model.graph.output[0]
new_graph = helper.make_graph(
nodes=node_list,
name="fuse-graph",
inputs=[inp],
outputs=[outp],
value_info=[],
)
new_model = helper.make_model(new_graph, producer_name="fuse_model")
new_model = ModelWrapper(new_model)
vi_preproc = [x for x in pre_proc_model.graph.input]
vi_preproc += [x for x in pre_proc_model.graph.output]
vi_preproc += [x for x in pre_proc_model.graph.value_info]
for vi in vi_preproc:
if vi == inp:
continue
new_model.graph.value_info.append(vi)
init_val = pre_proc_model.get_initializer(vi.name)
if init_val is not None:
new_model.set_initializer(vi.name, init_val)
vi_model = [x for x in model.graph.input]
vi_model += [x for x in model.graph.output]
vi_model += [x for x in model.graph.value_info]
for vi in vi_model:
if vi == outp:
continue
new_model.graph.value_info.append(vi)
init_val = model.get_initializer(vi.name)
if init_val is not None:
new_model.set_initializer(vi.name, init_val)
model = new_model
model = model.transform(InferShapes())
model = model.transform(InferDataTypes())
model = model.transform(GiveUniqueNodeNames())
model = model.transform(GiveUniqueParameterTensors())
model = model.transform(GiveReadableTensorNames())
return (model, graph_modified)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment