Skip to content
Snippets Groups Projects
Commit 7860d0c0 authored by auphelia's avatar auphelia
Browse files

[Transform] Restructure MergeONNXModels to correctly preserve quantization...

[Transform] Restructure MergeONNXModels to correctly preserve quantization information of non value_info tensors
parent b9c1cd76
No related branches found
No related tags found
No related merge requests found
......@@ -34,6 +34,7 @@ from finn.core.modelwrapper import ModelWrapper
import finn.util.basic as util
from finn.transformation.infer_shapes import InferShapes
from finn.transformation.infer_datatypes import InferDataTypes
from finn.transformation.infer_data_layouts import InferDataLayouts
from finn.transformation.general import (
GiveReadableTensorNames,
GiveUniqueNodeNames,
......@@ -59,16 +60,32 @@ class MergeONNXModels(Transformation):
pre_model = self.pre_model
post_model = copy.deepcopy(model)
if len(pre_model.graph.output) != 1:
# check for dynamic outputs of pre model
dyn_outp = []
for outp in pre_model.graph.output:
init_val = pre_model.get_initializer(outp.name)
if init_val is None:
dyn_outp.append(outp)
if len(dyn_outp) != 1:
warnings.warn(
"The pre model has more than one output! The transformation tries "
"to connect output[0] to the input of the post model."
"The pre model has more than one dynamic output! The transformation "
"tries to connect the first dynamic output to the first dynamic input "
"of the post model."
)
if len(post_model.graph.input) != 1:
# check for dynamic inputs of post model
dyn_inp = []
for inp in post_model.graph.input:
init_val = post_model.get_initializer(inp.name)
if init_val is None:
dyn_inp.append(inp)
if len(dyn_inp) != 1:
warnings.warn(
"The post model has more than one input! The transformation tries "
"to connect input[0] to the output of the pre model."
"The post model has more than one dynamic input! The transformation "
"tries to connect the first dynamic input to the first dynamic output "
"of the pre model."
)
# erase all node names to avoid conflict
......@@ -99,8 +116,8 @@ class MergeONNXModels(Transformation):
used_names.append(new_name)
# check if models can be merged
output_model_a = pre_model.graph.output[0].name
input_model_b = post_model.graph.input[0].name
output_model_a = dyn_outp[0].name
input_model_b = dyn_inp[0].name
output_a_shape = pre_model.get_tensor_shape(output_model_a)
input_b_shape = post_model.get_tensor_shape(input_model_b)
assert (
......@@ -144,11 +161,6 @@ class MergeONNXModels(Transformation):
vi_pre += [x for x in pre_model.graph.output]
vi_pre += [x for x in pre_model.graph.value_info]
for vi in vi_pre:
# graph input should not be part of graph.value_info, so skip
# if current vi == inp
if vi == inp:
continue
new_model.graph.value_info.append(vi)
# preserve intializers, quantization/sparsity annotation, etc.
# initializer
init_val = pre_model.get_initializer(vi.name)
......@@ -165,17 +177,17 @@ class MergeONNXModels(Transformation):
sparsity = pre_model.get_tensor_sparsity(vi.name)
if sparsity is not None:
new_model.set_tensor_sparsity(vi.name, sparsity)
# graph input should not be part of graph.value_info, so don't insert
# if current vi == inp, but the quantization annotation is preserved
if vi == inp:
continue
new_model.graph.value_info.append(vi)
# post model
vi_model = [x for x in post_model.graph.input]
vi_model += [x for x in post_model.graph.output]
vi_model += [x for x in post_model.graph.value_info]
for vi in vi_model:
# graph output should not be part of graph.value_info, so skip
# if current vi == outp
if vi == outp:
continue
new_model.graph.value_info.append(vi)
# preserve intializers, quantization/sparsity annotation, etc.
# initializer
init_val = post_model.get_initializer(vi.name)
......@@ -192,11 +204,17 @@ class MergeONNXModels(Transformation):
sparsity = post_model.get_tensor_sparsity(vi.name)
if sparsity is not None:
new_model.set_tensor_sparsity(vi.name, sparsity)
# graph output should not be part of graph.value_info, so don't insert
# if current vi == outp, but the quantization annotation is preserved
if vi == outp:
continue
new_model.graph.value_info.append(vi)
# tidy-up new model
model = new_model
model = model.transform(InferShapes())
model = model.transform(InferDataTypes())
model = model.transform(InferDataLayouts())
model = model.transform(GiveUniqueNodeNames())
model = model.transform(GiveUniqueParameterTensors())
model = model.transform(GiveReadableTensorNames())
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment