diff --git a/docs/finn-sheduling-and-folding.pptx b/docs/finn-sheduling-and-folding.pptx
new file mode 100644
index 0000000000000000000000000000000000000000..30bbe4d55b1cda9df25a791227983dc7cb750e58
Binary files /dev/null and b/docs/finn-sheduling-and-folding.pptx differ
diff --git a/src/finn/custom_op/fpgadataflow/labelselect_batch.py b/src/finn/custom_op/fpgadataflow/labelselect_batch.py
index 7591f09d8d0cd1847672fe5aa09616ff1571033d..f61fbf12da889700297006ef2566088d4150c0e4 100644
--- a/src/finn/custom_op/fpgadataflow/labelselect_batch.py
+++ b/src/finn/custom_op/fpgadataflow/labelselect_batch.py
@@ -41,6 +41,13 @@ class LabelSelect_Batch(HLSCustomOp):
 
     def __init__(self, onnx_node):
         super().__init__(onnx_node)
+        odt_name = self.get_nodeattr("outputDataType")
+        if odt_name == "":
+            # If not provided compute min size
+            labels = self.get_nodeattr("Labels")
+            odt = DataType.get_smallest_possible(labels - 1)
+            odt_name = odt.name
+            self.set_nodeattr("outputDataType", odt_name)
 
     def get_nodeattr_types(self):
         my_attrs = {
@@ -49,6 +56,7 @@ class LabelSelect_Batch(HLSCustomOp):
             "K": ("i", True, 0),
             # FINN DataTypes for input
             "inputDataType": ("s", True, ""),
+            "outputDataType": ("s", False, ""),
             # number of input vectors, examples:
             # [1] is a single vector (like a FC layer with batch=1)
             # [4] is four vectors (like a FC layer with batch=4)
@@ -69,7 +77,6 @@ class LabelSelect_Batch(HLSCustomOp):
         pe = self.get_nodeattr("PE")
         vecs = list(self.get_nodeattr("numInputVectors"))
         assert nlabels % pe == 0, "PE must divide Labels"
-        assert pe == 1, "LabelSelect currently fails with folding"
         folds = int(nlabels / pe)
         folded_ishape = tuple(vecs + [folds, pe])
         return folded_ishape
@@ -90,7 +97,7 @@ class LabelSelect_Batch(HLSCustomOp):
         exp_ishape = self.get_normal_input_shape()
         oshape = self.get_normal_output_shape()
         ishape = tuple(model.get_tensor_shape(self.onnx_node.input[0]))
-        assert ishape == exp_ishape, "Unexpect input shape."
+        assert ishape == exp_ishape, "Unexpected input shape."
         # implement tensor with correct shape
         values = np.random.randn(*oshape).astype(np.int64)
         return helper.make_node(
@@ -106,9 +113,8 @@ class LabelSelect_Batch(HLSCustomOp):
         )
 
     def infer_node_datatype(self, model):
-        # currently set to uint32 to be compatible with hlslib
-        # enhancement: consider finding smallest power-of-two int for reduced output bandwidth
-        model.set_tensor_datatype(self.onnx_node.output[0], DataType.UINT32)
+        odt = self.get_output_datatype()
+        model.set_tensor_datatype(self.onnx_node.output[0], odt)
 
     def verify_node(self):
         info_messages = []
@@ -134,6 +140,7 @@ class LabelSelect_Batch(HLSCustomOp):
             self.get_nodeattr("PE")
             self.get_nodeattr("K")
             self.get_nodeattr("inputDataType")
+            self.get_nodeattr("outputDataType")
             info_messages.append("All necessary attributes exist")
         except Exception:
             info_messages.append(
@@ -150,12 +157,12 @@ class LabelSelect_Batch(HLSCustomOp):
     def get_input_datatype(self):
         """Returns FINN DataType of input."""
         ret = DataType[self.get_nodeattr("inputDataType")]
-        assert ret.signed() is False, "LabelSelect is currently broken for signed inputs"
         return ret
 
     def get_output_datatype(self):
         """Returns FINN DataType of output."""
-        return DataType.UINT32
+        ret = DataType[self.get_nodeattr("outputDataType")]
+        return ret
 
     def get_instream_width(self):
         """Returns input stream width."""
@@ -260,8 +267,13 @@ class LabelSelect_Batch(HLSCustomOp):
         npy_type = "float"
         npy_in = "%s/input_0.npy" % code_gen_dir
         self.code_gen_dict["$READNPYDATA$"] = []
+
+        # Calling npy2apintstream with reverse_inner = false to have LE packing
+        # as required by HLS fxn LabelSelect_Batch
+        # Also notice that StreamingDataWidthConverter_Batch performs LE packing
+
         self.code_gen_dict["$READNPYDATA$"].append(
-            'npy2apintstream<%s, %s, %d, %s>("%s", in0);'
+            'npy2apintstream<%s, %s, %d, %s>("%s", in0,false);'
             % (packed_hls_type, elem_hls_type, elem_bits, npy_type, npy_in)
         )
 
@@ -277,12 +289,13 @@ class LabelSelect_Batch(HLSCustomOp):
     def docompute(self):
         node = self.onnx_node
         self.code_gen_dict["$DOCOMPUTE$"] = [
-            """{}<{}, {}, {}, {}, ap_uint<32>> (in0, out, 1);""".format(
+            """{}<{}, {}, {}, {}, {} > (in0, out, 1);""".format(
                 node.op_type,
                 self.get_nodeattr("Labels"),
                 self.get_nodeattr("PE"),
                 self.get_nodeattr("K"),
                 self.get_input_datatype().get_hls_datatype_str(),
+                self.get_output_datatype().get_hls_datatype_str(),
             )
         ]
 
@@ -316,10 +329,11 @@ class LabelSelect_Batch(HLSCustomOp):
     def blackboxfunction(self):
         self.code_gen_dict["$BLACKBOXFUNCTION$"] = [
             """void {}(hls::stream<ap_uint<{}*{}>> &in0,
-                hls::stream<ap_uint<32>> &out)""".format(
+                hls::stream<ap_uint<{}> > &out)""".format(
                 self.onnx_node.name,
                 self.get_nodeattr("PE"),
                 self.get_input_datatype().bitwidth(),
+                self.get_output_datatype().bitwidth(),
             )
         ]
 
diff --git a/src/finn/transformation/general.py b/src/finn/transformation/general.py
index 488391740fc25f1f7caa657adc9ed55bdc2f9722..4303eb17f39a9949f5729e895e449bbb6a633033 100644
--- a/src/finn/transformation/general.py
+++ b/src/finn/transformation/general.py
@@ -31,6 +31,55 @@ from finn.transformation import Transformation
 from toposort import toposort_flatten
 
 
+class RemoveUnusedTensors(Transformation):
+    """Remove any unused tensors in the graph by removing any initializers,
+    ValueInfo and tensor annotations associated with it. Unused tensors do not
+    appear as any input/output for any graph nodes.
+    """
+
+    def apply(self, model):
+        graph_modified = False
+        onnx_graph = model.model.graph
+        # build a set of tensors that we actually use in the graph nodes
+        used_tensors = set()
+        for node in model.graph.node:
+            for i in node.input:
+                used_tensors.add(i)
+            for o in node.output:
+                used_tensors.add(o)
+        # remove initializers, value_info and annotations that are not in the
+        # used set of tensors, as determined by the graph node i/o
+        for init in onnx_graph.initializer:
+            if init.name not in used_tensors:
+                onnx_graph.initializer.remove(init)
+                graph_modified = True
+        for vi in onnx_graph.value_info:
+            if vi.name not in used_tensors:
+                onnx_graph.value_info.remove(vi)
+                graph_modified = True
+        for qa in onnx_graph.quantization_annotation:
+            if qa.tensor_name not in used_tensors:
+                onnx_graph.quantization_annotation.remove(qa)
+                graph_modified = True
+
+        return (model, graph_modified)
+
+
+class RemoveStaticGraphInputs(Transformation):
+    "Remove any top-level graph inputs that have initializers."
+
+    def apply(self, model):
+        graph_modified = False
+        for i in model.graph.input:
+            if model.get_initializer(i.name) is not None:
+                # move ValueInfo to internal (value_info) container
+                model.graph.value_info.append(i)
+                model.graph.input.remove(i)
+                graph_modified = True
+
+        return (model, graph_modified)
+
+
 class GiveUniqueNodeNames(Transformation):
     """Give unique names to each node in the graph using enumeration."""
 
@@ -121,24 +170,23 @@ class GiveUniqueParameterTensors(Transformation):
 
 class SortGraph(Transformation):
     """ Returns the model with its node list sorted topologically.
-    Any ONNX graph to be executed must have a topologically sorted node list, as dictated
-    by the ONNX standard.
+    Any ONNX graph to be executed must have a topologically sorted node list,
+    as dictated by the ONNX standard.
     """
-    
+
     # Notes on SortGraph performance:
-    #         benchmark in  tests/transformation/test_sort_graph.py
-    # 
-    #         The algorithm doesn't move initializers so its performance should only depend on
-    #         the number of nodes
-    # 
-    #         Relative order of magnitudes for time per step:
-    #             - Gather graph structure:       base
-    #             - Sort nodes:                   0.1 of base
-    #             - Remove and insert in order :  0.001 of base
-    # 
-    #     Notes:
-    #         Remove nodes and insert them in order:
-    #           Probably this is faster than copying initializers and more robust in general
+    # benchmark in  tests/transformation/test_sort_graph.py
+    # The algorithm doesn't move initializers so its performance should only depend on
+    # the number of nodes
+    #
+    # Relative order of magnitudes for time per step:
+    # - Gather graph structure:       base
+    # - Sort nodes:                   0.1 of base
+    # - Remove and insert in order :  0.001 of base
+    #
+    # Notes:
+    # Remove nodes and insert them in order:
+    # Probably this is faster than copying initializers and more robust in general
 
     def apply(self, model):
         # Gather graph structure
diff --git a/src/finn/transformation/merge_onnx_models.py b/src/finn/transformation/merge_onnx_models.py
new file mode 100644
index 0000000000000000000000000000000000000000..5dc6127ed189311c72a119932394aca4745e3608
--- /dev/null
+++ b/src/finn/transformation/merge_onnx_models.py
@@ -0,0 +1,222 @@
+# Copyright (c) 2020, Xilinx
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+#   list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+#   this list of conditions and the following disclaimer in the documentation
+#   and/or other materials provided with the distribution.
+#
+# * Neither the name of FINN nor the names of its
+#   contributors may be used to endorse or promote products derived from
+#   this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+import copy
+import warnings
+from onnx import helper
+
+from finn.transformation import Transformation
+from finn.core.modelwrapper import ModelWrapper
+import finn.util.basic as util
+from finn.transformation.infer_shapes import InferShapes
+from finn.transformation.infer_datatypes import InferDataTypes
+from finn.transformation.infer_data_layouts import InferDataLayouts
+from finn.transformation.general import (
+    GiveReadableTensorNames,
+    GiveUniqueNodeNames,
+    GiveUniqueParameterTensors,
+)
+
+
+class MergeONNXModels(Transformation):
+    """Merges two models. The model passed in the transformation will be inserted before
+    the model the transformation is applied on, the resulting model is returned.
+    This transformation will try to connect graph.output[0] of the pre model and
+    graph.input[0] of the post model.
+    If more than one input or output exists, a warning is raised."""
+
+    def __init__(self, pre_model):
+        super().__init__()
+        # use deep copy of model that should be inserted in the beginning of
+        # the other model to ensure that it stays unchanged
+        self.pre_model = copy.deepcopy(pre_model)
+
+    def apply(self, model):
+        graph_modified = False
+        pre_model = self.pre_model
+        post_model = copy.deepcopy(model)
+
+        # check for dynamic outputs of pre model
+        dyn_outp = []
+        for outp in pre_model.graph.output:
+            init_val = pre_model.get_initializer(outp.name)
+            if init_val is None:
+                dyn_outp.append(outp)
+
+        if len(dyn_outp) != 1:
+            warnings.warn(
+                "The pre model has more than one dynamic output! The transformation "
+                "tries to connect the first dynamic output to the first dynamic input "
+                "of the post model."
+            )
+
+        # check for dynamic inputs of post model
+        dyn_inp = []
+        for inp in post_model.graph.input:
+            init_val = post_model.get_initializer(inp.name)
+            if init_val is None:
+                dyn_inp.append(inp)
+
+        if len(dyn_inp) != 1:
+            warnings.warn(
+                "The post model has more than one dynamic input! The transformation "
+                "tries to connect the first dynamic input to the first dynamic output "
+                "of the pre model."
+            )
+
+        # erase all node names to avoid conflict
+        for n in pre_model.graph.node:
+            n.name = ""
+        for n in post_model.graph.node:
+            n.name = ""
+
+        # randomize all tensor names
+        names1 = pre_model.get_all_tensor_names()
+        names2 = post_model.get_all_tensor_names()
+        used_names = names1 + names2
+
+        # pre_model
+        for tensor_name in names1:
+            new_name = util.random_string()
+            while new_name in used_names:
+                new_name = util.random_string()
+            pre_model.rename_tensor(tensor_name, new_name)
+            used_names.append(new_name)
+
+        # post_model
+        for tensor in names2:
+            new_name = util.random_string()
+            while new_name in used_names:
+                new_name = util.random_string()
+            post_model.rename_tensor(tensor_name, new_name)
+            used_names.append(new_name)
+
+        # check if models can be merged
+        output_model_a = dyn_outp[0].name
+        input_model_b = dyn_inp[0].name
+        output_a_shape = pre_model.get_tensor_shape(output_model_a)
+        input_b_shape = post_model.get_tensor_shape(input_model_b)
+        assert (
+            output_a_shape == input_b_shape
+        ), "Models can't be merged! Shapes don't match."
+
+        # connect output of one model to input of the other
+        for n in pre_model.graph.node:
+            if output_model_a == n.output[0]:
+                n.output[0] = input_model_b
+
+        # extract information for new model
+
+        # nodes
+        node_list_a = pre_model.graph.node
+        node_list_b = post_model.graph.node
+
+        node_list = node_list_a
+        for node in node_list_b:
+            node_list.append(node)
+
+        # in and output
+        inp = pre_model.graph.input[0]
+        outp = post_model.graph.output[0]
+
+        # create new graph and model
+        new_graph = helper.make_graph(
+            nodes=node_list,
+            name="fuse-graph",
+            inputs=[inp],
+            outputs=[outp],
+            value_info=[],
+        )
+
+        new_model = helper.make_model(new_graph, producer_name="fuse_model")
+        new_model = ModelWrapper(new_model)
+
+        # add value info from both models to new model
+        # pre model
+        vi_pre = [x for x in pre_model.graph.input]
+        vi_pre += [x for x in pre_model.graph.output]
+        vi_pre += [x for x in pre_model.graph.value_info]
+        for vi in vi_pre:
+            # preserve intializers, quantization/sparsity annotation, etc.
+            # initializer
+            init_val = pre_model.get_initializer(vi.name)
+            if init_val is not None:
+                new_model.set_initializer(vi.name, init_val)
+            # FINN datatype
+            dtype = pre_model.get_tensor_datatype(vi.name)
+            new_model.set_tensor_datatype(vi.name, dtype)
+            # data layout
+            data_layout = pre_model.get_tensor_layout(vi.name)
+            if data_layout is not None:
+                new_model.set_tensor_layout(vi.name, data_layout)
+            # sparsity
+            sparsity = pre_model.get_tensor_sparsity(vi.name)
+            if sparsity is not None:
+                new_model.set_tensor_sparsity(vi.name, sparsity)
+            # graph input should not be part of graph.value_info, so don't insert
+            # if current vi == inp, but the quantization annotation is preserved
+            if vi == inp:
+                continue
+            new_model.graph.value_info.append(vi)
+
+        # post model
+        vi_model = [x for x in post_model.graph.input]
+        vi_model += [x for x in post_model.graph.output]
+        vi_model += [x for x in post_model.graph.value_info]
+        for vi in vi_model:
+            # preserve intializers, quantization/sparsity annotation, etc.
+            # initializer
+            init_val = post_model.get_initializer(vi.name)
+            if init_val is not None:
+                new_model.set_initializer(vi.name, init_val)
+            # FINN datatype
+            dtype = post_model.get_tensor_datatype(vi.name)
+            new_model.set_tensor_datatype(vi.name, dtype)
+            # data layout
+            data_layout = post_model.get_tensor_layout(vi.name)
+            if data_layout is not None:
+                new_model.set_tensor_layout(vi.name, data_layout)
+            # sparsity
+            sparsity = post_model.get_tensor_sparsity(vi.name)
+            if sparsity is not None:
+                new_model.set_tensor_sparsity(vi.name, sparsity)
+            # graph output should not be part of graph.value_info, so don't insert
+            # if current vi == outp, but the quantization annotation is preserved
+            if vi == outp:
+                continue
+            new_model.graph.value_info.append(vi)
+
+        # tidy-up new model
+        model = new_model
+        model = model.transform(InferShapes())
+        model = model.transform(InferDataTypes())
+        model = model.transform(InferDataLayouts())
+        model = model.transform(GiveUniqueNodeNames())
+        model = model.transform(GiveUniqueParameterTensors())
+        model = model.transform(GiveReadableTensorNames())
+
+        return (model, graph_modified)
diff --git a/src/finn/transformation/streamline/absorb.py b/src/finn/transformation/streamline/absorb.py
index f089275c221f769daace3e9628a00bf87b4e5457..8398a277443530e84632d26fbfca6d90ea4b0b9e 100644
--- a/src/finn/transformation/streamline/absorb.py
+++ b/src/finn/transformation/streamline/absorb.py
@@ -317,11 +317,13 @@ class AbsorbTransposeIntoMultiThreshold(Transformation):
         graph_modified = False
         for n in graph.node:
             node_ind += 1
-            if n.op_type == "Transpose":
+            if n.op_type == "Transpose" and not model.is_fork_node(n):
                 perms = list(get_by_name(n.attribute, "perm").ints)
                 if perms == [0, 3, 1, 2]:
                     mt_cand = model.find_consumer(n.output[0])
-                    if mt_cand.op_type == "MultiThreshold":
+                    if mt_cand.op_type == "MultiThreshold" and not model.is_fork_node(
+                        mt_cand
+                    ):
                         final_t_cand = model.find_consumer(mt_cand.output[0])
                         if final_t_cand.op_type == "Transpose":
                             perms = list(
@@ -358,6 +360,7 @@ class AbsorbTransposeIntoMultiThreshold(Transformation):
             model = model.transform(InferDataTypes())
         return (model, graph_modified)
 
+
 class AbsorbTransposeIntoFlatten(Transformation):
     """Absorb transpose node into succeeding flatten node, if H=W=1 and the first
     dimension stays the same. Can also be applied if flatten is implemented implicitly
@@ -417,9 +420,10 @@ class AbsorbTransposeIntoFlatten(Transformation):
                     graph.node.insert(node_ind, node)
                     graph_modified = True
         if graph_modified:
-          model = model.transform(InferDataTypes())
+            model = model.transform(InferDataTypes())
         return (model, graph_modified)
-      
+
+
 class AbsorbScalarMulIntoTopK(Transformation):
     """Absorb a mul node into a suceeding topk node if the mul is scalar."""
 
@@ -453,3 +457,84 @@ class AbsorbScalarMulIntoTopK(Transformation):
             model = model.transform(InferShapes())
             model = model.transform(InferDataTypes())
         return (model, graph_modified)
+
+
+class AbsorbConsecutiveTransposes(Transformation):
+    """Remove (Transpose -> Transpose) patterns when the input and output
+    of the pattern have the same layout."""
+
+    def Are_opposite_permutations(self, perms1, perms2):
+        if len(perms1) != len(perms2):
+            return False
+        assert 0 <= max(perms2) < len(perms2), "invalid permutation"
+        assert 0 <= max(perms1) < len(perms1), "invalid permutation"
+
+        for i, p in enumerate(perms2):
+            if perms1[p] != i:
+                return False
+
+        return True
+
+    def apply(self, model):
+        graph = model.graph
+        graph_modified = False
+        for n in graph.node:
+            if n.op_type == "Transpose":
+                if model.is_fork_node(n):
+                    next_nodes = model.find_direct_successors(n)
+                    perms1 = list(get_by_name(n.attribute, "perm").ints)
+
+                    # check if all nodes after fork are opposite transposes
+                    all_opposite_transposes = True
+                    for next_node in next_nodes:
+                        if next_node is not None and next_node.op_type == "Transpose":
+                            perms2 = list(get_by_name(next_node.attribute, "perm").ints)
+                            if not self.Are_opposite_permutations(perms1, perms2):
+                                all_opposite_transposes = False
+                                break
+                        else:
+                            all_opposite_transposes = False
+                            break
+
+                    if not all_opposite_transposes:
+                        continue
+
+                    prod = model.find_producer(n.input[0])
+                    for next_node in next_nodes:
+                        # connect next_node's consumer input to n's producer output
+                        # TODO implement this to allow for forks as producers and
+                        # joins as consumers
+                        cons = model.find_consumer(next_node.output[0])
+                        cons.input[0] = prod.output[0]
+
+                        # remove consumer transpose
+                        graph.node.remove(next_node)
+
+                    # remove producer transpose
+                    graph.node.remove(n)
+                    graph_modified = True
+
+                else:
+                    next_node = model.find_consumer(n.output[0])
+                    if next_node is not None and next_node.op_type == "Transpose":
+                        perms1 = list(get_by_name(n.attribute, "perm").ints)
+                        perms2 = list(get_by_name(next_node.attribute, "perm").ints)
+                        if self.Are_opposite_permutations(perms1, perms2):
+
+                            # connect next_node's consumer input to n's producer output
+                            # TODO implement this to allow for forks as producers
+                            consumers = model.find_direct_successors(next_node)
+                            prod = model.find_producer(n.input[0])
+                            for cons in consumers:
+                                for cons_in in cons.input:
+                                    if cons_in == next_node.output[0]:
+                                        prod.output[0] = cons_in
+                                        break
+                            # remove both transposes
+                            graph.node.remove(n)
+                            graph.node.remove(next_node)
+
+                            graph_modified = True
+        if graph_modified:
+            model = model.transform(InferDataTypes())
+        return (model, graph_modified)
diff --git a/tests/brevitas/test_brevitas_cnv.py b/tests/brevitas/test_brevitas_cnv.py
index f91ca600d3f0ce3b1cda3c29216fe8e0e3f415e4..764671bee13710ef1d9fa21aab5ef600075b9b0d 100644
--- a/tests/brevitas/test_brevitas_cnv.py
+++ b/tests/brevitas/test_brevitas_cnv.py
@@ -38,7 +38,7 @@ import finn.core.onnx_exec as oxe
 from finn.core.modelwrapper import ModelWrapper
 from finn.transformation.fold_constants import FoldConstants
 from finn.transformation.infer_shapes import InferShapes
-from finn.transformation.general import GiveUniqueNodeNames
+from finn.transformation.general import GiveUniqueNodeNames, RemoveStaticGraphInputs
 from finn.transformation.double_to_single_float import DoubleToSingleFloat
 from finn.util.test import get_test_model_trained
 
@@ -57,6 +57,9 @@ def test_brevitas_cnv_export_exec(wbits, abits):
     model = model.transform(DoubleToSingleFloat())
     model = model.transform(InferShapes())
     model = model.transform(FoldConstants())
+    model = model.transform(RemoveStaticGraphInputs())
+    assert len(model.graph.input) == 1
+    assert len(model.graph.output) == 1
     fn = pk.resource_filename("finn", "data/cifar10/cifar10-test-data-class3.npz")
     input_tensor = np.load(fn)["arr_0"].astype(np.float32)
     input_tensor = input_tensor / 255
diff --git a/tests/brevitas/test_brevitas_fc.py b/tests/brevitas/test_brevitas_fc.py
index db18d91e3590e896e111c9e38bdc4de43872a98c..9369b25385080875efcb286c02291fc579a15a34 100644
--- a/tests/brevitas/test_brevitas_fc.py
+++ b/tests/brevitas/test_brevitas_fc.py
@@ -39,6 +39,7 @@ import torch
 import finn.core.onnx_exec as oxe
 from finn.core.modelwrapper import ModelWrapper
 from finn.transformation.fold_constants import FoldConstants
+from finn.transformation.general import RemoveStaticGraphInputs
 from finn.transformation.infer_shapes import InferShapes
 from finn.util.basic import make_build_dir
 from finn.util.test import get_test_model_trained
@@ -63,6 +64,9 @@ def test_brevitas_fc_onnx_export_and_exec(size, wbits, abits):
     model = ModelWrapper(finn_onnx)
     model = model.transform(InferShapes())
     model = model.transform(FoldConstants())
+    model = model.transform(RemoveStaticGraphInputs())
+    assert len(model.graph.input) == 1
+    assert len(model.graph.output) == 1
     # load one of the test vectors
     raw_i = get_data("finn", "data/onnx/mnist-conv/test_data_set_0/input_0.pb")
     input_tensor = onnx.load_tensor_from_string(raw_i)
diff --git a/tests/end2end/test_end2end_cnv_w1a1.py b/tests/end2end/test_end2end_cnv_w1a1.py
index e3f281904d7db1349d74d6eb70cad20a8f3d10af..a2cfcd3a864c12788c2ac73271b5782ddfa336c1 100644
--- a/tests/end2end/test_end2end_cnv_w1a1.py
+++ b/tests/end2end/test_end2end_cnv_w1a1.py
@@ -42,7 +42,12 @@ from finn.transformation.double_to_single_float import DoubleToSingleFloat
 from finn.transformation.infer_shapes import InferShapes
 from finn.transformation.move_reshape import RemoveCNVtoFCFlatten
 from finn.transformation.fold_constants import FoldConstants
-from finn.transformation.general import GiveReadableTensorNames, GiveUniqueNodeNames
+from finn.transformation.general import (
+    RemoveUnusedTensors,
+    RemoveStaticGraphInputs,
+    GiveReadableTensorNames,
+    GiveUniqueNodeNames,
+)
 from finn.transformation.streamline import Streamline
 from finn.transformation.lower_convs_to_matmul import LowerConvsToMatMul
 from finn.transformation.bipolar_to_xnor import ConvertBipolarMatMulToXnorPopcount
@@ -97,6 +102,7 @@ def test_end2end_cnv_w1a1_import_and_tidy():
     model = model.transform(FoldConstants())
     model = model.transform(GiveUniqueNodeNames())
     model = model.transform(GiveReadableTensorNames())
+    model = model.transform(RemoveStaticGraphInputs())
     model.save(build_dir + "/end2end_cnv_w1a1_tidy.onnx")
 
 
@@ -108,6 +114,7 @@ def test_end2end_cnv_w1a1_streamline():
     model = model.transform(absorb.AbsorbTransposeIntoMultiThreshold())
     model = model.transform(ConvertBipolarMatMulToXnorPopcount())
     model = model.transform(Streamline())
+    model = model.transform(RemoveUnusedTensors())
     model.save(build_dir + "/end2end_cnv_w1a1_streamlined.onnx")
 
 
diff --git a/tests/end2end/test_end2end_cnv_w2a2.py b/tests/end2end/test_end2end_cnv_w2a2.py
index 31ccebd4c175ad2badef17499bf113d978b637f7..f45b0a3eccd2f52ea144405865a1df06315952d9 100644
--- a/tests/end2end/test_end2end_cnv_w2a2.py
+++ b/tests/end2end/test_end2end_cnv_w2a2.py
@@ -42,7 +42,12 @@ from finn.transformation.double_to_single_float import DoubleToSingleFloat
 from finn.transformation.infer_shapes import InferShapes
 from finn.transformation.move_reshape import RemoveCNVtoFCFlatten
 from finn.transformation.fold_constants import FoldConstants
-from finn.transformation.general import GiveReadableTensorNames, GiveUniqueNodeNames
+from finn.transformation.general import (
+    RemoveUnusedTensors,
+    RemoveStaticGraphInputs,
+    GiveReadableTensorNames,
+    GiveUniqueNodeNames,
+)
 from finn.transformation.streamline import Streamline
 from finn.transformation.lower_convs_to_matmul import LowerConvsToMatMul
 import finn.transformation.streamline.absorb as absorb
@@ -96,6 +101,7 @@ def test_end2end_cnv_w2a2_import_and_tidy():
     model = model.transform(FoldConstants())
     model = model.transform(GiveUniqueNodeNames())
     model = model.transform(GiveReadableTensorNames())
+    model = model.transform(RemoveStaticGraphInputs())
     model.save(build_dir + "/end2end_cnv_w2a2_tidy.onnx")
 
 
@@ -106,6 +112,7 @@ def test_end2end_cnv_w2a2_streamline():
     model = model.transform(MakeMaxPoolNHWC())
     model = model.transform(absorb.AbsorbTransposeIntoMultiThreshold())
     model = model.transform(Streamline())
+    model = model.transform(RemoveUnusedTensors())
     model.save(build_dir + "/end2end_cnv_w2a2_streamlined.onnx")
 
 
diff --git a/tests/end2end/test_end2end_tfc_w1a1.py b/tests/end2end/test_end2end_tfc_w1a1.py
index ebfed5e571f1e7e2499c3501c6859239a329677a..31659df631e8ab489cb63dbef51200f313bca6b3 100644
--- a/tests/end2end/test_end2end_tfc_w1a1.py
+++ b/tests/end2end/test_end2end_tfc_w1a1.py
@@ -63,7 +63,12 @@ from finn.transformation.fpgadataflow.replace_verilog_relpaths import (
 )
 from finn.transformation.fpgadataflow.set_exec_mode import SetExecMode
 from finn.transformation.fpgadataflow.synth_pynq_proj import SynthPYNQProject
-from finn.transformation.general import GiveReadableTensorNames, GiveUniqueNodeNames
+from finn.transformation.general import (
+    RemoveUnusedTensors,
+    RemoveStaticGraphInputs,
+    GiveReadableTensorNames,
+    GiveUniqueNodeNames,
+)
 from finn.transformation.infer_datatypes import InferDataTypes
 from finn.transformation.infer_shapes import InferShapes
 from finn.transformation.streamline import Streamline
@@ -98,12 +103,14 @@ def test_end2end_tfc_w1a1_import_and_tidy():
     model = model.transform(GiveUniqueNodeNames())
     model = model.transform(GiveReadableTensorNames())
     model = model.transform(InferDataTypes())
+    model = model.transform(RemoveStaticGraphInputs())
     model.save(build_dir + "/end2end_tfc_w1a1_tidy.onnx")
 
 
 def test_end2end_tfc_w1a1_streamline():
     model = load_test_checkpoint_or_skip(build_dir + "/end2end_tfc_w1a1_tidy.onnx")
     model = model.transform(Streamline())
+    model = model.transform(RemoveUnusedTensors())
     model.save(build_dir + "/end2end_tfc_w1a1_streamlined.onnx")
 
 
diff --git a/tests/end2end/test_end2end_tfc_w1a2.py b/tests/end2end/test_end2end_tfc_w1a2.py
index d4c005a86580fb36e735beb00717fcfdffff21e5..d5579f625a20ae26e18bcdcba0cfaa3042a71b9a 100644
--- a/tests/end2end/test_end2end_tfc_w1a2.py
+++ b/tests/end2end/test_end2end_tfc_w1a2.py
@@ -61,7 +61,12 @@ from finn.transformation.fpgadataflow.replace_verilog_relpaths import (
 )
 from finn.transformation.fpgadataflow.set_exec_mode import SetExecMode
 from finn.transformation.fpgadataflow.synth_pynq_proj import SynthPYNQProject
-from finn.transformation.general import GiveReadableTensorNames, GiveUniqueNodeNames
+from finn.transformation.general import (
+    RemoveUnusedTensors,
+    RemoveStaticGraphInputs,
+    GiveReadableTensorNames,
+    GiveUniqueNodeNames,
+)
 from finn.transformation.infer_datatypes import InferDataTypes
 from finn.transformation.infer_shapes import InferShapes
 from finn.transformation.streamline import Streamline
@@ -93,12 +98,14 @@ def test_end2end_tfc_w1a2_import_and_tidy():
     model = model.transform(GiveUniqueNodeNames())
     model = model.transform(GiveReadableTensorNames())
     model = model.transform(InferDataTypes())
+    model = model.transform(RemoveStaticGraphInputs())
     model.save(build_dir + "/end2end_tfc_w1a2_tidy.onnx")
 
 
 def test_end2end_tfc_w1a2_streamline():
     model = load_test_checkpoint_or_skip(build_dir + "/end2end_tfc_w1a2_tidy.onnx")
     model = model.transform(Streamline())
+    model = model.transform(RemoveUnusedTensors())
     model.save(build_dir + "/end2end_tfc_w1a2_streamlined.onnx")
 
 
diff --git a/tests/end2end/test_end2end_tfc_w2a2.py b/tests/end2end/test_end2end_tfc_w2a2.py
index 19d3f86e046658c4080d71984df1cff74008adab..470119f3444987f0156caff81bf556bf4f2f2cbb 100644
--- a/tests/end2end/test_end2end_tfc_w2a2.py
+++ b/tests/end2end/test_end2end_tfc_w2a2.py
@@ -61,7 +61,12 @@ from finn.transformation.fpgadataflow.replace_verilog_relpaths import (
 )
 from finn.transformation.fpgadataflow.set_exec_mode import SetExecMode
 from finn.transformation.fpgadataflow.synth_pynq_proj import SynthPYNQProject
-from finn.transformation.general import GiveReadableTensorNames, GiveUniqueNodeNames
+from finn.transformation.general import (
+    RemoveUnusedTensors,
+    RemoveStaticGraphInputs,
+    GiveReadableTensorNames,
+    GiveUniqueNodeNames,
+)
 from finn.transformation.infer_datatypes import InferDataTypes
 from finn.transformation.infer_shapes import InferShapes
 from finn.transformation.streamline import Streamline
@@ -93,12 +98,14 @@ def test_end2end_tfc_w2a2_import_and_tidy():
     model = model.transform(GiveUniqueNodeNames())
     model = model.transform(GiveReadableTensorNames())
     model = model.transform(InferDataTypes())
+    model = model.transform(RemoveStaticGraphInputs())
     model.save(build_dir + "/end2end_tfc_w2a2_tidy.onnx")
 
 
 def test_end2end_tfc_w2a2_streamline():
     model = load_test_checkpoint_or_skip(build_dir + "/end2end_tfc_w2a2_tidy.onnx")
     model = model.transform(Streamline())
+    model = model.transform(RemoveUnusedTensors())
     model.save(build_dir + "/end2end_tfc_w2a2_streamlined.onnx")
 
 
diff --git a/tests/fpgadataflow/test_fpgadataflow_labelselect.py b/tests/fpgadataflow/test_fpgadataflow_labelselect.py
index 2df841728395229dafe33d2804c44a3489ef3e45..9bc77cd47fd6115823f9a35d98e8874ee3f98b2d 100644
--- a/tests/fpgadataflow/test_fpgadataflow_labelselect.py
+++ b/tests/fpgadataflow/test_fpgadataflow_labelselect.py
@@ -27,6 +27,7 @@
 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 
 import pytest
+import numpy as np
 
 from onnx import TensorProto, helper
 
@@ -70,7 +71,8 @@ def make_labelselect_modelwrapper(labels, pe, k, idt):
     model = ModelWrapper(model)
 
     model.set_tensor_datatype("inp", idt)
-    model.set_tensor_datatype("outp", DataType.UINT32)
+    odt = DataType.get_smallest_possible(labels - 1)
+    model.set_tensor_datatype("outp", odt)
 
     return model
 
@@ -79,19 +81,18 @@ def prepare_inputs(input_tensor, idt):
     return {"inp": input_tensor}
 
 
-# TODO: folded inputs fail, likely problem in hlslib
-# input datatype -- checked by assertion in HLSCustomOp
-@pytest.mark.parametrize("idt", [DataType.UINT8, DataType.UINT16])
+@pytest.mark.parametrize("idt", [DataType.UINT8, DataType.UINT16, DataType.INT16])
 # labels
-@pytest.mark.parametrize("labels", [10, 1000])
+@pytest.mark.parametrize("labels", [10, 100])
 # folding
-@pytest.mark.parametrize("fold", [-1])
+@pytest.mark.parametrize("fold", [-1, 2, 10])
 # number of top labels to select
 @pytest.mark.parametrize("k", [1, 5])
 # execution mode
 @pytest.mark.parametrize("exec_mode", ["cppsim", "rtlsim"])
 @pytest.mark.vivado
 def test_fpgadataflow_labelselect(idt, labels, fold, k, exec_mode):
+    np.random.seed(0)
     if fold == -1:
         pe = 1
     else:
diff --git a/tests/transformation/streamline/test_streamline_cnv.py b/tests/transformation/streamline/test_streamline_cnv.py
index 103967dfb6b86cc6e2ce2bc9ab78249d8945d47d..bcb66a2c22eb4d6a998580129881793bbc86b250 100644
--- a/tests/transformation/streamline/test_streamline_cnv.py
+++ b/tests/transformation/streamline/test_streamline_cnv.py
@@ -34,7 +34,12 @@ import pkg_resources as pk
 import finn.core.onnx_exec as oxe
 from finn.core.modelwrapper import ModelWrapper
 from finn.transformation.fold_constants import FoldConstants
-from finn.transformation.general import GiveReadableTensorNames, GiveUniqueNodeNames
+from finn.transformation.general import (
+    RemoveUnusedTensors,
+    RemoveStaticGraphInputs,
+    GiveReadableTensorNames,
+    GiveUniqueNodeNames,
+)
 from finn.transformation.infer_shapes import InferShapes
 from finn.transformation.streamline import Streamline
 from finn.util.test import get_test_model_trained
@@ -62,6 +67,7 @@ def test_streamline_cnv(size, wbits, abits):
     model = model.transform(FoldConstants())
     model = model.transform(GiveUniqueNodeNames())
     model = model.transform(GiveReadableTensorNames())
+    model = model.transform(RemoveStaticGraphInputs())
     # load one of the test vectors
     fn = pk.resource_filename("finn", "data/cifar10/cifar10-test-data-class3.npz")
     input_tensor = np.load(fn)["arr_0"].astype(np.float32)
@@ -73,6 +79,9 @@ def test_streamline_cnv(size, wbits, abits):
     expected = expected_ctx[model.graph.output[0].name]
     # model.save("orig_cnv.onnx")
     model = model.transform(Streamline())
+    model = model.transform(RemoveUnusedTensors())
+    assert len(model.graph.initializer) == 21
+    assert len(model.graph.value_info) == 43
     # model.save("streamlined_cnv.onnx")
     assert len(model.graph.node) == 23
     produced_ctx = oxe.execute_onnx(model, input_dict, True)
diff --git a/tests/transformation/streamline/test_streamline_fc.py b/tests/transformation/streamline/test_streamline_fc.py
index c68561239b7c30973856fa282d20cd2afaa168ae..dd7e756b4021af26c228804d4b509ecff032347e 100644
--- a/tests/transformation/streamline/test_streamline_fc.py
+++ b/tests/transformation/streamline/test_streamline_fc.py
@@ -37,7 +37,12 @@ import pytest
 import finn.core.onnx_exec as oxe
 from finn.core.modelwrapper import ModelWrapper
 from finn.transformation.fold_constants import FoldConstants
-from finn.transformation.general import GiveReadableTensorNames, GiveUniqueNodeNames
+from finn.transformation.general import (
+    RemoveUnusedTensors,
+    RemoveStaticGraphInputs,
+    GiveReadableTensorNames,
+    GiveUniqueNodeNames,
+)
 from finn.transformation.infer_shapes import InferShapes
 from finn.transformation.streamline import Streamline
 from finn.util.test import get_test_model_trained
@@ -65,6 +70,7 @@ def test_streamline_fc(size, wbits, abits):
     model = model.transform(FoldConstants())
     model = model.transform(GiveUniqueNodeNames())
     model = model.transform(GiveReadableTensorNames())
+    model = model.transform(RemoveStaticGraphInputs())
     # load one of the test vectors
     raw_i = get_data("finn", "data/onnx/mnist-conv/test_data_set_0/input_0.pb")
     input_tensor = onnx.load_tensor_from_string(raw_i)
@@ -73,6 +79,10 @@ def test_streamline_fc(size, wbits, abits):
     expected_ctx = oxe.execute_onnx(model, input_dict, True)
     expected = expected_ctx[model.graph.output[0].name]
     model = model.transform(Streamline())
+    model = model.transform(RemoveUnusedTensors())
+    assert len(model.graph.initializer) == 11
+    assert len(model.graph.value_info) == 21
+    assert len(model.graph.quantization_annotation) == 18
     produced_ctx = oxe.execute_onnx(model, input_dict, True)
     produced = produced_ctx[model.graph.output[0].name]
     assert np.isclose(expected, produced, atol=1e-3).all()
diff --git a/tests/transformation/test_absorb_opposite_transposes.py b/tests/transformation/test_absorb_opposite_transposes.py
new file mode 100644
index 0000000000000000000000000000000000000000..859e691277a261f01b559e2e166763e402c5d689
--- /dev/null
+++ b/tests/transformation/test_absorb_opposite_transposes.py
@@ -0,0 +1,76 @@
+# Copyright (c) 2020, Xilinx
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+#   list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+#   this list of conditions and the following disclaimer in the documentation
+#   and/or other materials provided with the distribution.
+#
+# * Neither the name of FINN nor the names of its
+#   contributors may be used to endorse or promote products derived from
+#   this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import numpy as np
+import onnx.helper as oh
+from onnx import TensorProto
+
+import finn.core.onnx_exec as ox
+from finn.core.modelwrapper import ModelWrapper
+from finn.transformation.infer_shapes import InferShapes
+from finn.transformation.streamline.absorb import AbsorbConsecutiveTransposes
+
+
+def test_absorb_opposite_transposes():
+    np.random.seed(0)
+    input_shape = [1, 3, 4, 2]
+    top_in = oh.make_tensor_value_info("top_in", TensorProto.FLOAT, input_shape)
+    top_out = oh.make_tensor_value_info("top_out", TensorProto.FLOAT, input_shape)
+    value_info = [oh.make_tensor_value_info("add_param_0", TensorProto.FLOAT, [1])]
+    value_info += [oh.make_tensor_value_info("add_param_1", TensorProto.FLOAT, [1])]
+    value_info += [oh.make_tensor_value_info("mul_param_0", TensorProto.FLOAT, [1])]
+    modelproto = oh.make_model(
+        oh.make_graph(
+            name="test",
+            inputs=[top_in],
+            outputs=[top_out],
+            value_info=value_info,
+            nodes=[
+                oh.make_node("Add", ["top_in", "add_param_0"], ["t0"]),
+                oh.make_node("Transpose", ["t0"], ["t1"], perm=[0, 2, 3, 1]),
+                oh.make_node("Transpose", ["t1"], ["t2"], perm=[0, 3, 1, 2]),
+                oh.make_node("Add", ["t2", "add_param_1"], ["t3"]),
+                oh.make_node("Transpose", ["t3"], ["t4"], perm=[0, 2, 3, 1]),
+                oh.make_node("Transpose", ["t4"], ["t5"], perm=[0, 3, 1, 2]),
+                oh.make_node("Add", ["t5", "t2"], ["t6"]),
+                oh.make_node("Mul", ["t6", "mul_param_0"], ["top_out"]),
+            ],
+        )
+    )
+    model = ModelWrapper(modelproto)
+    model = model.transform(InferShapes())
+    model.set_initializer("add_param_0", np.asarray([1], dtype=np.float32))
+    model.set_initializer("add_param_1", np.asarray([3], dtype=np.float32))
+    model.set_initializer("mul_param_0", np.asarray([2], dtype=np.float32))
+    new_model = model.transform(AbsorbConsecutiveTransposes())
+    new_model = new_model.transform(InferShapes())
+    inp_dict = {"top_in": np.random.rand(*input_shape).astype(np.float32)}
+    assert ox.compare_execution(model, model, inp_dict)
+    assert len(new_model.graph.node) == 4
+    for n in new_model.graph.node:
+        assert new_model.graph.node[0].op_type != "Transpose"
diff --git a/tests/transformation/test_merge_onnx_models.py b/tests/transformation/test_merge_onnx_models.py
new file mode 100644
index 0000000000000000000000000000000000000000..db7c990baddfb50a39603937a9c5b73f512a0e59
--- /dev/null
+++ b/tests/transformation/test_merge_onnx_models.py
@@ -0,0 +1,126 @@
+# Copyright (c) 2020, Xilinx
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+#   list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+#   this list of conditions and the following disclaimer in the documentation
+#   and/or other materials provided with the distribution.
+#
+# * Neither the name of FINN nor the names of its
+#   contributors may be used to endorse or promote products derived from
+#   this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+from pkgutil import get_data
+
+import numpy as np
+import onnx
+import onnx.numpy_helper as np_helper
+from onnx import TensorProto, helper
+
+from finn.core.modelwrapper import ModelWrapper
+from finn.core.datatype import DataType
+from finn.transformation.infer_shapes import InferShapes
+from finn.transformation.infer_datatypes import InferDataTypes
+from finn.transformation.infer_data_layouts import InferDataLayouts
+from finn.transformation.general import GiveReadableTensorNames, GiveUniqueNodeNames
+from finn.transformation.merge_onnx_models import MergeONNXModels
+import finn.core.onnx_exec as oxe
+
+
+def test_merge_onnx_models():
+    # load pre model
+    raw_m = get_data("finn", "data/onnx/mnist-conv/model.onnx")
+    model1 = ModelWrapper(raw_m)
+    # the input for model1 comes from a uint8 vector so we set the finn datatype
+    # of the input tensor to DataType.UINT8 to verify that the datatypes are correctly
+    # preserved in the transformed model
+    model1.set_tensor_datatype(model1.graph.input[0].name, DataType.UINT8)
+    model1 = model1.transform(InferShapes())
+    model1 = model1.transform(GiveUniqueNodeNames())
+    model1 = model1.transform(GiveReadableTensorNames())
+
+    # set up post model
+    shape = [1, 10]
+    inp = helper.make_tensor_value_info("inp", TensorProto.FLOAT, shape)
+    a0 = helper.make_tensor_value_info("a0", TensorProto.FLOAT, [])
+    a1 = helper.make_tensor_value_info("a1", TensorProto.FLOAT, [])
+    outp = helper.make_tensor_value_info("outp", TensorProto.FLOAT, shape)
+
+    mul_node = helper.make_node("Mul", ["inp", "a0"], ["mul_out"])
+    div_node = helper.make_node("Div", ["mul_out", "a1"], ["outp"])
+
+    graph = helper.make_graph(
+        nodes=[mul_node, div_node],
+        name="model2-graph",
+        inputs=[inp],
+        outputs=[outp],
+        value_info=[a0, a1],
+    )
+
+    model2 = helper.make_model(graph, producer_name="model2")
+    model2 = ModelWrapper(model2)
+    # initialize model2
+    a0_value = np.random.uniform(low=0, high=1, size=(1)).astype(np.float32)
+    model2.set_initializer("a0", a0_value)
+    a1_value = np.random.uniform(low=0.1, high=1, size=(1)).astype(np.float32)
+    model2.set_initializer("a1", a1_value)
+    # set a dummy sparsity annotation to check if it gets correctly transferred
+    # to the merged model
+    sparsity = {"dw": {"kernel_shape": 0}}
+    model2.set_tensor_sparsity("a1", sparsity)
+    model2 = model2.transform(InferShapes())
+    model2 = model2.transform(InferDataTypes())
+    model2 = model2.transform(InferDataLayouts())
+    model2 = model2.transform(GiveUniqueNodeNames())
+    model2 = model2.transform(GiveReadableTensorNames())
+
+    # simulate the models before the merging and pass the output of model1 to model2
+    # load one of the test vectors
+    raw_i = get_data("finn", "data/onnx/mnist-conv/test_data_set_0/input_0.pb")
+    inp_values = onnx.load_tensor_from_string(raw_i)
+    inp_values = np_helper.to_array(inp_values)
+    idict = {model1.graph.input[0].name: inp_values}
+    odict = oxe.execute_onnx(model1, idict)
+    temp = odict[model1.graph.output[0].name]
+
+    idict = {model2.graph.input[0].name: temp}
+    odict = oxe.execute_onnx(model2, idict)
+    outp = odict[model2.graph.output[0].name]
+    # merge models
+    model_transformed = model2.transform(MergeONNXModels(model1))
+
+    idict = {model_transformed.graph.input[0].name: inp_values}
+    odict = oxe.execute_onnx(model_transformed, idict)
+    outp_transformed = odict[model_transformed.graph.output[0].name]
+
+    assert (outp == outp_transformed).all()
+    assert len(model_transformed.graph.node) == len(model1.graph.node) + len(
+        model2.graph.node
+    )
+    # to test if the value is preserved we set the sparsity annotation of input[1]
+    # of the division block to a dummy value, we can now look for the division block
+    # and check if the sparsity annotation is still the same
+    for n in model_transformed.graph.node:
+        if n.op_type == "Div":
+            tensor_name = n.input[1]
+            set_sparsity = model_transformed.get_tensor_sparsity(tensor_name)
+            assert sparsity == set_sparsity
+
+    # check if finn datatype of graph.input[0] is still set to UINT8
+    assert model_transformed.get_tensor_datatype("global_in") == DataType.UINT8