diff --git a/src/finn/transformation/merge_onnx_models.py b/src/finn/transformation/merge_onnx_models.py
new file mode 100644
index 0000000000000000000000000000000000000000..4f6a3fe38f35413afc1b5990ea5aaf84b06d02f5
--- /dev/null
+++ b/src/finn/transformation/merge_onnx_models.py
@@ -0,0 +1,111 @@
+from onnx import helper
+
+from finn.transformation import Transformation
+from finn.core.modelwrapper import ModelWrapper
+import finn.util.basic as util
+from finn.transformation.infer_shapes import InferShapes
+from finn.transformation.infer_datatypes import InferDataTypes
+from finn.transformation.general import (
+    GiveReadableTensorNames,
+    GiveUniqueNodeNames,
+    GiveUniqueParameterTensors,
+)
+
+
+def _make_model_values_unique(model1, model2):
+    # ensure that tensor and node names are different in each model
+    # tensors
+    names1 = model1.get_all_tensor_names()
+    names2 = model2.get_all_tensor_names()
+    duplicates = list(set(names1).intersection(names2))
+    # if there are duplicates in the tensor names rename these tensors
+    if duplicates:
+        used_names = names1 + names2
+        for name in duplicates:
+            # model1
+            new_name = util.random_string()
+            while new_name in used_names:
+                new_name = util.random_string()
+            model1.rename_tensor(name, new_name)
+            used_names.append(new_name)
+
+            # model2
+            new_name = util.random_string()
+            while new_name in used_names:
+                new_name = util.random_string()
+            model2.rename_tensor(name, new_name)
+            used_names.append(new_name)
+
+    # nodes
+    names1 = [x.name for x in model1.graph.node]
+    names1 = list(filter(None, names1))  # filter out empty node names
+    names2 = [x.name for x in model2.graph.node]
+    names2 = list(filter(None, names2))
+    duplicates = list(set(names1).intersection(names2))
+    # if there are duplicates erase all node names
+    if duplicates:
+        for n in model1.graph.node:
+            n.name = ""
+        for n in model2.graph.node:
+            n.name = ""
+
+    return (model1, model2)
+
+
+class MergeONNXModels(Transformation):
+    def __init__(self, pre_proc_model):
+        super().__init__()
+        self.pre_proc_model = pre_proc_model
+
+    def apply(self, model):
+        graph_modified = False
+        pre_proc_model = self.pre_proc_model
+
+        (pre_proc_model, model) = _make_model_values_unique(pre_proc_model, model)
+
+        node_list_a = pre_proc_model.graph.node
+        node_list_b = model.graph.node
+        node_list = node_list_a
+        node_list[-1].output[0] = node_list_b[0].input[0]
+        for node in node_list_b:
+            node_list.append(node)
+        inp = pre_proc_model.graph.input[0]
+        outp = model.graph.output[0]
+        new_graph = helper.make_graph(
+            nodes=node_list,
+            name="fuse-graph",
+            inputs=[inp],
+            outputs=[outp],
+            value_info=[],
+        )
+
+        new_model = helper.make_model(new_graph, producer_name="fuse_model")
+        new_model = ModelWrapper(new_model)
+        vi_preproc = [x for x in pre_proc_model.graph.input]
+        vi_preproc += [x for x in pre_proc_model.graph.output]
+        vi_preproc += [x for x in pre_proc_model.graph.value_info]
+        for vi in vi_preproc:
+            if vi == inp:
+                continue
+            new_model.graph.value_info.append(vi)
+            init_val = pre_proc_model.get_initializer(vi.name)
+            if init_val is not None:
+                new_model.set_initializer(vi.name, init_val)
+        vi_model = [x for x in model.graph.input]
+        vi_model += [x for x in model.graph.output]
+        vi_model += [x for x in model.graph.value_info]
+        for vi in vi_model:
+            if vi == outp:
+                continue
+            new_model.graph.value_info.append(vi)
+            init_val = model.get_initializer(vi.name)
+            if init_val is not None:
+                new_model.set_initializer(vi.name, init_val)
+
+        model = new_model
+        model = model.transform(InferShapes())
+        model = model.transform(InferDataTypes())
+        model = model.transform(GiveUniqueNodeNames())
+        model = model.transform(GiveUniqueParameterTensors())
+        model = model.transform(GiveReadableTensorNames())
+        return (model, graph_modified)