diff --git a/src/finn/transformation/merge_onnx_models.py b/src/finn/transformation/merge_onnx_models.py
index 18af62dafc693df9414cc6a8a1ec338607fc47e9..5dc6127ed189311c72a119932394aca4745e3608 100644
--- a/src/finn/transformation/merge_onnx_models.py
+++ b/src/finn/transformation/merge_onnx_models.py
@@ -34,6 +34,7 @@ from finn.core.modelwrapper import ModelWrapper
 import finn.util.basic as util
 from finn.transformation.infer_shapes import InferShapes
 from finn.transformation.infer_datatypes import InferDataTypes
+from finn.transformation.infer_data_layouts import InferDataLayouts
 from finn.transformation.general import (
     GiveReadableTensorNames,
     GiveUniqueNodeNames,
@@ -59,16 +60,32 @@ class MergeONNXModels(Transformation):
         pre_model = self.pre_model
         post_model = copy.deepcopy(model)
 
-        if len(pre_model.graph.output) != 1:
+        # check for dynamic outputs of pre model
+        dyn_outp = []
+        for outp in pre_model.graph.output:
+            init_val = pre_model.get_initializer(outp.name)
+            if init_val is None:
+                dyn_outp.append(outp)
+
+        if len(dyn_outp) != 1:
             warnings.warn(
-                "The pre model has more than one output! The transformation tries "
-                "to connect output[0] to the input of the post model."
+                "The pre model has more than one dynamic output! The transformation "
+                "tries to connect the first dynamic output to the first dynamic input "
+                "of the post model."
             )
 
-        if len(post_model.graph.input) != 1:
+        # check for dynamic inputs of post model
+        dyn_inp = []
+        for inp in post_model.graph.input:
+            init_val = post_model.get_initializer(inp.name)
+            if init_val is None:
+                dyn_inp.append(inp)
+
+        if len(dyn_inp) != 1:
             warnings.warn(
-                "The post model has more than one input! The transformation tries "
-                "to connect input[0] to the output of the pre model."
+                "The post model has more than one dynamic input! The transformation "
+                "tries to connect the first dynamic input to the first dynamic output "
+                "of the pre model."
             )
 
         # erase all node names to avoid conflict
@@ -99,8 +116,8 @@ class MergeONNXModels(Transformation):
             used_names.append(new_name)
 
         # check if models can be merged
-        output_model_a = pre_model.graph.output[0].name
-        input_model_b = post_model.graph.input[0].name
+        output_model_a = dyn_outp[0].name
+        input_model_b = dyn_inp[0].name
         output_a_shape = pre_model.get_tensor_shape(output_model_a)
         input_b_shape = post_model.get_tensor_shape(input_model_b)
         assert (
@@ -144,11 +161,6 @@ class MergeONNXModels(Transformation):
         vi_pre += [x for x in pre_model.graph.output]
         vi_pre += [x for x in pre_model.graph.value_info]
         for vi in vi_pre:
-            # graph input should not be part of graph.value_info, so skip
-            # if current vi == inp
-            if vi == inp:
-                continue
-            new_model.graph.value_info.append(vi)
             # preserve intializers, quantization/sparsity annotation, etc.
             # initializer
             init_val = pre_model.get_initializer(vi.name)
@@ -165,17 +177,17 @@ class MergeONNXModels(Transformation):
             sparsity = pre_model.get_tensor_sparsity(vi.name)
             if sparsity is not None:
                 new_model.set_tensor_sparsity(vi.name, sparsity)
+            # graph input should not be part of graph.value_info, so don't insert
+            # if current vi == inp, but the quantization annotation is preserved
+            if vi == inp:
+                continue
+            new_model.graph.value_info.append(vi)
 
         # post model
         vi_model = [x for x in post_model.graph.input]
         vi_model += [x for x in post_model.graph.output]
         vi_model += [x for x in post_model.graph.value_info]
         for vi in vi_model:
-            # graph output should not be part of graph.value_info, so skip
-            # if current vi == outp
-            if vi == outp:
-                continue
-            new_model.graph.value_info.append(vi)
             # preserve intializers, quantization/sparsity annotation, etc.
             # initializer
             init_val = post_model.get_initializer(vi.name)
@@ -192,11 +204,17 @@ class MergeONNXModels(Transformation):
             sparsity = post_model.get_tensor_sparsity(vi.name)
             if sparsity is not None:
                 new_model.set_tensor_sparsity(vi.name, sparsity)
+            # graph output should not be part of graph.value_info, so don't insert
+            # if current vi == outp, but the quantization annotation is preserved
+            if vi == outp:
+                continue
+            new_model.graph.value_info.append(vi)
 
         # tidy-up new model
         model = new_model
         model = model.transform(InferShapes())
         model = model.transform(InferDataTypes())
+        model = model.transform(InferDataLayouts())
         model = model.transform(GiveUniqueNodeNames())
         model = model.transform(GiveUniqueParameterTensors())
         model = model.transform(GiveReadableTensorNames())