diff --git a/src/finn/transformation/merge_onnx_models.py b/src/finn/transformation/merge_onnx_models.py index 18af62dafc693df9414cc6a8a1ec338607fc47e9..5dc6127ed189311c72a119932394aca4745e3608 100644 --- a/src/finn/transformation/merge_onnx_models.py +++ b/src/finn/transformation/merge_onnx_models.py @@ -34,6 +34,7 @@ from finn.core.modelwrapper import ModelWrapper import finn.util.basic as util from finn.transformation.infer_shapes import InferShapes from finn.transformation.infer_datatypes import InferDataTypes +from finn.transformation.infer_data_layouts import InferDataLayouts from finn.transformation.general import ( GiveReadableTensorNames, GiveUniqueNodeNames, @@ -59,16 +60,32 @@ class MergeONNXModels(Transformation): pre_model = self.pre_model post_model = copy.deepcopy(model) - if len(pre_model.graph.output) != 1: + # check for dynamic outputs of pre model + dyn_outp = [] + for outp in pre_model.graph.output: + init_val = pre_model.get_initializer(outp.name) + if init_val is None: + dyn_outp.append(outp) + + if len(dyn_outp) != 1: warnings.warn( - "The pre model has more than one output! The transformation tries " - "to connect output[0] to the input of the post model." + "The pre model has more than one dynamic output! The transformation " + "tries to connect the first dynamic output to the first dynamic input " + "of the post model." ) - if len(post_model.graph.input) != 1: + # check for dynamic inputs of post model + dyn_inp = [] + for inp in post_model.graph.input: + init_val = post_model.get_initializer(inp.name) + if init_val is None: + dyn_inp.append(inp) + + if len(dyn_inp) != 1: warnings.warn( - "The post model has more than one input! The transformation tries " - "to connect input[0] to the output of the pre model." + "The post model has more than one dynamic input! The transformation " + "tries to connect the first dynamic input to the first dynamic output " + "of the pre model." ) # erase all node names to avoid conflict @@ -99,8 +116,8 @@ class MergeONNXModels(Transformation): used_names.append(new_name) # check if models can be merged - output_model_a = pre_model.graph.output[0].name - input_model_b = post_model.graph.input[0].name + output_model_a = dyn_outp[0].name + input_model_b = dyn_inp[0].name output_a_shape = pre_model.get_tensor_shape(output_model_a) input_b_shape = post_model.get_tensor_shape(input_model_b) assert ( @@ -144,11 +161,6 @@ class MergeONNXModels(Transformation): vi_pre += [x for x in pre_model.graph.output] vi_pre += [x for x in pre_model.graph.value_info] for vi in vi_pre: - # graph input should not be part of graph.value_info, so skip - # if current vi == inp - if vi == inp: - continue - new_model.graph.value_info.append(vi) # preserve intializers, quantization/sparsity annotation, etc. # initializer init_val = pre_model.get_initializer(vi.name) @@ -165,17 +177,17 @@ class MergeONNXModels(Transformation): sparsity = pre_model.get_tensor_sparsity(vi.name) if sparsity is not None: new_model.set_tensor_sparsity(vi.name, sparsity) + # graph input should not be part of graph.value_info, so don't insert + # if current vi == inp, but the quantization annotation is preserved + if vi == inp: + continue + new_model.graph.value_info.append(vi) # post model vi_model = [x for x in post_model.graph.input] vi_model += [x for x in post_model.graph.output] vi_model += [x for x in post_model.graph.value_info] for vi in vi_model: - # graph output should not be part of graph.value_info, so skip - # if current vi == outp - if vi == outp: - continue - new_model.graph.value_info.append(vi) # preserve intializers, quantization/sparsity annotation, etc. # initializer init_val = post_model.get_initializer(vi.name) @@ -192,11 +204,17 @@ class MergeONNXModels(Transformation): sparsity = post_model.get_tensor_sparsity(vi.name) if sparsity is not None: new_model.set_tensor_sparsity(vi.name, sparsity) + # graph output should not be part of graph.value_info, so don't insert + # if current vi == outp, but the quantization annotation is preserved + if vi == outp: + continue + new_model.graph.value_info.append(vi) # tidy-up new model model = new_model model = model.transform(InferShapes()) model = model.transform(InferDataTypes()) + model = model.transform(InferDataLayouts()) model = model.transform(GiveUniqueNodeNames()) model = model.transform(GiveUniqueParameterTensors()) model = model.transform(GiveReadableTensorNames())