diff --git a/src/finn/core/modelwrapper.py b/src/finn/core/modelwrapper.py
index 6bf468368648aa4ee0916c593f0a4d8ce75343f5..aab6839bac3dc15c96201ad0a311900d19487a3c 100644
--- a/src/finn/core/modelwrapper.py
+++ b/src/finn/core/modelwrapper.py
@@ -332,6 +332,22 @@ class ModelWrapper:
         else:
             return None
 
+    def is_fork_node(self, node):
+        """Checks if the given node is a fork, that is, the node has multiple
+        direct successors"""
+        direct_successors = self.find_direct_successors(node)
+        is_fork = False if direct_successors is None else (len(direct_successors) > 1)
+        return is_fork
+
+    def is_join_node(self, node):
+        """Checks if the given node is a join, that is, the node has multiple
+        direct predecessors"""
+        direct_predecessors = self.find_direct_predecessors(node)
+        is_join = (
+            False if direct_predecessors is None else (len(direct_predecessors) > 1)
+        )
+        return is_join
+
     def get_all_tensor_names(self):
         """Returns a list of all (input, output and value_info) tensor names
         in the graph."""
diff --git a/src/finn/custom_op/fpgadataflow/__init__.py b/src/finn/custom_op/fpgadataflow/__init__.py
index c77fd81c0bfaa77b458368807410b8bfec17abb7..17a55e519ed0440f68e295aecaab179e6adf632f 100644
--- a/src/finn/custom_op/fpgadataflow/__init__.py
+++ b/src/finn/custom_op/fpgadataflow/__init__.py
@@ -40,6 +40,7 @@ from finn.util.basic import (
 from finn.util.fpgadataflow import (
     IPGenBuilder,
     pyverilate_get_liveness_threshold_cycles,
+    rtlsim_multi_io,
 )
 from . import templates
 
@@ -318,14 +319,24 @@ Found no codegen dir for this node, did you run the prepare_cppsim transformatio
             )
 
     def npy_to_dynamic_output(self, context):
-        """Reads the output from a .npy file and saves it at the right place in
-        the context dictionary."""
-        # TODO support multi-output nodes as needed
+        """Reads the output from an output.npy file generated from cppsim and
+        places its content into the context dictionary."""
         node = self.onnx_node
         code_gen_dir = self.get_nodeattr("code_gen_dir_cppsim")
         output = np.load("{}/output.npy".format(code_gen_dir))
         context[node.output[0]] = output
 
+    def npy_to_dynamic_outputs(self, context, npy_list):
+        """Reads the output from .npy files generated from cppsim and places
+        their content into the context dictionary.
+        npy_list is a list specifying which files to read, and its order must
+        match the order of node outputs."""
+        node = self.onnx_node
+        code_gen_dir = self.get_nodeattr("code_gen_dir_cppsim")
+        for i in range(len(npy_list)):
+            output = np.load("{}/{}".format(code_gen_dir, npy_list[i]))
+            context[node.output[i]] = output
+
     def exec_precompiled_singlenode_model(self):
         """Executes precompiled executable."""
         executable_path = self.get_nodeattr("executable_path")
@@ -421,6 +432,16 @@ compilation transformations?
             sim.stop_vcd_trace()
         return outputs
 
+    def rtlsim_multi_io(self, sim, io_dict):
+        "Run rtlsim for this node, supports multiple i/o streams."
+
+        trace_file = self.get_nodeattr("rtlsim_trace")
+        if trace_file == "default":
+            trace_file = self.onnx_node.name + ".vcd"
+        num_out_values = self.get_number_output_values()
+        total_cycle_count = rtlsim_multi_io(sim, io_dict, num_out_values, trace_file)
+        self.set_nodeattr("sim_cycles", total_cycle_count)
+
     def execute_node(self, context, graph):
         """Executes single node using cppsim or rtlsim."""
         mode = self.get_nodeattr("exec_mode")
diff --git a/src/finn/custom_op/fpgadataflow/duplicatestreams_batch.py b/src/finn/custom_op/fpgadataflow/duplicatestreams_batch.py
new file mode 100644
index 0000000000000000000000000000000000000000..54051af5e0387081a23e1f8fa77ec9e363098830
--- /dev/null
+++ b/src/finn/custom_op/fpgadataflow/duplicatestreams_batch.py
@@ -0,0 +1,361 @@
+# Copyright (c) 2020, Xilinx
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+#   list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+#   this list of conditions and the following disclaimer in the documentation
+#   and/or other materials provided with the distribution.
+#
+# * Neither the name of FINN nor the names of its
+#   contributors may be used to endorse or promote products derived from
+#   this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import os
+
+import numpy as np
+
+from finn.core.datatype import DataType
+from finn.custom_op.fpgadataflow import HLSCustomOp
+from onnx import TensorProto, helper
+from finn.util.data_packing import npy_to_rtlsim_input, rtlsim_output_to_npy
+
+
+class DuplicateStreams_Batch(HLSCustomOp):
+    """Class that corresponds to finn-hlslib function of the same name."""
+
+    def __init__(self, onnx_node):
+        super().__init__(onnx_node)
+
+    def get_nodeattr_types(self):
+        my_attrs = {
+            "NumChannels": ("i", True, 0),
+            "PE": ("i", True, 0),
+            # FINN DataTypes for input
+            "inputDataType": ("s", True, ""),
+            # number of input vectors, examples:
+            # [1] is a single vector (like a FC layer with batch=1)
+            # [4] is four vectors (like a FC layer with batch=4)
+            # [1, 4, 4] is four * four vectors (like a conv layer with batch=1)
+            "numInputVectors": ("ints", False, [1]),
+        }
+        my_attrs.update(super().get_nodeattr_types())
+        return my_attrs
+
+    def get_normal_input_shape(self):
+        ch = self.get_nodeattr("NumChannels")
+        vecs = list(self.get_nodeattr("numInputVectors"))
+        ishape = tuple(vecs + [ch])
+        return ishape
+
+    def get_folded_input_shape(self):
+        ch = self.get_nodeattr("NumChannels")
+        pe = self.get_nodeattr("PE")
+        vecs = list(self.get_nodeattr("numInputVectors"))
+        assert ch % pe == 0, "PE must divide NumChannels"
+        folds = int(ch / pe)
+        folded_ishape = tuple(vecs + [folds, pe])
+        return folded_ishape
+
+    def get_normal_output_shape(self):
+        return self.get_normal_input_shape()
+
+    def get_folded_output_shape(self):
+        return self.get_folded_input_shape()
+
+    def make_shape_compatible_op(self, model):
+        exp_ishape = self.get_normal_input_shape()
+        oshape = self.get_normal_output_shape()
+        ishape = tuple(model.get_tensor_shape(self.onnx_node.input[0]))
+        assert ishape == exp_ishape, "Unexpected input shape."
+        # implement tensor with correct shape
+        values = np.random.randn(*oshape).astype(np.float32)
+        split_input = np.concatenate((values, values), axis=0)
+        return helper.make_node(
+            "Split",
+            inputs=[split_input],
+            outputs=[self.onnx_node.output[0], self.onnx_node.output[0]],
+            value=helper.make_tensor(
+                name="const_tensor", data_type=TensorProto.FLOAT, axis=0
+            ),
+        )
+
+    def infer_node_datatype(self, model):
+        odt = self.get_output_datatype()
+        model.set_tensor_datatype(self.onnx_node.output[0], odt)
+
+    def verify_node(self):
+        info_messages = []
+        # verify that "domain" is set to "finn"
+        domain_value = self.onnx_node.domain
+        if domain_value == "finn":
+            info_messages.append("Attribute domain is set correctly")
+        else:
+            info_messages.append('Attribute domain should be set to "finn"')
+
+        # verify that "backend" is set to "fpgadataflow"
+        backend_value = self.get_nodeattr("backend")
+        if backend_value == "fpgadataflow":
+            info_messages.append("Attribute backend is set correctly")
+        else:
+            info_messages.append('Attribute backend should be set to "fpgadataflow"')
+
+        # verify that all necessary attributes exist
+        try:
+            self.get_nodeattr("code_gen_dir_cppsim")
+            self.get_nodeattr("executable_path")
+            self.get_nodeattr("NumChannels")
+            self.get_nodeattr("PE")
+            self.get_nodeattr("inputDataType")
+            info_messages.append("All necessary attributes exist")
+        except Exception:
+            info_messages.append(
+                """The required GlobalAccPool_Batch attributes do not exist."""
+            )
+
+        return info_messages
+
+    def get_input_datatype(self):
+        """Returns FINN DataType of input."""
+        return DataType[self.get_nodeattr("inputDataType")]
+
+    def get_output_datatype(self):
+        """Returns FINN DataType of output."""
+        return DataType[self.get_nodeattr("inputDataType")]
+
+    def get_instream_width(self):
+        """Returns input stream width."""
+        ibits = self.get_input_datatype().bitwidth()
+        pe = self.get_nodeattr("PE")
+        in_width = pe * ibits
+        return in_width
+
+    def get_outstream_width(self):
+        """Returns output stream width."""
+        obits = self.get_output_datatype().bitwidth()
+        pe = self.get_nodeattr("PE")
+        out_width = pe * obits
+        return out_width
+
+    def get_number_output_values(self):
+        return 2 * np.prod(self.get_folded_output_shape()[1:-1])
+
+    def execute_node(self, context, graph):
+        mode = self.get_nodeattr("exec_mode")
+        node = self.onnx_node
+        exp_ishape = self.get_normal_input_shape()
+        exp_oshape = self.get_normal_output_shape()
+        folded_ishape = self.get_folded_input_shape()
+        folded_oshape = self.get_folded_output_shape()
+
+        if mode == "cppsim":
+            code_gen_dir = self.get_nodeattr("code_gen_dir_cppsim")
+        elif mode == "rtlsim":
+            code_gen_dir = self.get_nodeattr("code_gen_dir_ipgen")
+        else:
+            raise Exception(
+                """Invalid value for attribute exec_mode! Is currently set to: {}
+            has to be set to one of the following value ("cppsim", "rtlsim")""".format(
+                    mode
+                )
+            )
+
+        inp = context[node.input[0]]
+        assert str(inp.dtype) == "float32", "Input datatype is not float32"
+        assert inp.shape == exp_ishape, """Input shape doesn't match expected shape ."""
+        export_idt = self.get_input_datatype()
+        # reshape input into folded form
+        inp = inp.reshape(folded_ishape)
+        # make copy before saving array
+        reshaped_input = inp.copy()
+        np.save(os.path.join(code_gen_dir, "input_0.npy"), reshaped_input)
+
+        if mode == "cppsim":
+            # execute the precompiled model
+            super().exec_precompiled_singlenode_model()
+            # load output npy file
+            super().npy_to_dynamic_outputs(context, ["output0.npy", "output1.npy"])
+            assert (
+                context[node.output[0]].shape == folded_oshape
+            ), "cppsim \
+            did not produce expected ofolded utput shape"
+            assert (
+                context[node.output[1]].shape == folded_oshape
+            ), "cppsim \
+            did not produce expected ofolded utput shape"
+            context[node.output[0]] = context[node.output[0]].reshape(*exp_oshape)
+            context[node.output[1]] = context[node.output[1]].reshape(*exp_oshape)
+        elif mode == "rtlsim":
+            sim = self.get_rtlsim()
+            nbits = self.get_instream_width()
+            rtlsim_inp = npy_to_rtlsim_input(
+                "{}/input_0.npy".format(code_gen_dir), export_idt, nbits
+            )
+            super().reset_rtlsim(sim)
+            super().toggle_clk(sim)
+            rtlsim_dict = {
+                "inputs": {"in0": rtlsim_inp},
+                "outputs": {"out0": [], "out1": []},
+            }
+            self.rtlsim_multi_io(sim, rtlsim_dict)
+            odt = self.get_output_datatype()
+            target_bits = odt.bitwidth()
+            packed_bits = self.get_outstream_width()
+            out_shape = self.get_folded_output_shape()
+
+            out_npy_path = "{}/output0.npy".format(code_gen_dir)
+            rtlsim_output_to_npy(
+                rtlsim_dict["outputs"]["out0"],
+                out_npy_path,
+                odt,
+                out_shape,
+                packed_bits,
+                target_bits,
+            )
+            # load and reshape output 0
+            output = np.load(out_npy_path)
+            output = np.asarray([output], dtype=np.float32).reshape(*exp_oshape)
+            context[node.output[0]] = output
+
+            out_npy_path = "{}/output1.npy".format(code_gen_dir)
+            rtlsim_output_to_npy(
+                rtlsim_dict["outputs"]["out1"],
+                out_npy_path,
+                odt,
+                out_shape,
+                packed_bits,
+                target_bits,
+            )
+            # load and reshape output 1
+            output = np.load(out_npy_path)
+            output = np.asarray([output], dtype=np.float32).reshape(*exp_oshape)
+            context[node.output[1]] = output
+        else:
+            raise Exception(
+                """Invalid value for attribute exec_mode! Is currently set to: {}
+            has to be set to one of the following value ("cppsim", "rtlsim")""".format(
+                    mode
+                )
+            )
+
+        assert (
+            context[node.output[0]].shape == exp_oshape
+        ), """Output0 shape doesn't match expected shape."""
+        assert (
+            context[node.output[1]].shape == exp_oshape
+        ), """Output1 shape doesn't match expected shape."""
+
+    def global_includes(self):
+        self.code_gen_dict["$GLOBALS$"] = ['#include "streamtools.h"']
+
+    def defines(self, var):
+        self.code_gen_dict["$DEFINES$"] = []
+
+    def read_npy_data(self):
+        code_gen_dir = self.get_nodeattr("code_gen_dir_cppsim")
+        dtype = self.get_input_datatype()
+        elem_bits = dtype.bitwidth()
+        packed_bits = self.get_instream_width()
+        packed_hls_type = "ap_uint<%d>" % packed_bits
+        elem_hls_type = dtype.get_hls_datatype_str()
+        npy_type = "float"
+        npy_in = "%s/input_0.npy" % code_gen_dir
+        self.code_gen_dict["$READNPYDATA$"] = []
+        self.code_gen_dict["$READNPYDATA$"].append(
+            'npy2apintstream<%s, %s, %d, %s>("%s", in0);'
+            % (packed_hls_type, elem_hls_type, elem_bits, npy_type, npy_in)
+        )
+
+    def strm_decl(self):
+        self.code_gen_dict["$STREAMDECLARATIONS$"] = []
+        self.code_gen_dict["$STREAMDECLARATIONS$"].append(
+            'hls::stream<ap_uint<{}>> in0 ("in0");'.format(self.get_instream_width())
+        )
+        self.code_gen_dict["$STREAMDECLARATIONS$"].append(
+            'hls::stream<ap_uint<{}>> out0 ("out0");'.format(self.get_outstream_width())
+        )
+        self.code_gen_dict["$STREAMDECLARATIONS$"].append(
+            'hls::stream<ap_uint<{}>> out1 ("out1");'.format(self.get_outstream_width())
+        )
+
+    def docompute(self):
+        self.code_gen_dict["$DOCOMPUTE$"] = [
+            """DuplicateStreams_Batch<{}, {}> (in0, out0, out1, 1);""".format(
+                self.get_outstream_width(), self.get_number_output_values() // 2,
+            )
+        ]
+
+    def dataoutstrm(self):
+        code_gen_dir = self.get_nodeattr("code_gen_dir_cppsim")
+        dtype = self.get_output_datatype()
+        elem_bits = dtype.bitwidth()
+        packed_bits = self.get_outstream_width()
+        packed_hls_type = "ap_uint<%d>" % packed_bits
+        elem_hls_type = dtype.get_hls_datatype_str()
+        npy_type = "float"
+        npy_out = "%s/output0.npy" % code_gen_dir
+        npy_out1 = "%s/output1.npy" % code_gen_dir
+        oshape = self.get_folded_output_shape()
+        oshape_cpp_str = str(oshape).replace("(", "{").replace(")", "}")
+
+        self.code_gen_dict["$DATAOUTSTREAM$"] = [
+            'apintstream2npy<%s, %s, %d, %s>(out0, %s, "%s");'
+            % (
+                packed_hls_type,
+                elem_hls_type,
+                elem_bits,
+                npy_type,
+                oshape_cpp_str,
+                npy_out,
+            )
+        ]
+
+        self.code_gen_dict["$DATAOUTSTREAM$"] += [
+            'apintstream2npy<%s, %s, %d, %s>(out1, %s, "%s");'
+            % (
+                packed_hls_type,
+                elem_hls_type,
+                elem_bits,
+                npy_type,
+                oshape_cpp_str,
+                npy_out1,
+            )
+        ]
+
+    def save_as_npy(self):
+        self.code_gen_dict["$SAVEASCNPY$"] = []
+
+    def blackboxfunction(self):
+        self.code_gen_dict["$BLACKBOXFUNCTION$"] = [
+            """void {}(hls::stream<ap_uint<{}>> &in0,
+                hls::stream<ap_uint<{}>> &out0,
+                hls::stream<ap_uint<{}>> &out1)""".format(
+                self.onnx_node.name,
+                self.get_instream_width(),
+                self.get_outstream_width(),
+                self.get_outstream_width(),
+            )
+        ]
+
+    def pragmas(self):
+        self.code_gen_dict["$PRAGMAS$"] = ["#pragma HLS INTERFACE axis port=in0"]
+        self.code_gen_dict["$PRAGMAS$"].append("#pragma HLS INTERFACE axis port=out0")
+        self.code_gen_dict["$PRAGMAS$"].append("#pragma HLS INTERFACE axis port=out1")
+        self.code_gen_dict["$PRAGMAS$"].append(
+            "#pragma HLS INTERFACE ap_ctrl_none port=return"
+        )
diff --git a/src/finn/custom_op/registry.py b/src/finn/custom_op/registry.py
index 67f59b49a4617355897c7adfb314f2710af04b71..0d62862c222b44d2e507a90a80bfcd4fa405d3fe 100644
--- a/src/finn/custom_op/registry.py
+++ b/src/finn/custom_op/registry.py
@@ -47,6 +47,7 @@ from finn.custom_op.fpgadataflow.globalaccpool_batch import GlobalAccPool_Batch
 from finn.custom_op.fpgadataflow.thresholding_batch import Thresholding_Batch
 from finn.custom_op.fpgadataflow.addstreams_batch import AddStreams_Batch
 from finn.custom_op.fpgadataflow.labelselect_batch import LabelSelect_Batch
+from finn.custom_op.fpgadataflow.duplicatestreams_batch import DuplicateStreams_Batch
 
 # create a mapping of all known CustomOp names and classes
 custom_op = {}
@@ -66,6 +67,7 @@ custom_op["GlobalAccPool_Batch"] = GlobalAccPool_Batch
 custom_op["Thresholding_Batch"] = Thresholding_Batch
 custom_op["AddStreams_Batch"] = AddStreams_Batch
 custom_op["LabelSelect_Batch"] = LabelSelect_Batch
+custom_op["DuplicateStreams_Batch"] = DuplicateStreams_Batch
 
 
 def getCustomOp(node):
diff --git a/src/finn/transformation/streamline/reorder.py b/src/finn/transformation/streamline/reorder.py
index 1886c785705161c3a13493de44dc3f3f86463f4f..b91ffdb3f731d27d9a6ba68b090f3881e6d7293a 100644
--- a/src/finn/transformation/streamline/reorder.py
+++ b/src/finn/transformation/streamline/reorder.py
@@ -36,8 +36,9 @@ from finn.util.basic import get_by_name
 
 
 class MoveAddPastMul(Transformation):
-    """Move add operations past multiply operations. The aim is to have them
-    next to each other such that they can be collapsed into a single add."""
+    """Move add operations past multiply operations on linear segments of the graph.
+    The aim is to have them next to each other such that they can be collapsed into
+    a single add."""
 
     def apply(self, model):
         graph = model.graph
@@ -45,9 +46,17 @@ class MoveAddPastMul(Transformation):
         graph_modified = False
         for n in graph.node:
             node_ind += 1
-            if n.op_type == "Add":
+            if (
+                n.op_type == "Add"
+                and not model.is_fork_node(n)
+                and not model.is_join_node(n)
+            ):
                 consumer = model.find_consumer(n.output[0])
-                if consumer is not None and consumer.op_type == "Mul":
+                if (
+                    consumer is not None
+                    and consumer.op_type == "Mul"
+                    and not model.is_join_node(consumer)
+                ):
                     # have: (x) -> add(,B) -> (x+B) -> mul(,A) -> (xA+BA)
                     # want: (x) -> mul(,A) -> (xA) -> add(,BA) -> (xA+BA)
                     # assume input 0 is from the previous layer, input 1 is the
@@ -63,12 +72,16 @@ class MoveAddPastMul(Transformation):
                     end_name = consumer.output[0]
                     # compute new param value for add
                     BA = B * A
+
                     # make and insert new nodes
                     new_mul = oh.make_node(
-                        "Mul", [start_name, mul_weight_name], [middle_name]
+                        "Mul",
+                        [start_name, mul_weight_name],
+                        [middle_name],
+                        name=consumer.name,
                     )
                     new_add = oh.make_node(
-                        "Add", [middle_name, add_weight_name], [end_name]
+                        "Add", [middle_name, add_weight_name], [end_name], name=n.name
                     )
                     graph.node.insert(node_ind, new_mul)
                     graph.node.insert(node_ind + 1, new_add)
@@ -78,6 +91,7 @@ class MoveAddPastMul(Transformation):
                     graph.node.remove(n)
                     graph.node.remove(consumer)
                     graph_modified = True
+
         model = model.transform(InferShapes())
         return (model, graph_modified)
 
@@ -92,9 +106,17 @@ class MoveScalarMulPastMatMul(Transformation):
         graph_modified = False
         for n in graph.node:
             node_ind += 1
-            if n.op_type == "Mul":
+            if (
+                n.op_type == "Mul"
+                and not model.is_fork_node(n)
+                and not model.is_join_node(n)
+            ):
                 consumer = model.find_consumer(n.output[0])
-                if consumer is not None and consumer.op_type == "MatMul":
+                if (
+                    consumer is not None
+                    and consumer.op_type == "MatMul"
+                    and not model.is_join_node(consumer)
+                ):
                     mul_weight_name = n.input[1]
                     matmul_weight_name = consumer.input[1]
                     A = model.get_initializer(mul_weight_name)
@@ -109,10 +131,16 @@ class MoveScalarMulPastMatMul(Transformation):
                         # if the mul is scalar, we can simply swap the order of ops
                         # make and insert new nodes
                         new_matmul = oh.make_node(
-                            "MatMul", [start_name, matmul_weight_name], [middle_name]
+                            "MatMul",
+                            [start_name, matmul_weight_name],
+                            [middle_name],
+                            name=consumer.name,
                         )
                         new_mul = oh.make_node(
-                            "Mul", [middle_name, mul_weight_name], [end_name]
+                            "Mul",
+                            [middle_name, mul_weight_name],
+                            [end_name],
+                            name=n.name,
                         )
                         graph.node.insert(node_ind, new_matmul)
                         graph.node.insert(node_ind + 1, new_mul)
@@ -135,9 +163,17 @@ class MoveScalarAddPastMatMul(Transformation):
         graph_modified = False
         for n in graph.node:
             node_ind += 1
-            if n.op_type == "Add":
+            if (
+                n.op_type == "Add"
+                and not model.is_fork_node(n)
+                and not model.is_join_node(n)
+            ):
                 consumer = model.find_consumer(n.output[0])
-                if consumer is not None and consumer.op_type == "MatMul":
+                if (
+                    consumer is not None
+                    and consumer.op_type == "MatMul"
+                    and not model.is_join_node(consumer)
+                ):
                     add_weight_name = n.input[1]
                     matmul_weight_name = consumer.input[1]
                     A = model.get_initializer(add_weight_name)
@@ -155,10 +191,16 @@ class MoveScalarAddPastMatMul(Transformation):
                         # update the add weight
                         model.set_initializer(add_weight_name, Anew)
                         new_matmul = oh.make_node(
-                            "MatMul", [start_name, matmul_weight_name], [middle_name]
+                            "MatMul",
+                            [start_name, matmul_weight_name],
+                            [middle_name],
+                            name=consumer.name,
                         )
                         new_add = oh.make_node(
-                            "Add", [middle_name, add_weight_name], [end_name]
+                            "Add",
+                            [middle_name, add_weight_name],
+                            [end_name],
+                            name=n.name,
                         )
                         graph.node.insert(node_ind, new_matmul)
                         graph.node.insert(node_ind + 1, new_add)
@@ -181,9 +223,17 @@ class MoveScalarAddPastConv(Transformation):
         graph_modified = False
         for n in graph.node:
             node_ind += 1
-            if n.op_type == "Add":
+            if (
+                n.op_type == "Add"
+                and not model.is_fork_node(n)
+                and not model.is_join_node(n)
+            ):
                 consumer = model.find_consumer(n.output[0])
-                if consumer is not None and consumer.op_type == "Conv":
+                if (
+                    consumer is not None
+                    and consumer.op_type == "Conv"
+                    and not model.is_join_node(consumer)
+                ):
                     conv_node = consumer
                     add_node = n
                     add_weight_name = n.input[1]
@@ -238,9 +288,17 @@ class MoveScalarMulPastConv(Transformation):
         graph_modified = False
         for n in graph.node:
             node_ind += 1
-            if n.op_type == "Mul":
+            if (
+                n.op_type == "Mul"
+                and not model.is_fork_node(n)
+                and not model.is_join_node(n)
+            ):
                 consumer = model.find_consumer(n.output[0])
-                if consumer is not None and consumer.op_type == "Conv":
+                if (
+                    consumer is not None
+                    and consumer.op_type == "Conv"
+                    and not model.is_join_node(consumer)
+                ):
                     mul_weight_name = n.input[1]
                     A = model.get_initializer(mul_weight_name)
                     assert A is not None, "Initializer for mul weights is not set."
diff --git a/src/finn/util/fpgadataflow.py b/src/finn/util/fpgadataflow.py
index 7b66d092107c27decca68926a0667333bebedbe0..d1669444e55cb0fddb2690e51849c4603d47d32c 100644
--- a/src/finn/util/fpgadataflow.py
+++ b/src/finn/util/fpgadataflow.py
@@ -127,3 +127,91 @@ def is_fpgadataflow_node(node):
                     is_node = True
 
     return is_node
+
+
+def rtlsim_multi_io(sim, io_dict, num_out_values, trace_file=""):
+    """Runs the pyverilator simulation by passing the input values to the simulation,
+    toggle the clock and observing the execution time. Function contains also an
+    observation loop that can abort the simulation if no output value is produced
+    after a set number of cycles. Can handle multiple i/o streams. See function
+    implementation for details on how the top-level signals should be named.
+
+    sim: the PyVerilator object for simulation
+    io_dict: a dict of dicts in the following format:
+            {"inputs" : {"in0" : <input_data>, "in1" : <input_data>},
+             "outputs" : {"out0" : [], "out1" : []} }
+            <input_data> is a list of Python arbitrary-precision ints indicating
+            what data to push into the simulation, and the output lists are
+            similarly filled when the simulation is complete
+    num_out_values: number of total values to be read from the simulation to
+                    finish the simulation and return.
+
+    returns: number of clock cycles elapsed for completion
+
+    """
+
+    if trace_file != "":
+        sim.start_vcd_trace(trace_file)
+
+    for outp in io_dict["outputs"]:
+        sim.io[outp + "_V_V_TREADY"] = 1
+
+    # observe if output is completely calculated
+    # total_cycle_count will contain the number of cycles the calculation ran
+    output_done = False
+    total_cycle_count = 0
+    output_count = 0
+    old_output_count = 0
+
+    # avoid infinite looping of simulation by aborting when there is no change in
+    # output values after 100 cycles
+    no_change_count = 0
+    liveness_threshold = pyverilate_get_liveness_threshold_cycles()
+
+    while not (output_done):
+        for inp in io_dict["inputs"]:
+            inputs = io_dict["inputs"][inp]
+            sim.io[inp + "_V_V_TVALID"] = 1 if len(inputs) > 0 else 0
+            sim.io[inp + "_V_V_TDATA"] = inputs[0] if len(inputs) > 0 else 0
+            if sim.io[inp + "_V_V_TREADY"] == 1 and sim.io[inp + "_V_V_TVALID"] == 1:
+                inputs = inputs[1:]
+            io_dict["inputs"][inp] = inputs
+
+        for outp in io_dict["outputs"]:
+            outputs = io_dict["outputs"][outp]
+            if sim.io[outp + "_V_V_TVALID"] == 1 and sim.io[outp + "_V_V_TREADY"] == 1:
+                outputs = outputs + [sim.io[outp + "_V_V_TDATA"]]
+                output_count += 1
+            io_dict["outputs"][outp] = outputs
+
+        sim.io.ap_clk = 1
+        sim.io.ap_clk = 0
+
+        total_cycle_count = total_cycle_count + 1
+
+        if output_count == old_output_count:
+            no_change_count = no_change_count + 1
+        else:
+            no_change_count = 0
+            old_output_count = output_count
+
+        # check if all expected output words received
+        if output_count == num_out_values:
+            output_done = True
+
+        # end sim on timeout
+        if no_change_count == liveness_threshold:
+            if trace_file != "":
+                sim.flush_vcd_trace()
+                sim.stop_vcd_trace()
+            raise Exception(
+                "Error in simulation! Takes too long to produce output. "
+                "Consider setting the LIVENESS_THRESHOLD env.var. to a "
+                "larger value."
+            )
+
+    if trace_file != "":
+        sim.flush_vcd_trace()
+        sim.stop_vcd_trace()
+
+    return total_cycle_count
diff --git a/tests/core/test_modelwrapper.py b/tests/core/test_modelwrapper.py
index d1da6934a5db07aabe41a9ca40b5de497b6460a1..4bd9385536bc6721c66726169dfa4c69e5f06772 100644
--- a/tests/core/test_modelwrapper.py
+++ b/tests/core/test_modelwrapper.py
@@ -127,3 +127,45 @@ def test_modelwrapper_graph_order():
     assert model.get_node_index(Round_node) == 1
     assert model.get_node_index(Ceil_node) == 2
     assert model.get_node_index(Add_node) == 3
+
+
+def test_modelwrapper_detect_forks_n_joins():
+    # create small network with properties to be tested
+    Neg_node = onnx.helper.make_node("Neg", inputs=["in1"], outputs=["neg1"])
+    Round_node = onnx.helper.make_node("Round", inputs=["neg1"], outputs=["round1"])
+
+    Ceil_node = onnx.helper.make_node("Ceil", inputs=["neg1"], outputs=["ceil1"])
+    Add_node = onnx.helper.make_node(
+        "Add", inputs=["round1", "ceil1"], outputs=["out1"]
+    )
+
+    in1 = onnx.helper.make_tensor_value_info("in1", onnx.TensorProto.FLOAT, [4, 4])
+    out1 = onnx.helper.make_tensor_value_info("out1", onnx.TensorProto.FLOAT, [4, 4])
+
+    graph = onnx.helper.make_graph(
+        nodes=[Neg_node, Round_node, Ceil_node, Add_node],
+        name="simple_graph",
+        inputs=[in1],
+        outputs=[out1],
+        value_info=[
+            onnx.helper.make_tensor_value_info("neg1", onnx.TensorProto.FLOAT, [4, 4]),
+            onnx.helper.make_tensor_value_info(
+                "round1", onnx.TensorProto.FLOAT, [4, 4]
+            ),
+            onnx.helper.make_tensor_value_info("ceil1", onnx.TensorProto.FLOAT, [4, 4]),
+        ],
+    )
+
+    onnx_model = onnx.helper.make_model(graph, producer_name="simple-model")
+    model = ModelWrapper(onnx_model)
+
+    # test
+    assert model.is_fork_node(Neg_node)
+    assert not model.is_fork_node(Round_node)
+    assert not model.is_fork_node(Ceil_node)
+    assert not model.is_fork_node(Add_node)
+
+    assert not model.is_join_node(Neg_node)
+    assert not model.is_join_node(Round_node)
+    assert not model.is_join_node(Ceil_node)
+    assert model.is_join_node(Add_node)
diff --git a/tests/fpgadataflow/test_fpgadataflow_duplicatestreams.py b/tests/fpgadataflow/test_fpgadataflow_duplicatestreams.py
new file mode 100644
index 0000000000000000000000000000000000000000..4fb84be59333ef0e696204c9064fcf77e35b5d9b
--- /dev/null
+++ b/tests/fpgadataflow/test_fpgadataflow_duplicatestreams.py
@@ -0,0 +1,127 @@
+# Copyright (c) 2020, Xilinx
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+#   list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+#   this list of conditions and the following disclaimer in the documentation
+#   and/or other materials provided with the distribution.
+#
+# * Neither the name of FINN nor the names of its
+#   contributors may be used to endorse or promote products derived from
+#   this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import pytest
+
+from onnx import TensorProto, helper
+
+import finn.core.onnx_exec as oxe
+from finn.core.datatype import DataType
+from finn.core.modelwrapper import ModelWrapper
+from finn.transformation.fpgadataflow.prepare_ip import PrepareIP
+from finn.transformation.fpgadataflow.prepare_cppsim import PrepareCppSim
+from finn.transformation.fpgadataflow.compile_cppsim import CompileCppSim
+from finn.transformation.fpgadataflow.hlssynth_ip import HLSSynthIP
+from finn.transformation.fpgadataflow.set_exec_mode import SetExecMode
+from finn.transformation.general import GiveUniqueNodeNames
+from finn.transformation.fpgadataflow.prepare_rtlsim import PrepareRTLSim
+from finn.util.basic import gen_finn_dt_tensor
+from finn.transformation.fpgadataflow.replace_verilog_relpaths import (
+    ReplaceVerilogRelPaths,
+)
+
+
+def make_dupstreams_modelwrapper(ch, pe, idim, idt):
+    shape = [1, idim, idim, ch]
+    inp = helper.make_tensor_value_info("inp", TensorProto.FLOAT, shape)
+    outp0 = helper.make_tensor_value_info("outp0", TensorProto.FLOAT, shape)
+    outp1 = helper.make_tensor_value_info("outp1", TensorProto.FLOAT, shape)
+
+    dupstrm_node = helper.make_node(
+        "DuplicateStreams_Batch",
+        ["inp"],
+        ["outp0", "outp1"],
+        domain="finn",
+        backend="fpgadataflow",
+        NumChannels=ch,
+        PE=pe,
+        inputDataType=idt.name,
+        numInputVectors=[1, idim, idim],
+    )
+    graph = helper.make_graph(
+        nodes=[dupstrm_node], name="graph", inputs=[inp], outputs=[outp0, outp1]
+    )
+
+    model = helper.make_model(graph, producer_name="addstreams-model")
+    model = ModelWrapper(model)
+
+    model.set_tensor_datatype("inp", idt)
+
+    return model
+
+
+def prepare_inputs(input_tensor, idt):
+    return {"inp": input_tensor}
+
+
+# data type
+@pytest.mark.parametrize("idt", [DataType.INT4, DataType.UINT16])
+# channels
+@pytest.mark.parametrize("ch", [64])
+# folding
+@pytest.mark.parametrize("fold", [-1, 2, 1])
+# image dimension
+@pytest.mark.parametrize("imdim", [7])
+# execution mode
+@pytest.mark.parametrize("exec_mode", ["cppsim", "rtlsim"])
+@pytest.mark.vivado
+def test_fpgadataflow_duplicatestreams(idt, ch, fold, imdim, exec_mode):
+    if fold == -1:
+        pe = 1
+    else:
+        pe = ch // fold
+    assert ch % pe == 0
+
+    # generate input data
+    x = gen_finn_dt_tensor(idt, (1, imdim, imdim, ch))
+
+    model = make_dupstreams_modelwrapper(ch, pe, imdim, idt)
+
+    if exec_mode == "cppsim":
+        model = model.transform(PrepareCppSim())
+        model = model.transform(CompileCppSim())
+        model = model.transform(SetExecMode("cppsim"))
+    elif exec_mode == "rtlsim":
+        model = model.transform(SetExecMode("rtlsim"))
+        model = model.transform(GiveUniqueNodeNames())
+        model = model.transform(PrepareIP("xc7z020clg400-1", 5))
+        model = model.transform(HLSSynthIP())
+        model = model.transform(ReplaceVerilogRelPaths())
+        model = model.transform(PrepareRTLSim())
+    else:
+        raise Exception("Unknown exec_mode")
+
+    # prepare input data and execute
+    input_dict = prepare_inputs(x, idt)
+    output_dict = oxe.execute_onnx(model, input_dict)
+    y0 = output_dict["outp0"]
+    y1 = output_dict["outp1"]
+    expected_y = x
+
+    assert (y0 == expected_y).all(), exec_mode + " failed"
+    assert (y1 == expected_y).all(), exec_mode + " failed"
diff --git a/tests/transformation/test_move_add_past_mul.py b/tests/transformation/test_move_add_past_mul.py
index a0516d6fb2ff985fc112185ce99ad8facd841caf..163b9d310a5f12bd0b854f9aa46f53a549bf109e 100644
--- a/tests/transformation/test_move_add_past_mul.py
+++ b/tests/transformation/test_move_add_past_mul.py
@@ -60,6 +60,9 @@ def test_move_add_past_mul_single():
     new_model = model.transform(MoveAddPastMul())
     inp_dict = {"top_in": np.asarray([-1.0, 1.0], dtype=np.float32)}
     assert ox.compare_execution(model, new_model, inp_dict)
+    assert new_model.graph.node[0].op_type == "Mul"
+    assert new_model.graph.node[1].op_type == "Add"
+    assert new_model.graph.node[0].output[0] == new_model.graph.node[1].input[0]
 
 
 def test_move_add_past_mul_multi():
@@ -92,3 +95,50 @@ def test_move_add_past_mul_multi():
     new_model = model.transform(MoveAddPastMul())
     inp_dict = {"top_in": np.asarray([-1.0, 1.0], dtype=np.float32)}
     assert ox.compare_execution(model, new_model, inp_dict)
+    assert new_model.graph.node[0].op_type == "Mul"
+    assert new_model.graph.node[1].op_type == "Mul"
+    assert new_model.graph.node[2].op_type == "Add"
+    assert new_model.graph.node[3].op_type == "Add"
+    for i in range(len(new_model.graph.node) - 1):
+        assert new_model.graph.node[i].output[0] == new_model.graph.node[i + 1].input[0]
+
+
+def test_move_add_past_mul_only_if_linear():
+    top_in = oh.make_tensor_value_info("top_in", TensorProto.FLOAT, [2])
+    top_out = oh.make_tensor_value_info("top_out", TensorProto.FLOAT, [2])
+
+    value_info = [oh.make_tensor_value_info("add1_param", TensorProto.FLOAT, [1])]
+    value_info += [oh.make_tensor_value_info("mul1_param", TensorProto.FLOAT, [1])]
+    value_info += [oh.make_tensor_value_info("mul2_param", TensorProto.FLOAT, [1])]
+    value_info += [oh.make_tensor_value_info("mul3_param", TensorProto.FLOAT, [1])]
+    modelproto = oh.make_model(
+        oh.make_graph(
+            name="test",
+            inputs=[top_in],
+            outputs=[top_out],
+            value_info=value_info,
+            nodes=[
+                oh.make_node("Add", ["top_in", "add1_param"], ["t1"]),
+                oh.make_node("Mul", ["t1", "mul1_param"], ["fork"]),
+                oh.make_node("Mul", ["fork", "mul2_param"], ["t3"]),
+                oh.make_node("Add", ["t3", "fork"], ["t4"]),
+                oh.make_node("Mul", ["t4", "mul3_param"], ["top_out"]),
+            ],
+        )
+    )
+    model = ModelWrapper(modelproto)
+    model = model.transform(InferShapes())
+
+    np.random.seed(0)
+    model.set_initializer("add1_param", np.random.rand(2).astype(np.float32))
+    model.set_initializer("mul1_param", np.random.rand(2).astype(np.float32))
+    model.set_initializer("mul2_param", np.random.rand(2).astype(np.float32))
+    model.set_initializer("mul3_param", np.random.rand(2).astype(np.float32))
+    new_model = model.transform(MoveAddPastMul())
+    inp_dict = {"top_in": np.random.rand(2).astype(np.float32)}
+    assert ox.compare_execution(model, new_model, inp_dict)
+    assert new_model.graph.node[0].op_type == "Mul"
+    assert new_model.graph.node[1].op_type == "Add"
+    assert new_model.graph.node[2].op_type == "Mul"
+    assert new_model.graph.node[3].op_type == "Add"
+    assert new_model.graph.node[4].op_type == "Mul"
diff --git a/tests/transformation/test_move_scalar_past_conv.py b/tests/transformation/test_move_scalar_past_conv.py
new file mode 100644
index 0000000000000000000000000000000000000000..9992d17b96ab5f419f3ac495f126ddfa736349a2
--- /dev/null
+++ b/tests/transformation/test_move_scalar_past_conv.py
@@ -0,0 +1,87 @@
+import numpy as np
+import onnx.helper as oh
+import pytest
+from onnx import TensorProto
+
+import finn.core.onnx_exec as ox
+from finn.core.modelwrapper import ModelWrapper
+from finn.transformation.infer_shapes import InferShapes
+from finn.transformation.streamline import (
+    MoveScalarAddPastConv,
+    MoveScalarMulPastConv,
+)
+
+
+@pytest.mark.parametrize(
+    "test_args", [("Add", MoveScalarAddPastConv()), ("Mul", MoveScalarMulPastConv())],
+)
+def test_move_scalar_past_conv_only_if_linear(test_args):
+    scalar_op = test_args[0]
+    transf_fxn = test_args[1]
+
+    in_feature_dim = 7
+    in_chn = 1
+    padding = False
+    stages = 3
+    kernel_size = 3
+
+    out_feature_dim = (
+        in_feature_dim if padding else in_feature_dim - (kernel_size // 2 * 2) * stages
+    )
+
+    input_shape = [1, in_chn, in_feature_dim, in_feature_dim]
+    output_shape = [1, in_chn, out_feature_dim, out_feature_dim]
+
+    conv_param_shape = [in_chn, in_chn, kernel_size, kernel_size]
+
+    conv_config = {}
+    conv_config["dilations"] = [1, 1]
+    conv_config["group"] = 1
+    conv_config["kernel_shape"] = [kernel_size, kernel_size]
+    conv_config["pads"] = [0, 0, 0, 0]
+    conv_config["strides"] = [1, 1]
+
+    top_in = oh.make_tensor_value_info("top_in", TensorProto.FLOAT, input_shape)
+    top_out = oh.make_tensor_value_info("top_out", TensorProto.FLOAT, output_shape)
+
+    value_info = [oh.make_tensor_value_info("p1", TensorProto.FLOAT, [1])]
+    value_info += [oh.make_tensor_value_info("p2", TensorProto.FLOAT, conv_param_shape)]
+    value_info += [oh.make_tensor_value_info("p3", TensorProto.FLOAT, conv_param_shape)]
+    value_info += [oh.make_tensor_value_info("p4", TensorProto.FLOAT, conv_param_shape)]
+    value_info += [oh.make_tensor_value_info("p5", TensorProto.FLOAT, conv_param_shape)]
+
+    modelproto = oh.make_model(
+        oh.make_graph(
+            name="test",
+            inputs=[top_in],
+            outputs=[top_out],
+            value_info=value_info,
+            nodes=[
+                oh.make_node("Conv", ["top_in", "p2"], ["t1"], **conv_config),
+                oh.make_node(scalar_op, ["t1", "p1"], ["t2"]),
+                oh.make_node("Conv", ["t2", "p3"], ["t3"], **conv_config),
+                oh.make_node("Conv", ["t2", "p4"], ["t4"], **conv_config),
+                oh.make_node(scalar_op, ["t3", "t4"], ["t5"]),
+                oh.make_node("Conv", ["t5", "p5"], ["top_out"], **conv_config),
+            ],
+        )
+    )
+    model = ModelWrapper(modelproto)
+    model = model.transform(InferShapes())
+
+    np.random.seed(0)
+    model.set_initializer("p1", *np.random.rand(1).astype(np.float32))
+    model.set_initializer("p2", np.random.rand(*conv_param_shape).astype(np.float32))
+    model.set_initializer("p3", np.random.rand(*conv_param_shape).astype(np.float32))
+    model.set_initializer("p4", np.random.rand(*conv_param_shape).astype(np.float32))
+    model.set_initializer("p5", np.random.rand(*conv_param_shape).astype(np.float32))
+    new_model = model.transform(transf_fxn)
+    inp_dict = {"top_in": np.random.rand(*input_shape).astype(np.float32)}
+
+    assert ox.compare_execution(model, new_model, inp_dict)
+    assert new_model.graph.node[0].op_type == "Conv"
+    assert new_model.graph.node[1].op_type == scalar_op
+    assert new_model.graph.node[2].op_type == "Conv"
+    assert new_model.graph.node[3].op_type == "Conv"
+    assert new_model.graph.node[4].op_type == scalar_op
+    assert new_model.graph.node[5].op_type == "Conv"
diff --git a/tests/transformation/test_move_scalar_past_matmul.py b/tests/transformation/test_move_scalar_past_matmul.py
index 896527e82d8cfa869cb979d1102904c70703a14c..e432dbf4ec1a38551609e5914e2d44968a020908 100644
--- a/tests/transformation/test_move_scalar_past_matmul.py
+++ b/tests/transformation/test_move_scalar_past_matmul.py
@@ -27,6 +27,7 @@
 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 
 import numpy as np
+import pytest
 import onnx.helper as oh
 from onnx import TensorProto
 
@@ -99,3 +100,56 @@ def test_move_scalar_add_past_matmul():
     assert new_model.graph.node[0].op_type == "MatMul"
     assert new_model.graph.node[1].op_type == "Add"
     assert new_model.graph.node[0].output[0] == new_model.graph.node[1].input[0]
+
+
+@pytest.mark.parametrize(
+    "test_args",
+    [("Add", MoveScalarAddPastMatMul()), ("Mul", MoveScalarMulPastMatMul())],
+)
+def test_move_scalar_past_matmul_only_if_linear(test_args):
+    scalar_op = test_args[0]
+    transf_fxn = test_args[1]
+    input_shape = [1, 2]
+    matmul_shape = [2, 2]
+    top_in = oh.make_tensor_value_info("top_in", TensorProto.FLOAT, input_shape)
+    top_out = oh.make_tensor_value_info("top_out", TensorProto.FLOAT, input_shape)
+
+    p1 = oh.make_tensor_value_info("p1", TensorProto.FLOAT, [1, 1])
+    p2 = oh.make_tensor_value_info("p2", TensorProto.FLOAT, matmul_shape)
+    p3 = oh.make_tensor_value_info("p3", TensorProto.FLOAT, matmul_shape)
+    p4 = oh.make_tensor_value_info("p4", TensorProto.FLOAT, matmul_shape)
+    modelproto = oh.make_model(
+        oh.make_graph(
+            name="test",
+            inputs=[top_in],
+            outputs=[top_out],
+            value_info=[p1, p2, p3, p4],
+            nodes=[
+                oh.make_node(scalar_op, ["top_in", "p1"], ["t1"]),
+                oh.make_node("MatMul", ["t1", "p2"], ["fork"]),
+                oh.make_node("MatMul", ["fork", "p3"], ["t3"]),
+                oh.make_node(scalar_op, ["t3", "fork"], ["t4"]),
+                oh.make_node("MatMul", ["t4", "p4"], ["top_out"]),
+            ],
+        )
+    )
+    model = ModelWrapper(modelproto)
+    model = model.transform(InferShapes())
+
+    np.random.seed(0)
+    model.set_initializer("p1", np.random.rand(1, 1).astype(np.float32))
+    model.set_initializer("p2", np.random.rand(*matmul_shape).astype(np.float32))
+    model.set_initializer("p3", np.random.rand(*matmul_shape).astype(np.float32))
+    model.set_initializer("p4", np.random.rand(*matmul_shape).astype(np.float32))
+
+    # Transform
+    new_model = model.transform(transf_fxn)
+
+    # Test
+    inp_dict = {"top_in": np.random.rand(*input_shape).astype(np.float32)}
+    assert ox.compare_execution(model, new_model, inp_dict)
+    assert new_model.graph.node[0].op_type == "MatMul"
+    assert new_model.graph.node[1].op_type == scalar_op
+    assert new_model.graph.node[2].op_type == "MatMul"
+    assert new_model.graph.node[3].op_type == scalar_op
+    assert new_model.graph.node[4].op_type == "MatMul"