diff --git a/src/finn/core/remote_exec.py b/src/finn/core/remote_exec.py
index 335dfec04e4abee41f914c5d912ce291a0d31a91..a533e4d36629f57f7c4a576570d75a1e051de5be 100644
--- a/src/finn/core/remote_exec.py
+++ b/src/finn/core/remote_exec.py
@@ -79,6 +79,12 @@ def remote_exec(model, execution_context):
     bash_command = ["/bin/bash", "-c", cmd]
     process_compile = subprocess.Popen(bash_command, stdout=subprocess.PIPE)
     process_compile.communicate()
+    # remove stale output file from local dir, if any
+    try:
+        os.remove("{}/output.npy".format(deployment_dir))
+    except FileNotFoundError:
+        pass
+    # copy generated output to local
     cmd = "sshpass -p {} scp -P{} {}@{}:{}/{}/output.npy {}".format(
         pynq_password,
         pynq_port,
diff --git a/src/finn/core/throughput_test.py b/src/finn/core/throughput_test.py
index c82d540e29fc59b92a22bf011e823a9f8c076843..8d3dabcf8af51327d5d951464c6d9b36e2f67497 100644
--- a/src/finn/core/throughput_test.py
+++ b/src/finn/core/throughput_test.py
@@ -30,10 +30,11 @@ import os
 import subprocess
 
 
-def throughput_test(model):
+def throughput_test(model, batchsize=1000):
     """Runs the throughput test for the given model remotely on the pynq board.
     The metadata properties related to the pynq board have to be set.
-    Returns a dictionary with results of the throughput test"""
+    Returns a dictionary with results of the throughput test. Returns None
+    if the test fails."""
 
     pynq_ip = model.get_metadata_prop("pynq_ip")
     pynq_port = int(model.get_metadata_prop("pynq_port"))
@@ -47,7 +48,8 @@ def throughput_test(model):
     cmd = (
         "sshpass -p {} ssh {}@{} -p {} "
         '"cd {}/{}; echo "{}" | '
-        'sudo -S python3.6 driver.py --exec_mode="throughput_test" --batchsize=1000"'
+        'sudo -S python3.6 driver.py --exec_mode="throughput_test" --batchsize=%d"'
+        % batchsize
     ).format(
         pynq_password,
         pynq_username,
@@ -61,6 +63,12 @@ def throughput_test(model):
     process_compile = subprocess.Popen(bash_command, stdout=subprocess.PIPE)
     process_compile.communicate()
 
+    # remove any pre-existing metrics file
+    try:
+        os.remove("{}/nw_metrics.txt".format(deployment_dir))
+    except FileNotFoundError:
+        pass
+
     cmd = "sshpass -p {} scp -P{} {}@{}:{}/{}/nw_metrics.txt {}".format(
         pynq_password,
         pynq_port,
@@ -74,7 +82,9 @@ def throughput_test(model):
     process_compile = subprocess.Popen(bash_command, stdout=subprocess.PIPE)
     process_compile.communicate()
 
-    with open("{}/nw_metrics.txt".format(deployment_dir), "r") as file:
-        res = eval(file.read())
-
-    return res
+    try:
+        with open("{}/nw_metrics.txt".format(deployment_dir), "r") as file:
+            res = eval(file.read())
+        return res
+    except FileNotFoundError:
+        return None
diff --git a/src/finn/custom_op/fpgadataflow/sameresize_batch.py b/src/finn/custom_op/fpgadataflow/sameresize_batch.py
new file mode 100644
index 0000000000000000000000000000000000000000..c459cac1e9c17336200a1fc85aad2af5e14e2c61
--- /dev/null
+++ b/src/finn/custom_op/fpgadataflow/sameresize_batch.py
@@ -0,0 +1,298 @@
+import os
+import numpy as np
+from onnx import TensorProto, helper
+from finn.core.datatype import DataType
+from finn.custom_op.fpgadataflow import HLSCustomOp
+from finn.util.data_packing import npy_to_rtlsim_input, rtlsim_output_to_npy
+
+
+class SameResize_Batch(HLSCustomOp):
+    """Class that corresponds to finn-hlslib SameResize function.
+    Implements 'same' padding on a given input image."""
+
+    def __init__(self, onnx_node):
+        super().__init__(onnx_node)
+
+    def get_nodeattr_types(self):
+        my_attrs = {
+            "ImgDim": ("i", True, 0),
+            "KernelDim": ("i", True, 0),
+            "Stride": ("i", True, 0),
+            "NumChannels": ("i", True, 0),
+            # FINN input datatype
+            "inputDataType": ("s", True, ""),
+            # distribution of added values to achieve "same" padding
+            "PaddingStyle": ("i", True, 2),
+        }
+        my_attrs.update(super().get_nodeattr_types())
+        return my_attrs
+
+    def get_normal_input_shape(self):
+        idim = self.get_nodeattr("ImgDim")
+        num_ch = self.get_nodeattr("NumChannels")
+
+        ishape = (1, idim, idim, num_ch)
+        return ishape
+
+    def get_normal_output_shape(self):
+        idim = self.get_nodeattr("ImgDim")
+        num_ch = self.get_nodeattr("NumChannels")
+        kdim = self.get_nodeattr("KernelDim")
+        stride = self.get_nodeattr("Stride")
+        assert idim % stride == 0, "Stride must divide input dimension."
+        # number of "same" windows over the input data
+        same_windows = idim // stride
+        odim = kdim + stride * (same_windows - 1)
+
+        oshape = (1, odim, odim, num_ch)
+        return oshape
+
+    def get_folded_input_shape(self):
+        # even though there is no folding in the current hlslib op,
+        # insert a time multiplexing axis to remain compatible with the
+        # shapes produced by the rest of the dataflow pipeline
+        ret = list(self.get_normal_input_shape())
+        ret.insert(-1, 1)
+        return tuple(ret)
+
+    def get_folded_output_shape(self):
+        # even though there is no folding in the current hlslib op,
+        # insert a time multiplexing axis to remain compatible with the
+        # shapes produced by the rest of the dataflow pipeline
+        ret = list(self.get_normal_output_shape())
+        ret.insert(-1, 1)
+        return tuple(ret)
+
+    def make_shape_compatible_op(self, model):
+        exp_ishape = self.get_normal_input_shape()
+        oshape = self.get_normal_output_shape()
+        ishape = tuple(model.get_tensor_shape(self.onnx_node.input[0]))
+        assert ishape == exp_ishape, "Unexpect input shape for SameResize."
+        # implement tensor with correct shape
+        values = np.random.randn(*oshape).astype(np.float32)
+        return helper.make_node(
+            "Constant",
+            inputs=[],
+            outputs=[self.onnx_node.output[0]],
+            value=helper.make_tensor(
+                name="const_tensor",
+                data_type=TensorProto.FLOAT,
+                dims=values.shape,
+                vals=values.flatten().astype(float),
+            ),
+        )
+
+    def infer_node_datatype(self, model):
+        node = self.onnx_node
+        # data type stays the same
+        dtype = model.get_tensor_datatype(node.input[0])
+        exp_idtype = self.get_input_datatype()
+        assert dtype == exp_idtype, "Unexpected datatype for SameResize_Batch"
+        model.set_tensor_datatype(node.output[0], dtype)
+
+    def verify_node(self):
+        pass
+
+    def get_input_datatype(self):
+        """Returns FINN DataType of input."""
+        ret = DataType[self.get_nodeattr("inputDataType")]
+        # the hlslib op always pads with zeroes, so ensure that the DataType
+        # is able to represent zeroes
+        assert ret.allowed(0), "SameResize_Batch DataType must support zero"
+        return ret
+
+    def get_output_datatype(self):
+        """Returns FINN DataType of output. (Same as input datatype)"""
+        return self.get_input_datatype()
+
+    def get_instream_width(self):
+        ibits = self.get_input_datatype().bitwidth()
+        num_ch = self.get_nodeattr("NumChannels")
+
+        return ibits * num_ch
+
+    def get_outstream_width(self):
+        obits = self.get_output_datatype().bitwidth()
+        num_ch = self.get_nodeattr("NumChannels")
+
+        return obits * num_ch
+
+    def get_number_output_values(self):
+        folded_oshape = self.get_folded_output_shape()
+        return np.prod(folded_oshape[:-1])
+
+    def global_includes(self):
+        self.code_gen_dict["$GLOBALS$"] = ['#include "streamtools.h"']
+
+    def defines(self, var):
+        numReps = 1
+        assert self.get_nodeattr("PaddingStyle") == 2, "Only PaddingStyle=2 supported"
+        self.code_gen_dict["$DEFINES$"] = [
+            """#define ImgDim1 {}\n #define KernelDim1 {}\n
+            #define Stride1 {}\n #define NumChannels1 {}\n
+            #define PaddingStyle1 {}\n #define numReps {}""".format(
+                self.get_nodeattr("ImgDim"),
+                self.get_nodeattr("KernelDim"),
+                self.get_nodeattr("Stride"),
+                self.get_nodeattr("NumChannels"),
+                self.get_nodeattr("PaddingStyle"),
+                numReps,
+            )
+        ]
+
+    def read_npy_data(self):
+        code_gen_dir = self.get_nodeattr("code_gen_dir_cppsim")
+        dtype = self.get_input_datatype()
+        if dtype == DataType.BIPOLAR:
+            # use binary for bipolar storage
+            dtype = DataType.BINARY
+        elem_bits = dtype.bitwidth()
+        packed_bits = self.get_instream_width()
+        packed_hls_type = "ap_uint<%d>" % packed_bits
+        elem_hls_type = dtype.get_hls_datatype_str()
+        npy_type = "float"
+        npy_in = "%s/input_0.npy" % code_gen_dir
+        self.code_gen_dict["$READNPYDATA$"] = []
+        self.code_gen_dict["$READNPYDATA$"].append(
+            'npy2apintstream<%s, %s, %d, %s>("%s", in0);'
+            % (packed_hls_type, elem_hls_type, elem_bits, npy_type, npy_in)
+        )
+
+    def strm_decl(self):
+        self.code_gen_dict["$STREAMDECLARATIONS$"] = []
+        self.code_gen_dict["$STREAMDECLARATIONS$"].append(
+            'hls::stream<ap_uint<{}>> in0 ("in0");'.format(self.get_instream_width())
+        )
+        self.code_gen_dict["$STREAMDECLARATIONS$"].append(
+            'hls::stream<ap_uint<{}>> out ("out");'.format(self.get_outstream_width())
+        )
+
+    def docompute(self):
+        in_t = self.get_input_datatype().get_hls_datatype_str()
+        node = self.onnx_node
+        self.code_gen_dict["$DOCOMPUTE$"] = [
+            """{}<ImgDim1, KernelDim1, Stride1, NumChannels1,
+                {}, PaddingStyle1> (in0, out, numReps);""".format(
+                node.op_type, in_t
+            )
+        ]
+
+    def dataoutstrm(self):
+        code_gen_dir = self.get_nodeattr("code_gen_dir_cppsim")
+        dtype = self.get_output_datatype()
+        if dtype == DataType.BIPOLAR:
+            # use binary for bipolar storage
+            dtype = DataType.BINARY
+        elem_bits = dtype.bitwidth()
+        packed_bits = self.get_outstream_width()
+        packed_hls_type = "ap_uint<%d>" % packed_bits
+        elem_hls_type = dtype.get_hls_datatype_str()
+        npy_type = "float"
+        npy_out = "%s/output.npy" % code_gen_dir
+        oshape = self.get_folded_output_shape()
+        oshape_cpp_str = str(oshape).replace("(", "{").replace(")", "}")
+
+        self.code_gen_dict["$DATAOUTSTREAM$"] = [
+            'apintstream2npy<%s, %s, %d, %s>(out, %s, "%s");'
+            % (
+                packed_hls_type,
+                elem_hls_type,
+                elem_bits,
+                npy_type,
+                oshape_cpp_str,
+                npy_out,
+            )
+        ]
+
+    def save_as_npy(self):
+        self.code_gen_dict["$SAVEASCNPY$"] = []
+
+    def blackboxfunction(self):
+        packed_bits = self.get_instream_width()
+        packed_hls_type = "ap_uint<%d>" % packed_bits
+        self.code_gen_dict["$BLACKBOXFUNCTION$"] = [
+            "void %s(hls::stream<%s > &in0, hls::stream<%s > &out)"
+            % (self.onnx_node.name, packed_hls_type, packed_hls_type)
+        ]
+
+    def pragmas(self):
+        self.code_gen_dict["$PRAGMAS$"] = ["#pragma HLS INTERFACE axis port=in0"]
+        self.code_gen_dict["$PRAGMAS$"].append("#pragma HLS INTERFACE axis port=out")
+        self.code_gen_dict["$PRAGMAS$"].append(
+            "#pragma HLS INTERFACE ap_ctrl_none port=return"
+        )
+
+    def execute_node(self, context, graph):
+        mode = self.get_nodeattr("exec_mode")
+        node = self.onnx_node
+        exp_ishape = self.get_normal_input_shape()
+        exp_oshape = self.get_normal_output_shape()
+        folded_oshape = self.get_folded_output_shape()
+
+        if mode == "cppsim":
+            code_gen_dir = self.get_nodeattr("code_gen_dir_cppsim")
+        elif mode == "rtlsim":
+            code_gen_dir = self.get_nodeattr("code_gen_dir_ipgen")
+        else:
+            raise Exception(
+                """Invalid value for attribute exec_mode! Is currently set to: {}
+            has to be set to one of the following value ("cppsim", "rtlsim")""".format(
+                    mode
+                )
+            )
+
+        inp = context[node.input[0]]
+        assert str(inp.dtype) == "float32", "Input datatype is not float32"
+        assert (
+            inp.shape == exp_ishape
+        ), """Input shape doesn't
+        match expected shape (1, ImgDim, ImgDim, NumChannels)."""
+        export_idt = self.get_input_datatype()
+
+        # no reshaping for input since assuming no folding on input
+        # make copy before saving array
+        inp = inp.copy()
+        np.save(os.path.join(code_gen_dir, "input_0.npy"), inp)
+
+        if mode == "cppsim":
+            # execute the precompiled model
+            super().exec_precompiled_singlenode_model()
+            # load output npy file
+            super().npy_to_dynamic_output(context)
+            assert (
+                context[node.output[0]].shape == folded_oshape
+            ), "cppsim \
+            did not produce expected ofolded utput shape"
+            context[node.output[0]] = context[node.output[0]].reshape(*exp_oshape)
+        elif mode == "rtlsim":
+            sim = self.get_rtlsim()
+            nbits = self.get_instream_width()
+            rtlsim_inp = npy_to_rtlsim_input(
+                "{}/input_0.npy".format(code_gen_dir), export_idt, nbits
+            )
+            super().reset_rtlsim(sim)
+            super().toggle_clk(sim)
+            rtlsim_output = self.rtlsim(sim, rtlsim_inp)
+            odt = export_idt
+            target_bits = odt.bitwidth()
+            packed_bits = self.get_outstream_width()
+            out_npy_path = "{}/output.npy".format(code_gen_dir)
+            out_shape = self.get_folded_output_shape()
+            rtlsim_output_to_npy(
+                rtlsim_output, out_npy_path, odt, out_shape, packed_bits, target_bits
+            )
+            # load and reshape output
+            output = np.load(out_npy_path)
+            output = np.asarray([output], dtype=np.float32).reshape(*exp_oshape)
+            context[node.output[0]] = output
+        else:
+            raise Exception(
+                """Invalid value for attribute exec_mode! Is currently set to: {}
+            has to be set to one of the following value ("cppsim", "rtlsim")""".format(
+                    mode
+                )
+            )
+        assert (
+            context[node.output[0]].shape == exp_oshape
+        ), """Output shape doesn't match expected shape
+            (1, OutputDim, OutputDim, NumChannels)."""
diff --git a/src/finn/custom_op/registry.py b/src/finn/custom_op/registry.py
index 0d62862c222b44d2e507a90a80bfcd4fa405d3fe..238829e03353d79fab7c51e7d1b9dca6e2a96a11 100644
--- a/src/finn/custom_op/registry.py
+++ b/src/finn/custom_op/registry.py
@@ -44,6 +44,7 @@ from finn.custom_op.fpgadataflow.streamingdatawidthconverter_batch import (
     StreamingDataWidthConverter_Batch,
 )
 from finn.custom_op.fpgadataflow.globalaccpool_batch import GlobalAccPool_Batch
+from finn.custom_op.fpgadataflow.sameresize_batch import SameResize_Batch
 from finn.custom_op.fpgadataflow.thresholding_batch import Thresholding_Batch
 from finn.custom_op.fpgadataflow.addstreams_batch import AddStreams_Batch
 from finn.custom_op.fpgadataflow.labelselect_batch import LabelSelect_Batch
@@ -64,6 +65,7 @@ custom_op["MaxPoolNHWC"] = MaxPoolNHWC
 custom_op["StreamingDataWidthConverter_Batch"] = StreamingDataWidthConverter_Batch
 custom_op["StreamingFIFO"] = StreamingFIFO
 custom_op["GlobalAccPool_Batch"] = GlobalAccPool_Batch
+custom_op["SameResize_Batch"] = SameResize_Batch
 custom_op["Thresholding_Batch"] = Thresholding_Batch
 custom_op["AddStreams_Batch"] = AddStreams_Batch
 custom_op["LabelSelect_Batch"] = LabelSelect_Batch
diff --git a/src/finn/transformation/fpgadataflow/make_pynq_driver.py b/src/finn/transformation/fpgadataflow/make_pynq_driver.py
index 049ede5064d252bd6391184c4227e5367a8c1e2b..18d3db18da089a5dda4dbb6d97180dd4a20613b5 100644
--- a/src/finn/transformation/fpgadataflow/make_pynq_driver.py
+++ b/src/finn/transformation/fpgadataflow/make_pynq_driver.py
@@ -107,6 +107,13 @@ class MakePYNQDriver(Transformation):
         driver = driver.replace("$OUTPUT_SHAPE_FOLDED$", mss(o_tensor_shape_folded))
         driver = driver.replace("$OUTPUT_SHAPE_PACKED$", mss(o_tensor_shape_packed))
 
+        # clock settings for driver
+        clk_ns = float(model.get_metadata_prop("clk_ns"))
+        fclk_mhz = 1 / (clk_ns * 0.001)
+        # TODO change according to PYNQ board?
+        driver = driver.replace("$CLK_NAME$", "fclk0_mhz")
+        driver = driver.replace("$CLOCK_FREQ_MHZ$", str(fclk_mhz))
+
         with open(driver_py, "w") as f:
             f.write(driver)
         # copy all the dependencies into the driver folder
diff --git a/src/finn/transformation/fpgadataflow/templates.py b/src/finn/transformation/fpgadataflow/templates.py
index 55ecb57decd2ac4fa08331b5ebbcb7fd2f0cd5c6..ab9fd03251819aee72f74cc0c1fa17b99b1e05a4 100644
--- a/src/finn/transformation/fpgadataflow/templates.py
+++ b/src/finn/transformation/fpgadataflow/templates.py
@@ -91,7 +91,7 @@ cd %s
 
 pynq_driver_template = """
 import argparse
-
+import os
 from pynq import Overlay
 import numpy as np
 from pynq import allocate
@@ -101,6 +101,7 @@ from finn.util.data_packing import (
     packed_bytearray_to_finnpy
 )
 from finn.core.datatype import DataType
+from pynq.ps import Clocks
 
 class FINNAccelDriver():
     def __init__(self, N, bitfile):
@@ -118,8 +119,12 @@ class FINNAccelDriver():
         self.oshape_folded = $OUTPUT_SHAPE_FOLDED$
         self.ishape_packed = $INPUT_SHAPE_PACKED$   # datatype np.uint8
         self.oshape_packed = $OUTPUT_SHAPE_PACKED$  # datatype np.uint8
+        # clock frequency
+        self.fclk_mhz = $CLOCK_FREQ_MHZ$
         # load bitfile and set up accelerator
         self.ol = Overlay(bitfile)
+        # set the clock frequency as specified by user during transformations
+        Clocks.$CLK_NAME$ = self.fclk_mhz
         self.dma = self.ol.axi_dma_0
         self.ctrl_regs = self.ol.resize_accel_0
         # neuron folding factor of output = iterations per sample
@@ -202,6 +207,12 @@ if __name__ == "__main__":
     # for the remote execution the data from the input npy file has to be loaded,
     # packed and copied to the PYNQ buffer
     if exec_mode == "execute":
+        # remove old output file to prevent reusing old output
+        # in case execution fails
+        try:
+            os.remove(outputfile)
+        except FileNotFoundError:
+            pass
         # load desired input .npy file
         ibuf_normal = np.load(inputfile)
         ibuf_folded = finnDriver.fold_input(ibuf_normal)
@@ -212,10 +223,15 @@ if __name__ == "__main__":
 
     # for the throughput test the runtime of the network has to be measured
     if exec_mode == "throughput_test":
-        # measure runtime of network
-        start = time.time()
+        # remove old metrics file
+        try:
+            os.remove("nw_metrics.txt")
+        except FileNotFoundError:
+            pass
         # dictionary for results of throughput test
         res={}
+        # measure runtime of network
+        start = time.time()
 
     # execute accelerator
     finnDriver.execute()
@@ -228,6 +244,8 @@ if __name__ == "__main__":
         res["throughput[images/s]"] = N / runtime
         res["DRAM_in_bandwidth[Mb/s]"] = np.prod(finnDriver.ishape_packed)*0.000001 / runtime
         res["DRAM_out_bandwidth[Mb/s]"] = np.prod(finnDriver.oshape_packed)*0.000001 / runtime
+        res["fclk[mhz]"] = Clocks.fclk0_mhz
+        res["N"] = N
         file = open("nw_metrics.txt", "w")
         file.write(str(res))
         file.close()
diff --git a/src/finn/transformation/streamline/absorb.py b/src/finn/transformation/streamline/absorb.py
index 0d709297a9132b15b51435b7ab4b51ce55c7e9f3..dbcf97361017144174f9fbfca35a84361b5abd26 100644
--- a/src/finn/transformation/streamline/absorb.py
+++ b/src/finn/transformation/streamline/absorb.py
@@ -46,7 +46,11 @@ class AbsorbAddIntoMultiThreshold(Transformation):
         graph_modified = False
         for n in graph.node:
             node_ind += 1
-            if n.op_type == "Add":
+            if (
+                n.op_type == "Add"
+                and not model.is_fork_node(n)
+                and not model.is_join_node(n)
+            ):
                 consumer = model.find_consumer(n.output[0])
                 if consumer is not None and consumer.op_type == "MultiThreshold":
                     add_weight_name = n.input[1]
@@ -83,7 +87,11 @@ class AbsorbMulIntoMultiThreshold(Transformation):
         graph_modified = False
         for n in graph.node:
             node_ind += 1
-            if n.op_type == "Mul":
+            if (
+                n.op_type == "Mul"
+                and not model.is_fork_node(n)
+                and not model.is_join_node(n)
+            ):
                 mul_weight_name = n.input[1]
                 A = model.get_initializer(mul_weight_name)
                 assert A is not None, "Initializer for mul weights is not set."
diff --git a/src/finn/transformation/streamline/collapse_repeated.py b/src/finn/transformation/streamline/collapse_repeated.py
index aa059747b602bc6b659bc8b53b1f18988bba1ef0..67824ad4f633983b93e3178d03118927a1ddd85b 100644
--- a/src/finn/transformation/streamline/collapse_repeated.py
+++ b/src/finn/transformation/streamline/collapse_repeated.py
@@ -48,9 +48,17 @@ class CollapseRepeatedOp(Transformation):
         graph_modified = False
         for n in graph.node:
             node_ind += 1
-            if n.op_type == self.op_name:
+            if (
+                n.op_type == self.op_name
+                and not model.is_fork_node(n)
+                and not model.is_join_node(n)
+            ):
                 consumer = model.find_consumer(n.output[0])
-                if consumer is not None and consumer.op_type == self.op_name:
+                if (
+                    consumer is not None
+                    and consumer.op_type == self.op_name
+                    and not model.is_join_node(consumer)
+                ):
                     op0_param_name = n.input[1]
                     op1_param_name = consumer.input[1]
                     op0_param = model.get_initializer(op0_param_name)
diff --git a/src/finn/transformation/streamline/reorder.py b/src/finn/transformation/streamline/reorder.py
index b91ffdb3f731d27d9a6ba68b090f3881e6d7293a..0b6259a61d3eb67b7b38d4c6939019ce2893a875 100644
--- a/src/finn/transformation/streamline/reorder.py
+++ b/src/finn/transformation/streamline/reorder.py
@@ -244,7 +244,12 @@ class MoveScalarAddPastConv(Transformation):
                     start_name = n.input[0]
                     end_name = consumer.output[0]
                     conv_out_shape = model.get_tensor_shape(end_name)
-                    if all(x == 1 for x in A.shape):
+
+                    using_padding = True
+                    pads = list(get_by_name(consumer.attribute, "pads").ints)
+                    if sum(pads) == 0:
+                        using_padding = False
+                    if all(x == 1 for x in A.shape) and not using_padding:
                         # create a tensor filled with the add constant, in
                         # the shape expected by the convolution
                         conv_in_const = np.zeros(conv_in_shape, dtype=np.float32)
@@ -256,7 +261,8 @@ class MoveScalarAddPastConv(Transformation):
                         execute_node(conv_node, exec_ctx, model.graph)
                         # retrieve the conv output
                         Anew = exec_ctx[end_name]
-                        # strip out repetition
+
+                        # strip out repetition if no padding
                         Anew = Anew[0, :, 0, 0].reshape(1, -1, 1, 1)
                         # update the add weight
                         model.set_initializer(add_weight_name, Anew)
@@ -274,6 +280,7 @@ class MoveScalarAddPastConv(Transformation):
                         graph.node.remove(add_node)
                         graph.node.insert(node_ind, add_node)
                         graph_modified = True
+
         model = model.transform(InferShapes())
         return (model, graph_modified)
 
@@ -437,3 +444,90 @@ class MakeMaxPoolNHWC(Transformation):
                         graph.node.insert(node_ind - 1, consumer)
                         graph_modified = True
         return (model, graph_modified)
+
+
+class MoveOpPastFork(Transformation):
+    """Move node operations past graph forks. Used when a node before a fork
+     can be merged with nodes in the branches
+    """
+
+    def __init__(self, op_name_list):
+        super().__init__()
+        self.ops_to_move = op_name_list
+
+    def apply(self, model):
+        graph = model.graph
+        graph_modified = False
+        nodes = [n for n in graph.node]
+        node_ind = 0
+        for n in nodes:
+            node_ind += 1
+            if (
+                n.op_type in self.ops_to_move
+                and model.is_fork_node(n)
+                and not model.is_join_node(n)
+            ):
+
+                # Restrict this transform to operations with constant parameters
+                # Assuming parameters is in input 1
+                op_init_param = model.get_initializer(n.input[1])
+                if op_init_param is None:
+                    continue
+
+                # Check case when branches are empty and go
+                # to the same node
+                consumers = model.find_consumers(n.output[0])
+                unique_consumer = True
+                for consum_node in consumers[1:]:
+                    if consumers[0] != consum_node:
+                        unique_consumer = False
+                        break
+
+                if unique_consumer:
+                    continue
+
+                for consumer_node in consumers[1:]:
+                    # create new node
+                    new_param_name = model.make_new_valueinfo_name()
+                    new_output_tensor_name = model.make_new_valueinfo_name()
+                    new_node = oh.make_node(
+                        n.op_type,
+                        [n.input[0], new_param_name],
+                        [new_output_tensor_name],
+                    )
+                    graph.node.insert(node_ind, new_node)
+                    node_ind += 1
+                    model.set_initializer(new_param_name, op_init_param)
+
+                    # change consumer input tensor
+                    graph.node.remove(consumer_node)
+                    for idx, consumer_input in enumerate(consumer_node.input):
+                        if consumer_input == n.output[0]:
+                            consumer_node.input[idx] = new_output_tensor_name
+                            break
+                    else:
+                        raise Exception(
+                            "Consumer should have the current node output as input"
+                        )
+
+                    graph.node.insert(node_ind, consumer_node)
+
+                graph_modified = True
+
+        model = model.transform(InferShapes())
+        return (model, graph_modified)
+
+
+class MoveAddPastFork(MoveOpPastFork):
+    def __init__(self):
+        super().__init__(["Add"])
+
+
+class MoveMulPastFork(MoveOpPastFork):
+    def __init__(self):
+        super().__init__(["Mul"])
+
+
+class MoveLinearPastFork(MoveOpPastFork):
+    def __init__(self):
+        super().__init__(["Add", "Mul"])
diff --git a/src/finn/util/basic.py b/src/finn/util/basic.py
index bc413bf665e96be1d58a5de13b0744fd6a80f855..3880bb9591e27af5fe9d063dba2485d304e4db54 100644
--- a/src/finn/util/basic.py
+++ b/src/finn/util/basic.py
@@ -43,6 +43,13 @@ pynq_part_map["Pynq-Z1"] = "xc7z020clg400-1"
 pynq_part_map["Pynq-Z2"] = "xc7z020clg400-1"
 pynq_part_map["ZCU104"] = "xczu7ev-ffvc1156-2-e"
 
+# native AXI HP port width (in bits) for PYNQ boards
+pynq_native_port_width = dict()
+pynq_native_port_width["Pynq-Z1"] = 64
+pynq_native_port_width["Pynq-Z2"] = 64
+pynq_native_port_width["Ultra96"] = 128
+pynq_native_port_width["ZCU104"] = 128
+
 
 def get_rtlsim_trace_depth():
     """Return the trace depth for rtlsim via PyVerilator. Controllable
diff --git a/tests/end2end/test_end2end_cnv_w1a1.py b/tests/end2end/test_end2end_cnv_w1a1.py
index e6d1fc4efd61c01654ee88638698215d23a82eb3..c3359dcc82650bf0e9e8a5bc5276f5ca770ee96c 100644
--- a/tests/end2end/test_end2end_cnv_w1a1.py
+++ b/tests/end2end/test_end2end_cnv_w1a1.py
@@ -76,7 +76,7 @@ from finn.transformation.fpgadataflow.insert_fifo import InsertFIFO
 build_dir = "/tmp/" + os.environ["FINN_INST_NAME"]
 test_pynq_board = os.getenv("PYNQ_BOARD", default="Pynq-Z1")
 test_fpga_part = pynq_part_map[test_pynq_board]
-target_clk_ns = 5
+target_clk_ns = 10
 mem_mode = "decoupled"
 
 
diff --git a/tests/end2end/test_end2end_tfc_w1a1_throughput_test.py b/tests/end2end/test_end2end_tfc_w1a1.py
similarity index 98%
rename from tests/end2end/test_end2end_tfc_w1a1_throughput_test.py
rename to tests/end2end/test_end2end_tfc_w1a1.py
index 1ba149687bb80a0f977115bd380a09f70eef23f1..15c1c41b006c6f87d79a0e7eb6a4458838de5fd2 100644
--- a/tests/end2end/test_end2end_tfc_w1a1_throughput_test.py
+++ b/tests/end2end/test_end2end_tfc_w1a1.py
@@ -41,7 +41,6 @@ import onnx.numpy_helper as nph
 import finn.transformation.fpgadataflow.convert_to_hls_layers as to_hls
 import finn.transformation.streamline.absorb as absorb
 from finn.core.onnx_exec import execute_onnx
-from finn.core.throughput_test import throughput_test
 from finn.custom_op.registry import getCustomOp
 from finn.transformation.bipolar_to_xnor import ConvertBipolarMatMulToXnorPopcount
 from finn.transformation.fold_constants import FoldConstants
@@ -332,9 +331,6 @@ def test_end2end_tfc_w1a1_run_on_pynq():
         ret = execute_onnx(parent_model, {iname: x}, True)
         y = ret[oname]
         assert np.isclose(y, y_golden).all()
-        child_model = load_test_checkpoint_or_skip(sdp_node.get_nodeattr("model"))
-        res = throughput_test(child_model)
-        assert res is not None
 
     except KeyError:
         pytest.skip("PYNQ board IP address not specified")
diff --git a/tests/fpgadataflow/test_fpgadataflow_sameresize.py b/tests/fpgadataflow/test_fpgadataflow_sameresize.py
new file mode 100644
index 0000000000000000000000000000000000000000..ea6130c3891443595b038460233ebb85799ac461
--- /dev/null
+++ b/tests/fpgadataflow/test_fpgadataflow_sameresize.py
@@ -0,0 +1,125 @@
+import pytest
+import os
+import numpy as np
+
+from onnx import TensorProto, helper
+from finn.core.datatype import DataType
+from finn.core.modelwrapper import ModelWrapper
+from finn.util.basic import gen_finn_dt_tensor
+import finn.core.onnx_exec as oxe
+from finn.transformation.infer_shapes import InferShapes
+from finn.transformation.fpgadataflow.set_exec_mode import SetExecMode
+from finn.transformation.general import GiveUniqueNodeNames
+from finn.transformation.fpgadataflow.prepare_cppsim import PrepareCppSim
+from finn.transformation.fpgadataflow.compile_cppsim import CompileCppSim
+from finn.transformation.fpgadataflow.prepare_ip import PrepareIP
+from finn.transformation.fpgadataflow.hlssynth_ip import HLSSynthIP
+from finn.transformation.fpgadataflow.prepare_rtlsim import PrepareRTLSim
+
+from finn.util.basic import pynq_part_map
+
+test_pynq_board = os.getenv("PYNQ_BOARD", default="Pynq-Z1")
+test_fpga_part = pynq_part_map[test_pynq_board]
+target_clk_ns = 10
+
+
+def make_single_sameresize_modelwrapper(
+    idim, odim, kdim, stride, num_ch, idt, pad_style
+):
+    inp = helper.make_tensor_value_info(
+        "inp", TensorProto.FLOAT, [1, idim, idim, num_ch]
+    )
+    outp = helper.make_tensor_value_info(
+        "outp", TensorProto.FLOAT, [1, odim, odim, num_ch]
+    )
+
+    SameResize_node = helper.make_node(
+        "SameResize_Batch",
+        ["inp"],
+        ["outp"],
+        domain="finn",
+        backend="fpgadataflow",
+        ImgDim=idim,
+        KernelDim=kdim,
+        Stride=stride,
+        NumChannels=num_ch,
+        inputDataType=str(idt.name),
+        PaddingStyle=pad_style,
+    )
+
+    graph = helper.make_graph(
+        nodes=[SameResize_node], name="sameresize_graph", inputs=[inp], outputs=[outp]
+    )
+
+    model = helper.make_model(graph, producer_name="sameresize-model")
+    model = ModelWrapper(model)
+
+    model.set_tensor_datatype("inp", idt)
+    model.set_tensor_datatype("outp", idt)
+
+    return model
+
+
+# image dimension
+@pytest.mark.parametrize("idim", [8, 16])
+# kernel dimension
+@pytest.mark.parametrize("kdim", [2, 3])
+# stride
+@pytest.mark.parametrize("stride", [1, 2])
+# number of channels
+@pytest.mark.parametrize("num_ch", [1, 2])
+# FINN input datatype
+@pytest.mark.parametrize("idt", [DataType.INT2, DataType.INT4])
+# execution mode
+@pytest.mark.parametrize("mode", ["cppsim", "rtlsim"])
+@pytest.mark.slow
+@pytest.mark.vivado
+def test_fpgadataflow_sameresize(idim, kdim, stride, num_ch, idt, mode):
+    pad_style = 2
+    assert idim % stride == 0, "Stride must divide input dimension."
+    # number of "same" windows over the input data
+    same_windows = idim // stride
+    odim = kdim + stride * (same_windows - 1)
+
+    # generate input data
+    x = gen_finn_dt_tensor(idt, [1, idim, idim, num_ch])
+    input_dict = {"inp": x}
+
+    model = make_single_sameresize_modelwrapper(
+        idim, odim, kdim, stride, num_ch, idt, pad_style
+    )
+    model = model.transform(InferShapes())
+    model = model.transform(SetExecMode(mode))
+    model = model.transform(GiveUniqueNodeNames())
+    if mode == "cppsim":
+        model = model.transform(PrepareCppSim())
+        model = model.transform(CompileCppSim())
+    elif mode == "rtlsim":
+        model = model.transform(PrepareIP(test_fpga_part, target_clk_ns))
+        model = model.transform(HLSSynthIP())
+        model = model.transform(PrepareRTLSim())
+    y_produced = oxe.execute_onnx(model, input_dict)["outp"]
+    expected_oshape = (1, odim, odim, num_ch)
+    assert y_produced.shape == expected_oshape
+
+    # calculate reference
+    # calculate correct padding according to parameters
+    pad = odim - idim
+    if pad_style == 2:
+        if pad % 2 == 0:
+            pad_up = pad // 2
+            pad_left = pad // 2
+        else:
+            pad_up = pad // 2 + 1
+            pad_left = pad // 2 + 1
+    else:
+        pad_up = pad // 2
+        pad_left = pad // 2
+    pad_down = pad - pad_up
+    pad_right = pad - pad_left
+
+    y_expected = np.pad(
+        x, ((0, 0), (pad_up, pad_down), (pad_left, pad_right), (0, 0)), "constant"
+    )
+
+    assert (y_produced == y_expected).all()
diff --git a/tests/pynq/test_pynq_performance_end2end.py b/tests/pynq/test_pynq_performance_end2end.py
new file mode 100644
index 0000000000000000000000000000000000000000..66a93a190061e0142637be19bb2ea841d192745a
--- /dev/null
+++ b/tests/pynq/test_pynq_performance_end2end.py
@@ -0,0 +1,65 @@
+import os
+
+import pytest
+import numpy as np
+from scipy.stats import linregress
+import warnings
+from finn.util.test import load_test_checkpoint_or_skip
+from finn.core.throughput_test import throughput_test
+
+build_dir = "/tmp/" + os.environ["FINN_INST_NAME"]
+
+
+@pytest.mark.parametrize("end2end_example", ["tfc_w1a1", "cnv_w1a1"])
+@pytest.mark.slow
+def test_pynq_performance_end2end(end2end_example):
+    model = load_test_checkpoint_or_skip(
+        build_dir + "/end2end_%s_pynq_deploy.onnx" % end2end_example
+    )
+    try:
+        ip = os.environ["PYNQ_IP"]  # NOQA
+        board = os.environ["PYNQ_BOARD"]  # NOQA
+        if ip == "" or board == "":
+            pytest.skip("PYNQ board or IP address not specified")
+        ret = dict()
+        # try a range of batch sizes, some may fail due to insufficient DMA
+        # buffers
+        bsize_range_in = [2 ** i for i in range(16)]
+        bsize_range = []
+        for bsize in bsize_range_in:
+            res = throughput_test(model, bsize)
+            if res is not None:
+                ret[bsize] = res
+                bsize_range.append(bsize)
+            else:
+                # assume we reached largest possible N
+                break
+
+        y = [ret[key]["runtime[ms]"] for key in bsize_range]
+        lrret = linregress(bsize_range, y)
+        ret_str = ""
+        ret_str += "\n" + "%s Throughput Test Results" % end2end_example
+        ret_str += "\n" + "-----------------------------"
+        ret_str += "\n" + "From linear regression:"
+        ret_str += "\n" + "Invocation overhead: %f ms" % lrret.intercept
+        ret_str += "\n" + "Time per sample: %f ms" % lrret.slope
+        ret_str += "\n" + "Raw data:"
+
+        ret_str += "\n" + "{:<8} {:<16} {:<16} {:<16} {:<16} {:<16}".format(
+            "N", "runtime[ms]", "fclk[mhz]", "fps", "DRAM rd[Mb/s]", "DRAM wr[Mb/s]"
+        )
+        for k in bsize_range:
+            v = ret[k]
+            ret_str += "\n" + "{:<8} {:<16} {:<16} {:<16} {:<16} {:<16}".format(
+                k,
+                np.round(v["runtime[ms]"], 4),
+                v["fclk[mhz]"],
+                np.round(v["throughput[images/s]"], 2),
+                np.round(v["DRAM_in_bandwidth[Mb/s]"], 2),
+                np.round(v["DRAM_out_bandwidth[Mb/s]"], 2),
+            )
+        ret_str += "\n" + "-----------------------------"
+        warnings.warn(ret_str)
+
+    except KeyError:
+        pytest.skip("PYNQ board or IP address not specified")
diff --git a/tests/pynq/test_pynq_performance_fifo.py b/tests/pynq/test_pynq_performance_fifo.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d4542473c4b58d3baa62f4123fd0f2f76954d95
--- /dev/null
+++ b/tests/pynq/test_pynq_performance_fifo.py
@@ -0,0 +1,128 @@
+import os
+
+import pytest
+
+import numpy as np
+from onnx import TensorProto, helper
+
+from finn.core.datatype import DataType
+from finn.core.modelwrapper import ModelWrapper
+from finn.transformation.fpgadataflow.prepare_ip import PrepareIP
+from finn.transformation.fpgadataflow.create_stitched_ip import CreateStitchedIP
+from finn.transformation.fpgadataflow.hlssynth_ip import HLSSynthIP
+from finn.transformation.fpgadataflow.insert_tlastmarker import InsertTLastMarker
+from finn.transformation.fpgadataflow.make_deployment import DeployToPYNQ
+from finn.transformation.fpgadataflow.make_pynq_driver import MakePYNQDriver
+from finn.transformation.fpgadataflow.make_pynq_proj import MakePYNQProject
+from finn.transformation.fpgadataflow.synth_pynq_proj import SynthPYNQProject
+import finn.transformation.fpgadataflow.replace_verilog_relpaths as rvp
+from finn.transformation.general import GiveUniqueNodeNames
+from finn.util.basic import pynq_part_map, pynq_native_port_width
+from finn.core.throughput_test import throughput_test
+from scipy.stats import linregress
+import warnings
+
+
+def make_single_fifo_modelwrapper(Shape, Depth, fld_shape, finn_dtype):
+
+    inp = helper.make_tensor_value_info("inp", TensorProto.FLOAT, Shape)
+    outp = helper.make_tensor_value_info("outp", TensorProto.FLOAT, Shape)
+
+    FIFO_node = helper.make_node(
+        "StreamingFIFO",
+        ["inp"],
+        ["outp"],
+        domain="finn",
+        backend="fpgadataflow",
+        depth=Depth,
+        folded_shape=fld_shape,
+        dataType=str(finn_dtype.name),
+    )
+
+    graph = helper.make_graph(
+        nodes=[FIFO_node], name="fifo_graph", inputs=[inp], outputs=[outp]
+    )
+
+    model = helper.make_model(graph, producer_name="fifo-model")
+    model = ModelWrapper(model)
+
+    model.set_tensor_datatype("inp", finn_dtype)
+    model.set_tensor_datatype("outp", finn_dtype)
+
+    return model
+
+
+@pytest.mark.vivado
+@pytest.mark.slow
+def test_pynq_performance_fifo():
+    try:
+        ip = os.environ["PYNQ_IP"]  # NOQA
+        board = os.environ["PYNQ_BOARD"]  # NOQA
+        if ip == "" or board == "":
+            pytest.skip("PYNQ board or IP address not specified")
+        fifo_width = pynq_native_port_width[board]
+        shape = (1, fifo_width)
+        folded_shape = (1, 1, fifo_width)
+        depth = 16
+        clk_ns = 10
+        dtype = DataType.BIPOLAR
+        fpga_part = pynq_part_map[board]
+        username = os.getenv("PYNQ_USERNAME", "xilinx")
+        password = os.getenv("PYNQ_PASSWORD", "xilinx")
+        port = os.getenv("PYNQ_PORT", 22)
+        target_dir = os.getenv("PYNQ_TARGET_DIR", "/home/xilinx/finn")
+
+        model = make_single_fifo_modelwrapper(shape, depth, folded_shape, dtype)
+        model = model.transform(InsertTLastMarker())
+        model = model.transform(GiveUniqueNodeNames())
+        model = model.transform(PrepareIP(fpga_part, clk_ns))
+        model = model.transform(HLSSynthIP())
+        model = model.transform(rvp.ReplaceVerilogRelPaths())
+        model = model.transform(CreateStitchedIP(fpga_part, clk_ns))
+        model = model.transform(MakePYNQProject(board))
+        model = model.transform(SynthPYNQProject())
+        model = model.transform(MakePYNQDriver())
+        model = model.transform(DeployToPYNQ(ip, port, username, password, target_dir))
+
+        ret = dict()
+        # try a range of batch sizes, some may fail due to insufficient DMA
+        # buffers
+        bsize_range_in = [2 ** i for i in range(20)]
+        bsize_range = []
+        for bsize in bsize_range_in:
+            res = throughput_test(model, bsize)
+            if res is not None:
+                ret[bsize] = res
+                bsize_range.append(bsize)
+            else:
+                # assume we reached largest possible N
+                break
+
+        y = [ret[key]["runtime[ms]"] for key in bsize_range]
+        lrret = linregress(bsize_range, y)
+        ret_str = ""
+        ret_str += "\n" + "FIFO Throughput Test Results"
+        ret_str += "\n" + "-----------------------------"
+        ret_str += "\n" + "From linear regression:"
+        ret_str += "\n" + "Invocation overhead: %f ms" % lrret.intercept
+        ret_str += "\n" + "Time per sample: %f ms" % lrret.slope
+        ret_str += "\n" + "Raw data:"
+
+        ret_str += "\n" + "{:<8} {:<16} {:<16} {:<16} {:<16} {:<16}".format(
+            "N", "runtime[ms]", "fclk[mhz]", "fps", "DRAM rd[Mb/s]", "DRAM wr[Mb/s]"
+        )
+        for k in bsize_range:
+            v = ret[k]
+            ret_str += "\n" + "{:<8} {:<16} {:<16} {:<16} {:<16} {:<16}".format(
+                k,
+                np.round(v["runtime[ms]"], 4),
+                v["fclk[mhz]"],
+                np.round(v["throughput[images/s]"], 2),
+                np.round(v["DRAM_in_bandwidth[Mb/s]"], 2),
+                np.round(v["DRAM_out_bandwidth[Mb/s]"], 2),
+            )
+        ret_str += "\n" + "-----------------------------"
+        warnings.warn(ret_str)
+
+    except KeyError:
+        pytest.skip("PYNQ board or IP address not specified")
diff --git a/tests/transformation/test_collapse_repeated_op.py b/tests/transformation/test_collapse_repeated_op.py
index 01d932ece0be4b0beb7ad6094284ec3efb1e525e..b74d868f9b921c35ff9f596c811583f45f761374 100644
--- a/tests/transformation/test_collapse_repeated_op.py
+++ b/tests/transformation/test_collapse_repeated_op.py
@@ -34,6 +34,7 @@ import finn.core.onnx_exec as ox
 from finn.core.modelwrapper import ModelWrapper
 from finn.transformation.infer_shapes import InferShapes
 from finn.transformation.streamline import CollapseRepeatedAdd, CollapseRepeatedMul
+import pytest
 
 
 def test_collapse_repeated_op():
@@ -67,3 +68,60 @@ def test_collapse_repeated_op():
     new_model = new_model.transform(CollapseRepeatedMul())
     inp_dict = {"top_in": np.asarray([-1.0, 1.0], dtype=np.float32)}
     assert ox.compare_execution(model, new_model, inp_dict)
+    assert len(new_model.graph.node) == 2
+    assert new_model.graph.node[0].op_type == "Add"
+    assert new_model.graph.node[1].op_type == "Mul"
+
+
+@pytest.mark.parametrize(
+    "test_args", [("Add", CollapseRepeatedAdd()), ("Mul", CollapseRepeatedMul())],
+)
+def test_collapse_repeated_only_if_linear(test_args):
+    scalar_op = test_args[0]
+    transf_fxn = test_args[1]
+
+    input_shape = [4, 4]
+    output_shape = input_shape
+
+    top_in = oh.make_tensor_value_info("top_in", TensorProto.FLOAT, input_shape)
+    top_out = oh.make_tensor_value_info("top_out", TensorProto.FLOAT, output_shape)
+
+    value_info = [oh.make_tensor_value_info("p1", TensorProto.FLOAT, [1])]
+    value_info += [oh.make_tensor_value_info("p2", TensorProto.FLOAT, [1])]
+    value_info += [oh.make_tensor_value_info("p3", TensorProto.FLOAT, [1])]
+    value_info += [oh.make_tensor_value_info("p4", TensorProto.FLOAT, [1])]
+    value_info += [oh.make_tensor_value_info("p5", TensorProto.FLOAT, [1])]
+
+    modelproto = oh.make_model(
+        oh.make_graph(
+            name="test",
+            inputs=[top_in],
+            outputs=[top_out],
+            value_info=value_info,
+            nodes=[
+                oh.make_node(scalar_op, ["top_in", "p2"], ["t1"]),
+                oh.make_node(scalar_op, ["t1", "p1"], ["t2"]),
+                oh.make_node(scalar_op, ["t2", "p3"], ["t3"]),
+                oh.make_node(scalar_op, ["t2", "p4"], ["t4"]),
+                oh.make_node(scalar_op, ["t3", "t4"], ["t5"]),
+                oh.make_node(scalar_op, ["t5", "p5"], ["top_out"]),
+            ],
+        )
+    )
+    model = ModelWrapper(modelproto)
+    model = model.transform(InferShapes())
+
+    np.random.seed(0)
+    model.set_initializer("p1", *np.random.rand(1).astype(np.float32))
+    model.set_initializer("p2", *np.random.rand(1).astype(np.float32))
+    model.set_initializer("p3", *np.random.rand(1).astype(np.float32))
+    model.set_initializer("p4", *np.random.rand(1).astype(np.float32))
+    model.set_initializer("p5", *np.random.rand(1).astype(np.float32))
+
+    # Transform
+    new_model = model.transform(transf_fxn)
+
+    # Test
+    inp_dict = {"top_in": np.random.rand(*input_shape).astype(np.float32)}
+    assert ox.compare_execution(model, new_model, inp_dict)
+    assert len(new_model.graph.node) == 5
diff --git a/tests/transformation/test_move_past_fork.py b/tests/transformation/test_move_past_fork.py
new file mode 100644
index 0000000000000000000000000000000000000000..f3d37bd60c9e2580ca4499daafa8693f39fec810
--- /dev/null
+++ b/tests/transformation/test_move_past_fork.py
@@ -0,0 +1,79 @@
+from onnx import TensorProto, helper
+import numpy as np
+
+import finn.core.onnx_exec as oxe
+from finn.core.modelwrapper import ModelWrapper
+from finn.transformation.streamline.reorder import MoveLinearPastFork
+from finn.transformation.infer_shapes import InferShapes
+
+import pytest
+
+
+@pytest.mark.parametrize("ch", [64, 1])
+# ifmdim
+@pytest.mark.parametrize("ifmdim", [-1, 7])
+def test_move_past_fork(ch, ifmdim):
+    # generate test vectors of correct shape
+    if ifmdim == -1:
+        input_shape = (1, ch)
+    else:
+        input_shape = (1, ch, ifmdim, ifmdim)
+
+    top_in = helper.make_tensor_value_info("top_in", TensorProto.FLOAT, input_shape)
+    top_out = helper.make_tensor_value_info("top_out", TensorProto.FLOAT, input_shape)
+
+    num_of_params = 8
+    value_info = []
+    for i in range(num_of_params):
+        value_info += [
+            helper.make_tensor_value_info("p" + str(i), TensorProto.FLOAT, input_shape)
+        ]
+
+    add_1_to_move = helper.make_node("Add", ["top_in", "p0"], ["fork1"])
+    mul_1_to_move = helper.make_node("Mul", ["t5", "p4"], ["fork2"])
+    add_2_to_move = helper.make_node("Add", ["fork2", "p5"], ["t6"])
+    mul_1_not_to_move = helper.make_node("Mul", ["t8", "p7"], ["fork3"])
+    modelproto = helper.make_model(
+        helper.make_graph(
+            name="test",
+            inputs=[top_in],
+            outputs=[top_out],
+            value_info=value_info,
+            nodes=[
+                # fork1
+                add_1_to_move,
+                helper.make_node("Mul", ["fork1", "p1"], ["t2"]),
+                helper.make_node("Mul", ["fork1", "p2"], ["t3"]),
+                helper.make_node("Add", ["t2", "t3"], ["t4"]),
+                helper.make_node("Add", ["t4", "p3"], ["t5"]),
+                # fork2
+                mul_1_to_move,
+                add_2_to_move,
+                helper.make_node("Add", ["fork2", "p6"], ["t7"]),
+                helper.make_node("Add", ["t6", "t7"], ["t8"]),
+                # empty branches: do nothing
+                mul_1_not_to_move,
+                helper.make_node("Add", ["fork3", "fork3"], ["top_out"]),
+            ],
+        )
+    )
+    model = ModelWrapper(modelproto)
+    model = model.transform(InferShapes())
+
+    np.random.seed(0)
+    for i in range(num_of_params):
+        model.set_initializer(
+            "p" + str(i), np.random.rand(*input_shape).astype(np.float32)
+        )
+
+    # Transform
+    new_model = model.transform(MoveLinearPastFork())
+    inp_dict = {"top_in": np.random.rand(*input_shape).astype(np.float32)}
+
+    # Test
+    assert oxe.compare_execution(model, new_model, inp_dict)
+    assert not new_model.is_fork_node(add_1_to_move)
+    assert not new_model.is_fork_node(mul_1_to_move)
+    assert not new_model.is_fork_node(add_2_to_move)
+    assert new_model.is_fork_node(mul_1_not_to_move)
+    assert len(new_model.graph.node) == 14
diff --git a/tests/transformation/test_move_scalar_past_conv.py b/tests/transformation/test_move_scalar_past_conv.py
index 9992d17b96ab5f419f3ac495f126ddfa736349a2..0f50642d2b9d1583030630cb4927c2b86667e71a 100644
--- a/tests/transformation/test_move_scalar_past_conv.py
+++ b/tests/transformation/test_move_scalar_past_conv.py
@@ -12,6 +12,85 @@ from finn.transformation.streamline import (
 )
 
 
+@pytest.mark.parametrize("padding", [False, True])
+@pytest.mark.parametrize(
+    "test_args", [("Add", MoveScalarAddPastConv()), ("Mul", MoveScalarMulPastConv())],
+)
+def test_move_scalar_past_conv(test_args, padding):
+    scalar_op = test_args[0]
+    transf_fxn = test_args[1]
+
+    in_feature_dim = 7
+    in_chn = 3
+
+    stages = 2
+    kernel_size = 3
+
+    out_feature_dim = (
+        in_feature_dim if padding else in_feature_dim - (kernel_size // 2 * 2) * stages
+    )
+
+    input_shape = [1, in_chn, in_feature_dim, in_feature_dim]
+    output_shape = [1, in_chn, out_feature_dim, out_feature_dim]
+
+    conv_param_shape = [in_chn, in_chn, kernel_size, kernel_size]
+
+    conv_config = {}
+    conv_config["dilations"] = [1, 1]
+    conv_config["group"] = 1
+    conv_config["kernel_shape"] = [kernel_size, kernel_size]
+    if padding:
+        conv_config["pads"] = [1, 1, 1, 1]
+    else:
+        conv_config["pads"] = [0, 0, 0, 0]
+    conv_config["strides"] = [1, 1]
+
+    top_in = oh.make_tensor_value_info("top_in", TensorProto.FLOAT, input_shape)
+    top_out = oh.make_tensor_value_info("top_out", TensorProto.FLOAT, output_shape)
+
+    value_info = [oh.make_tensor_value_info("p1", TensorProto.FLOAT, [1])]
+    value_info += [oh.make_tensor_value_info("p2", TensorProto.FLOAT, conv_param_shape)]
+    value_info += [oh.make_tensor_value_info("p3", TensorProto.FLOAT, conv_param_shape)]
+
+    modelproto = oh.make_model(
+        oh.make_graph(
+            name="test",
+            inputs=[top_in],
+            outputs=[top_out],
+            value_info=value_info,
+            nodes=[
+                oh.make_node(scalar_op, ["top_in", "p1"], ["t1"]),
+                oh.make_node("Conv", ["t1", "p2"], ["t2"], **conv_config),
+                oh.make_node("Conv", ["t2", "p3"], ["top_out"], **conv_config),
+            ],
+        )
+    )
+    model = ModelWrapper(modelproto)
+    model = model.transform(InferShapes())
+
+    np.random.seed(0)
+    model.set_initializer("p1", *np.random.rand(1).astype(np.float32))
+    model.set_initializer("p2", np.random.rand(*conv_param_shape).astype(np.float32))
+    model.set_initializer("p3", np.random.rand(*conv_param_shape).astype(np.float32))
+    new_model = model.transform(transf_fxn)
+    inp_dict = {"top_in": np.random.rand(*input_shape).astype(np.float32)}
+
+    assert ox.compare_execution(model, new_model, inp_dict)
+    if scalar_op == "Add":
+        if padding:
+            assert new_model.graph.node[0].op_type == scalar_op
+            assert new_model.graph.node[1].op_type == "Conv"
+            assert new_model.graph.node[2].op_type == "Conv"
+        else:
+            assert new_model.graph.node[0].op_type == "Conv"
+            assert new_model.graph.node[1].op_type == scalar_op
+            assert new_model.graph.node[2].op_type == "Conv"
+    else:
+        assert new_model.graph.node[0].op_type == "Conv"
+        assert new_model.graph.node[1].op_type == "Conv"
+        assert new_model.graph.node[2].op_type == scalar_op
+
+
 @pytest.mark.parametrize(
     "test_args", [("Add", MoveScalarAddPastConv()), ("Mul", MoveScalarMulPastConv())],
 )