diff --git a/AUTHORS.rst b/AUTHORS.rst
index e231e61d38991e11e2e43a7c9a3a78c50c878244..eb1e06e54b7eb6deedd3e7f8392bb3aa257e7dc6 100644
--- a/AUTHORS.rst
+++ b/AUTHORS.rst
@@ -6,3 +6,5 @@ Contributors
 * Jakoba Petri-Koenig (@auphelia)
 * Andrea Rigoni (@AndreaRigoni)
 * Hendrik Borras (@HenniOVP)
+* Lucian Petrica (@quetric)
+* Tobias Alonso (@Tobi-Alonso)
diff --git a/src/finn/custom_op/fpgadataflow/__init__.py b/src/finn/custom_op/fpgadataflow/__init__.py
index c77fd81c0bfaa77b458368807410b8bfec17abb7..17a55e519ed0440f68e295aecaab179e6adf632f 100644
--- a/src/finn/custom_op/fpgadataflow/__init__.py
+++ b/src/finn/custom_op/fpgadataflow/__init__.py
@@ -40,6 +40,7 @@ from finn.util.basic import (
 from finn.util.fpgadataflow import (
     IPGenBuilder,
     pyverilate_get_liveness_threshold_cycles,
+    rtlsim_multi_io,
 )
 from . import templates
 
@@ -318,14 +319,24 @@ Found no codegen dir for this node, did you run the prepare_cppsim transformatio
             )
 
     def npy_to_dynamic_output(self, context):
-        """Reads the output from a .npy file and saves it at the right place in
-        the context dictionary."""
-        # TODO support multi-output nodes as needed
+        """Reads the output from an output.npy file generated from cppsim and
+        places its content into the context dictionary."""
         node = self.onnx_node
         code_gen_dir = self.get_nodeattr("code_gen_dir_cppsim")
         output = np.load("{}/output.npy".format(code_gen_dir))
         context[node.output[0]] = output
 
+    def npy_to_dynamic_outputs(self, context, npy_list):
+        """Reads the output from .npy files generated from cppsim and places
+        their content into the context dictionary.
+        npy_list is a list specifying which files to read, and its order must
+        match the order of node outputs."""
+        node = self.onnx_node
+        code_gen_dir = self.get_nodeattr("code_gen_dir_cppsim")
+        for i in range(len(npy_list)):
+            output = np.load("{}/{}".format(code_gen_dir, npy_list[i]))
+            context[node.output[i]] = output
+
     def exec_precompiled_singlenode_model(self):
         """Executes precompiled executable."""
         executable_path = self.get_nodeattr("executable_path")
@@ -421,6 +432,16 @@ compilation transformations?
             sim.stop_vcd_trace()
         return outputs
 
+    def rtlsim_multi_io(self, sim, io_dict):
+        "Run rtlsim for this node, supports multiple i/o streams."
+
+        trace_file = self.get_nodeattr("rtlsim_trace")
+        if trace_file == "default":
+            trace_file = self.onnx_node.name + ".vcd"
+        num_out_values = self.get_number_output_values()
+        total_cycle_count = rtlsim_multi_io(sim, io_dict, num_out_values, trace_file)
+        self.set_nodeattr("sim_cycles", total_cycle_count)
+
     def execute_node(self, context, graph):
         """Executes single node using cppsim or rtlsim."""
         mode = self.get_nodeattr("exec_mode")
diff --git a/src/finn/custom_op/fpgadataflow/duplicatestreams_batch.py b/src/finn/custom_op/fpgadataflow/duplicatestreams_batch.py
new file mode 100644
index 0000000000000000000000000000000000000000..54051af5e0387081a23e1f8fa77ec9e363098830
--- /dev/null
+++ b/src/finn/custom_op/fpgadataflow/duplicatestreams_batch.py
@@ -0,0 +1,361 @@
+# Copyright (c) 2020, Xilinx
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+#   list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+#   this list of conditions and the following disclaimer in the documentation
+#   and/or other materials provided with the distribution.
+#
+# * Neither the name of FINN nor the names of its
+#   contributors may be used to endorse or promote products derived from
+#   this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import os
+
+import numpy as np
+
+from finn.core.datatype import DataType
+from finn.custom_op.fpgadataflow import HLSCustomOp
+from onnx import TensorProto, helper
+from finn.util.data_packing import npy_to_rtlsim_input, rtlsim_output_to_npy
+
+
+class DuplicateStreams_Batch(HLSCustomOp):
+    """Class that corresponds to finn-hlslib function of the same name."""
+
+    def __init__(self, onnx_node):
+        super().__init__(onnx_node)
+
+    def get_nodeattr_types(self):
+        my_attrs = {
+            "NumChannels": ("i", True, 0),
+            "PE": ("i", True, 0),
+            # FINN DataTypes for input
+            "inputDataType": ("s", True, ""),
+            # number of input vectors, examples:
+            # [1] is a single vector (like a FC layer with batch=1)
+            # [4] is four vectors (like a FC layer with batch=4)
+            # [1, 4, 4] is four * four vectors (like a conv layer with batch=1)
+            "numInputVectors": ("ints", False, [1]),
+        }
+        my_attrs.update(super().get_nodeattr_types())
+        return my_attrs
+
+    def get_normal_input_shape(self):
+        ch = self.get_nodeattr("NumChannels")
+        vecs = list(self.get_nodeattr("numInputVectors"))
+        ishape = tuple(vecs + [ch])
+        return ishape
+
+    def get_folded_input_shape(self):
+        ch = self.get_nodeattr("NumChannels")
+        pe = self.get_nodeattr("PE")
+        vecs = list(self.get_nodeattr("numInputVectors"))
+        assert ch % pe == 0, "PE must divide NumChannels"
+        folds = int(ch / pe)
+        folded_ishape = tuple(vecs + [folds, pe])
+        return folded_ishape
+
+    def get_normal_output_shape(self):
+        return self.get_normal_input_shape()
+
+    def get_folded_output_shape(self):
+        return self.get_folded_input_shape()
+
+    def make_shape_compatible_op(self, model):
+        exp_ishape = self.get_normal_input_shape()
+        oshape = self.get_normal_output_shape()
+        ishape = tuple(model.get_tensor_shape(self.onnx_node.input[0]))
+        assert ishape == exp_ishape, "Unexpected input shape."
+        # implement tensor with correct shape
+        values = np.random.randn(*oshape).astype(np.float32)
+        split_input = np.concatenate((values, values), axis=0)
+        return helper.make_node(
+            "Split",
+            inputs=[split_input],
+            outputs=[self.onnx_node.output[0], self.onnx_node.output[0]],
+            value=helper.make_tensor(
+                name="const_tensor", data_type=TensorProto.FLOAT, axis=0
+            ),
+        )
+
+    def infer_node_datatype(self, model):
+        odt = self.get_output_datatype()
+        model.set_tensor_datatype(self.onnx_node.output[0], odt)
+
+    def verify_node(self):
+        info_messages = []
+        # verify that "domain" is set to "finn"
+        domain_value = self.onnx_node.domain
+        if domain_value == "finn":
+            info_messages.append("Attribute domain is set correctly")
+        else:
+            info_messages.append('Attribute domain should be set to "finn"')
+
+        # verify that "backend" is set to "fpgadataflow"
+        backend_value = self.get_nodeattr("backend")
+        if backend_value == "fpgadataflow":
+            info_messages.append("Attribute backend is set correctly")
+        else:
+            info_messages.append('Attribute backend should be set to "fpgadataflow"')
+
+        # verify that all necessary attributes exist
+        try:
+            self.get_nodeattr("code_gen_dir_cppsim")
+            self.get_nodeattr("executable_path")
+            self.get_nodeattr("NumChannels")
+            self.get_nodeattr("PE")
+            self.get_nodeattr("inputDataType")
+            info_messages.append("All necessary attributes exist")
+        except Exception:
+            info_messages.append(
+                """The required GlobalAccPool_Batch attributes do not exist."""
+            )
+
+        return info_messages
+
+    def get_input_datatype(self):
+        """Returns FINN DataType of input."""
+        return DataType[self.get_nodeattr("inputDataType")]
+
+    def get_output_datatype(self):
+        """Returns FINN DataType of output."""
+        return DataType[self.get_nodeattr("inputDataType")]
+
+    def get_instream_width(self):
+        """Returns input stream width."""
+        ibits = self.get_input_datatype().bitwidth()
+        pe = self.get_nodeattr("PE")
+        in_width = pe * ibits
+        return in_width
+
+    def get_outstream_width(self):
+        """Returns output stream width."""
+        obits = self.get_output_datatype().bitwidth()
+        pe = self.get_nodeattr("PE")
+        out_width = pe * obits
+        return out_width
+
+    def get_number_output_values(self):
+        return 2 * np.prod(self.get_folded_output_shape()[1:-1])
+
+    def execute_node(self, context, graph):
+        mode = self.get_nodeattr("exec_mode")
+        node = self.onnx_node
+        exp_ishape = self.get_normal_input_shape()
+        exp_oshape = self.get_normal_output_shape()
+        folded_ishape = self.get_folded_input_shape()
+        folded_oshape = self.get_folded_output_shape()
+
+        if mode == "cppsim":
+            code_gen_dir = self.get_nodeattr("code_gen_dir_cppsim")
+        elif mode == "rtlsim":
+            code_gen_dir = self.get_nodeattr("code_gen_dir_ipgen")
+        else:
+            raise Exception(
+                """Invalid value for attribute exec_mode! Is currently set to: {}
+            has to be set to one of the following value ("cppsim", "rtlsim")""".format(
+                    mode
+                )
+            )
+
+        inp = context[node.input[0]]
+        assert str(inp.dtype) == "float32", "Input datatype is not float32"
+        assert inp.shape == exp_ishape, """Input shape doesn't match expected shape ."""
+        export_idt = self.get_input_datatype()
+        # reshape input into folded form
+        inp = inp.reshape(folded_ishape)
+        # make copy before saving array
+        reshaped_input = inp.copy()
+        np.save(os.path.join(code_gen_dir, "input_0.npy"), reshaped_input)
+
+        if mode == "cppsim":
+            # execute the precompiled model
+            super().exec_precompiled_singlenode_model()
+            # load output npy file
+            super().npy_to_dynamic_outputs(context, ["output0.npy", "output1.npy"])
+            assert (
+                context[node.output[0]].shape == folded_oshape
+            ), "cppsim \
+            did not produce expected ofolded utput shape"
+            assert (
+                context[node.output[1]].shape == folded_oshape
+            ), "cppsim \
+            did not produce expected ofolded utput shape"
+            context[node.output[0]] = context[node.output[0]].reshape(*exp_oshape)
+            context[node.output[1]] = context[node.output[1]].reshape(*exp_oshape)
+        elif mode == "rtlsim":
+            sim = self.get_rtlsim()
+            nbits = self.get_instream_width()
+            rtlsim_inp = npy_to_rtlsim_input(
+                "{}/input_0.npy".format(code_gen_dir), export_idt, nbits
+            )
+            super().reset_rtlsim(sim)
+            super().toggle_clk(sim)
+            rtlsim_dict = {
+                "inputs": {"in0": rtlsim_inp},
+                "outputs": {"out0": [], "out1": []},
+            }
+            self.rtlsim_multi_io(sim, rtlsim_dict)
+            odt = self.get_output_datatype()
+            target_bits = odt.bitwidth()
+            packed_bits = self.get_outstream_width()
+            out_shape = self.get_folded_output_shape()
+
+            out_npy_path = "{}/output0.npy".format(code_gen_dir)
+            rtlsim_output_to_npy(
+                rtlsim_dict["outputs"]["out0"],
+                out_npy_path,
+                odt,
+                out_shape,
+                packed_bits,
+                target_bits,
+            )
+            # load and reshape output 0
+            output = np.load(out_npy_path)
+            output = np.asarray([output], dtype=np.float32).reshape(*exp_oshape)
+            context[node.output[0]] = output
+
+            out_npy_path = "{}/output1.npy".format(code_gen_dir)
+            rtlsim_output_to_npy(
+                rtlsim_dict["outputs"]["out1"],
+                out_npy_path,
+                odt,
+                out_shape,
+                packed_bits,
+                target_bits,
+            )
+            # load and reshape output 1
+            output = np.load(out_npy_path)
+            output = np.asarray([output], dtype=np.float32).reshape(*exp_oshape)
+            context[node.output[1]] = output
+        else:
+            raise Exception(
+                """Invalid value for attribute exec_mode! Is currently set to: {}
+            has to be set to one of the following value ("cppsim", "rtlsim")""".format(
+                    mode
+                )
+            )
+
+        assert (
+            context[node.output[0]].shape == exp_oshape
+        ), """Output0 shape doesn't match expected shape."""
+        assert (
+            context[node.output[1]].shape == exp_oshape
+        ), """Output1 shape doesn't match expected shape."""
+
+    def global_includes(self):
+        self.code_gen_dict["$GLOBALS$"] = ['#include "streamtools.h"']
+
+    def defines(self, var):
+        self.code_gen_dict["$DEFINES$"] = []
+
+    def read_npy_data(self):
+        code_gen_dir = self.get_nodeattr("code_gen_dir_cppsim")
+        dtype = self.get_input_datatype()
+        elem_bits = dtype.bitwidth()
+        packed_bits = self.get_instream_width()
+        packed_hls_type = "ap_uint<%d>" % packed_bits
+        elem_hls_type = dtype.get_hls_datatype_str()
+        npy_type = "float"
+        npy_in = "%s/input_0.npy" % code_gen_dir
+        self.code_gen_dict["$READNPYDATA$"] = []
+        self.code_gen_dict["$READNPYDATA$"].append(
+            'npy2apintstream<%s, %s, %d, %s>("%s", in0);'
+            % (packed_hls_type, elem_hls_type, elem_bits, npy_type, npy_in)
+        )
+
+    def strm_decl(self):
+        self.code_gen_dict["$STREAMDECLARATIONS$"] = []
+        self.code_gen_dict["$STREAMDECLARATIONS$"].append(
+            'hls::stream<ap_uint<{}>> in0 ("in0");'.format(self.get_instream_width())
+        )
+        self.code_gen_dict["$STREAMDECLARATIONS$"].append(
+            'hls::stream<ap_uint<{}>> out0 ("out0");'.format(self.get_outstream_width())
+        )
+        self.code_gen_dict["$STREAMDECLARATIONS$"].append(
+            'hls::stream<ap_uint<{}>> out1 ("out1");'.format(self.get_outstream_width())
+        )
+
+    def docompute(self):
+        self.code_gen_dict["$DOCOMPUTE$"] = [
+            """DuplicateStreams_Batch<{}, {}> (in0, out0, out1, 1);""".format(
+                self.get_outstream_width(), self.get_number_output_values() // 2,
+            )
+        ]
+
+    def dataoutstrm(self):
+        code_gen_dir = self.get_nodeattr("code_gen_dir_cppsim")
+        dtype = self.get_output_datatype()
+        elem_bits = dtype.bitwidth()
+        packed_bits = self.get_outstream_width()
+        packed_hls_type = "ap_uint<%d>" % packed_bits
+        elem_hls_type = dtype.get_hls_datatype_str()
+        npy_type = "float"
+        npy_out = "%s/output0.npy" % code_gen_dir
+        npy_out1 = "%s/output1.npy" % code_gen_dir
+        oshape = self.get_folded_output_shape()
+        oshape_cpp_str = str(oshape).replace("(", "{").replace(")", "}")
+
+        self.code_gen_dict["$DATAOUTSTREAM$"] = [
+            'apintstream2npy<%s, %s, %d, %s>(out0, %s, "%s");'
+            % (
+                packed_hls_type,
+                elem_hls_type,
+                elem_bits,
+                npy_type,
+                oshape_cpp_str,
+                npy_out,
+            )
+        ]
+
+        self.code_gen_dict["$DATAOUTSTREAM$"] += [
+            'apintstream2npy<%s, %s, %d, %s>(out1, %s, "%s");'
+            % (
+                packed_hls_type,
+                elem_hls_type,
+                elem_bits,
+                npy_type,
+                oshape_cpp_str,
+                npy_out1,
+            )
+        ]
+
+    def save_as_npy(self):
+        self.code_gen_dict["$SAVEASCNPY$"] = []
+
+    def blackboxfunction(self):
+        self.code_gen_dict["$BLACKBOXFUNCTION$"] = [
+            """void {}(hls::stream<ap_uint<{}>> &in0,
+                hls::stream<ap_uint<{}>> &out0,
+                hls::stream<ap_uint<{}>> &out1)""".format(
+                self.onnx_node.name,
+                self.get_instream_width(),
+                self.get_outstream_width(),
+                self.get_outstream_width(),
+            )
+        ]
+
+    def pragmas(self):
+        self.code_gen_dict["$PRAGMAS$"] = ["#pragma HLS INTERFACE axis port=in0"]
+        self.code_gen_dict["$PRAGMAS$"].append("#pragma HLS INTERFACE axis port=out0")
+        self.code_gen_dict["$PRAGMAS$"].append("#pragma HLS INTERFACE axis port=out1")
+        self.code_gen_dict["$PRAGMAS$"].append(
+            "#pragma HLS INTERFACE ap_ctrl_none port=return"
+        )
diff --git a/src/finn/custom_op/fpgadataflow/thresholding_batch.py b/src/finn/custom_op/fpgadataflow/thresholding_batch.py
new file mode 100644
index 0000000000000000000000000000000000000000..fa33c70218fab16f106da45e296f0d59ae4ea606
--- /dev/null
+++ b/src/finn/custom_op/fpgadataflow/thresholding_batch.py
@@ -0,0 +1,551 @@
+# Copyright (c) 2020, Xilinx
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+#   list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+#   this list of conditions and the following disclaimer in the documentation
+#   and/or other materials provided with the distribution.
+#
+# * Neither the name of FINN nor the names of its
+#   contributors may be used to endorse or promote products derived from
+#   this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+from math import ceil
+import os
+
+import numpy as np
+
+from onnx import TensorProto, helper
+from finn.core.datatype import DataType
+from finn.custom_op.fpgadataflow import HLSCustomOp
+from finn.util.basic import interleave_matrix_outer_dim_from_partitions
+from finn.util.data_packing import (
+    npy_to_rtlsim_input,
+    numpy_to_hls_code,
+    rtlsim_output_to_npy,
+)
+from . import templates
+
+# ONNX i/o tensor shape assumptions for Thresholding:
+# input 0 is the input tensor, shape (..., NumChannels)
+# input 1 is the threshold tensor, shape (NumChannels, n_thres)
+# output 0 is the output tensor, shape (..., NumChannels) - same as input
+# the ... here can be any shape (representing groups of vectors)
+
+
+class Thresholding_Batch(HLSCustomOp):
+    """Class that corresponds to finn-hls Thresholding_Batch function."""
+
+    def __init__(self, onnx_node):
+        super().__init__(onnx_node)
+        self.decoupled_wrapper = templates.decoupled_wrapper
+
+    def get_nodeattr_types(self):
+        my_attrs = {
+            "PE": ("i", True, 0),
+            "NumChannels": ("i", True, 0),
+            # string defining memory type
+            "ram_style": ("s", False, "distributed"),
+            # FINN DataTypes for inputs, weights, outputs
+            "inputDataType": ("s", True, ""),
+            "outputDataType": ("s", True, ""),
+            # input and output FIFO depths
+            "inFIFODepth": ("i", False, 0),
+            "outFIFODepth": ("i", False, 0),
+            # number of input vectors, examples:
+            # [1] is a single vector (like a FC layer with batch=1)
+            # [4] is four vectors (like a FC layer with batch=4)
+            # [1, 4, 4] is four * four vectors (like a conv layer with batch=1)
+            "numInputVectors": ("ints", False, [1]),
+        }
+        my_attrs.update(super().get_nodeattr_types())
+        return my_attrs
+
+    def calc_tmem(self):
+        """Calculates and returns TMEM."""
+        mh = self.get_nodeattr("NumChannels")
+        pe = self.get_nodeattr("PE")
+        return mh // pe
+
+    def make_shape_compatible_op(self, model):
+        oshape = self.get_normal_output_shape()
+        # implement tensor with correct shape
+        values = np.random.randn(*oshape).astype(np.float32)
+        return helper.make_node(
+            "Constant",
+            inputs=[],
+            outputs=[self.onnx_node.output[0]],
+            value=helper.make_tensor(
+                name="const_tensor",
+                data_type=TensorProto.FLOAT,
+                dims=values.shape,
+                vals=values.flatten().astype(float),
+            ),
+        )
+
+    def infer_node_datatype(self, model):
+        node = self.onnx_node
+        # check input datatype against property
+        idt_name = self.get_input_datatype().name
+        exp_idt_name = self.get_nodeattr("inputDataType")
+        assert exp_idt_name == idt_name, "Bad input DataType for Thresholding layer"
+        # set output datatype from property
+        odt = self.get_output_datatype()
+        model.set_tensor_datatype(node.output[0], odt)
+
+    def verify_node(self):
+        info_messages = []
+        # verify that "domain" is set to "finn"
+        domain_value = self.onnx_node.domain
+        if domain_value == "finn":
+            info_messages.append("Attribute domain is set correctly")
+        else:
+            info_messages.append('Attribute domain should be set to "finn"')
+
+        # verify that "backend" is set to "fpgadataflow"
+        backend_value = self.get_nodeattr("backend")
+        if backend_value == "fpgadataflow":
+            info_messages.append("Attribute backend is set correctly")
+        else:
+            info_messages.append('Attribute backend should be set to "fpgadataflow"')
+
+        # verify that all necessary attributes exist
+        # TODO collect automatically from get_nodeattr_types
+        try:
+            self.get_nodeattr("code_gen_dir_cppsim")
+            self.get_nodeattr("executable_path")
+            self.get_nodeattr("NumChannels")
+            self.get_nodeattr("PE")
+            self.get_nodeattr("inputDataType")
+            self.get_nodeattr("outputDataType")
+            info_messages.append("All necessary attributes exist")
+        except Exception:
+            info_messages.append(
+                """The required Threshold_Batch attributes do not exist."""
+            )
+
+        return info_messages
+
+    def bram_estimation(self):
+        """Calculates BRAM cost if resource set to BRAM"""
+        style = self.get_nodeattr("ram_style")
+        P = self.get_nodeattr("PE")
+        idt = self.get_input_datatype()
+        A = idt.bitwidth()
+        tmem = self.calc_tmem()
+
+        if style == "block" and tmem > 1:
+            return int(ceil(A * P / 16)) * int(ceil(tmem / 1024))
+        else:
+            return 0
+
+    def lut_estimation(self):
+        """Calculates LUT cost, taking memory resource type into account """
+        # TODO add in/out FIFO contributions
+        style = self.get_nodeattr("ram_style")
+        P = self.get_nodeattr("PE")
+        idt = self.get_input_datatype()
+        A = idt.bitwidth()
+        tmem = self.calc_tmem()
+        # cost of comparators
+        comparator_cost = A * P
+        # cost of LUTRAM
+        if style == "distributed" and tmem > 1:
+            lutram_cost = P * A * int(ceil(tmem / 64))
+        else:
+            lutram_cost = 0
+        # total cost
+        return comparator_cost + lutram_cost
+
+    def get_input_datatype(self):
+        """Returns FINN DataType of input."""
+        return DataType[self.get_nodeattr("inputDataType")]
+
+    def get_output_datatype(self):
+        """Returns FINN DataType of output."""
+        return DataType[self.get_nodeattr("outputDataType")]
+
+    def get_instream_width(self):
+        i_bits = self.get_input_datatype().bitwidth()
+        return i_bits * self.get_nodeattr("PE")
+
+    def get_outstream_width(self):
+        o_bits = self.get_output_datatype().bitwidth()
+        return o_bits * self.get_nodeattr("PE")
+
+    def get_folded_input_shape(self):
+        ich = self.get_nodeattr("NumChannels")
+        pe = self.get_nodeattr("PE")
+        fold = ich // pe
+        vecs = list(self.get_nodeattr("numInputVectors"))
+        folded_input_shape = tuple(vecs + [fold, pe])
+        return folded_input_shape
+
+    def get_folded_output_shape(self):
+        # same shape as input
+        return self.get_folded_input_shape()
+
+    def get_normal_input_shape(self):
+        ich = self.get_nodeattr("NumChannels")
+        vecs = list(self.get_nodeattr("numInputVectors"))
+        normal_input_shape = tuple(vecs + [ich])
+        return normal_input_shape
+
+    def get_normal_output_shape(self):
+        # same shape as input
+        return self.get_normal_input_shape()
+
+    def get_number_output_values(self):
+        nf = np.prod(self.get_folded_output_shape()[:-1])
+        return nf
+
+    def get_template_param_values(self):
+        """Returns the template parameter values according to input, output and weight
+        data types."""
+        ret = dict()
+        inp_hls_str = self.get_input_datatype().get_hls_datatype_str()
+        out_hls_str = self.get_output_datatype().get_hls_datatype_str()
+        # fill in TSrcI
+        ret["TSrcI"] = "Slice<%s>" % inp_hls_str
+        # fill in TDstI
+        ret["TDstI"] = "Slice<%s>" % out_hls_str
+
+        return ret
+
+    def get_hls_compatible_threshold_tensor(self, orig_thres_matrix):
+        """Convert the original numpy weight matrix orig_weight_matrix into
+        a form suitable for passing to the hlslib call:
+        * ensure MH % PE == 0
+        * for unsigned inputs, ensure thresholds are positive
+        * interleave rows between PEs
+        * reshape into (PE, TMEM, n_thres_steps) and return
+        """
+        mh = self.get_nodeattr("NumChannels")
+        pe = self.get_nodeattr("PE")
+        tmem = mh // pe
+        assert mh % pe == 0, "Requirement NumChannels divisable by PE is violated."
+        assert (
+            orig_thres_matrix.ndim == 2
+        ), """Threshold matrix dimension is
+        not as expected (2)."""
+        n_thres_steps = orig_thres_matrix.shape[1]
+        if not self.get_input_datatype().signed():
+            # ensure all thresholds are nonnegative
+            assert (orig_thres_matrix >= 0).all()
+        # ensure all thresholds are integer
+        assert (orig_thres_matrix.astype(np.int32) == orig_thres_matrix).all()
+        ret = orig_thres_matrix
+        # ensure channels = mh , duplicating if necessary
+        if ret.shape[0] == 1:
+            ret = np.tile(ret, (mh, 1))
+        assert (
+            ret.shape[0] == mh
+        ), "Channels of threshold matrix are not as expected (mh)"
+        # distribute rows between PEs
+        ret = interleave_matrix_outer_dim_from_partitions(ret, pe)
+        assert (
+            ret.shape[0] == pe
+        ), """First dimension after distribution of the
+        rows between PEs is not as expected (pe)"""
+        assert (
+            ret.shape[1] == tmem
+        ), """Second dimension after distribution of the
+        rows between PEs is not as expected (tmem)"""
+        assert (
+            ret.shape[2] == n_thres_steps
+        ), """Third dimension after distribution of the
+        rows between PEs is not as expected (n_thres_steps)"""
+        return ret.reshape(1, pe, tmem, n_thres_steps)
+
+    def generate_params(self, model, path):
+        code_gen_dir = path
+        # save thresholds in thresh.h
+        thresholds = model.get_initializer(self.onnx_node.input[1])
+
+        threshold_tensor = self.get_hls_compatible_threshold_tensor(thresholds)
+        tdt = DataType.INT32
+        thresholds_hls_code = numpy_to_hls_code(
+            threshold_tensor, tdt, "thresholds", False, True
+        )
+        # write thresholds into thresh.h
+        f_thresh = open("{}/thresh.h".format(code_gen_dir), "w")
+        tdt_hls = tdt.get_hls_datatype_str()
+        # use binary to export bipolar activations
+        export_odt = self.get_output_datatype()
+        if self.get_output_datatype() == DataType.BIPOLAR:
+            export_odt = DataType.BINARY
+        odt_hls = export_odt.get_hls_datatype_str()
+        f_thresh.write(
+            "static ThresholdsActivation<{},{},{},{},{},{},{}> threshs \
+            = ".format(
+                self.calc_tmem(),
+                self.get_nodeattr("PE"),
+                threshold_tensor.shape[-1],
+                tdt_hls,
+                odt_hls,
+                export_odt.min(),
+                "std::less_equal<%s>" % tdt_hls,
+            )
+        )
+        f_thresh.write(thresholds_hls_code)
+        f_thresh.close()
+
+    def execute_node(self, context, graph):
+        mode = self.get_nodeattr("exec_mode")
+        node = self.onnx_node
+
+        # TODO ensure codegen dir exists
+        if mode == "cppsim":
+            code_gen_dir = self.get_nodeattr("code_gen_dir_cppsim")
+        elif mode == "rtlsim":
+            code_gen_dir = self.get_nodeattr("code_gen_dir_ipgen")
+        else:
+            raise Exception(
+                """Invalid value for attribute exec_mode! Is currently set to: {}
+            has to be set to one of the following value ("cppsim", "rtlsim")""".format(
+                    mode
+                )
+            )
+
+        # create a npy file fore each input of the node (in_ind is input index)
+        in_ind = 0
+        for inputs in node.input:
+            # it is assumed that the first input of the node is the data input
+            # the second input are the weights
+            # the third input are the thresholds
+            if in_ind == 0:
+                assert (
+                    str(context[inputs].dtype) == "float32"
+                ), """Input datatype is
+                not float32 as expected."""
+                expected_inp_shape = self.get_folded_input_shape()
+                reshaped_input = context[inputs].reshape(expected_inp_shape)
+                if self.get_input_datatype() == DataType.BIPOLAR:
+                    # store bipolar activations as binary
+                    reshaped_input = (reshaped_input + 1) / 2
+                    export_idt = DataType.BINARY
+                else:
+                    export_idt = self.get_input_datatype()
+                # make copy before saving the array
+                reshaped_input = reshaped_input.copy()
+                np.save(
+                    os.path.join(code_gen_dir, "input_{}.npy".format(in_ind)),
+                    reshaped_input,
+                )
+            elif in_ind > 2:
+                raise Exception("Unexpected input found for StreamingFCLayer")
+            in_ind += 1
+
+        if mode == "cppsim":
+            # execute the precompiled model
+            super().exec_precompiled_singlenode_model()
+            # load output npy file
+            super().npy_to_dynamic_output(context)
+            # reinterpret binary output as bipolar where needed
+            if self.get_output_datatype() == DataType.BIPOLAR:
+                out = context[node.output[0]]
+                out = 2 * out - 1
+                context[node.output[0]] = out
+            assert (
+                context[node.output[0]].shape == self.get_folded_output_shape()
+            ), """Output shape is not as expected"""
+            # reshape output to have expected shape
+            oshape = self.get_normal_output_shape()
+            context[node.output[0]] = context[node.output[0]].reshape(*oshape)
+        elif mode == "rtlsim":
+            sim = self.get_rtlsim()
+            nbits = self.get_instream_width()
+            inp = npy_to_rtlsim_input(
+                "{}/input_0.npy".format(code_gen_dir), export_idt, nbits
+            )
+            super().reset_rtlsim(sim)
+            super().toggle_clk(sim)
+            output = self.rtlsim(sim, inp)
+            odt = self.get_output_datatype()
+            target_bits = odt.bitwidth()
+            packed_bits = self.get_outstream_width()
+            out_npy_path = "{}/output.npy".format(code_gen_dir)
+            out_shape = self.get_folded_output_shape()
+            rtlsim_output_to_npy(
+                output, out_npy_path, odt, out_shape, packed_bits, target_bits
+            )
+
+            # load and reshape output
+            output = np.load(out_npy_path)
+            oshape = self.get_normal_output_shape()
+            output = np.asarray([output], dtype=np.float32).reshape(*oshape)
+            context[node.output[0]] = output
+        else:
+            raise Exception(
+                """Invalid value for attribute exec_mode! Is currently set to: {}
+            has to be set to one of the following value ("cppsim", "rtlsim")""".format(
+                    mode
+                )
+            )
+
+    def global_includes(self):
+        self.code_gen_dict["$GLOBALS$"] = ['#include "activations.hpp"']
+        self.code_gen_dict["$GLOBALS$"] += ['#include "thresh.h"']
+
+    # TODO check and add whatever missing
+    def defines(self, var):
+        numInputVectors = list(self.get_nodeattr("numInputVectors"))
+        numReps = numInputVectors[0]
+        self.code_gen_dict["$DEFINES$"] = [
+            """#define NumChannels1 {}\n #define PE1 {}\n #define numReps {}""".format(
+                self.get_nodeattr("NumChannels"), self.get_nodeattr("PE"), numReps,
+            )
+        ]
+
+    def read_npy_data(self):
+        code_gen_dir = self.get_nodeattr("code_gen_dir_cppsim")
+        dtype = self.get_input_datatype()
+        elem_bits = dtype.bitwidth()
+        packed_bits = self.get_instream_width()
+        packed_hls_type = "ap_uint<%d>" % packed_bits
+        elem_hls_type = dtype.get_hls_datatype_str()
+        npy_type = "float"
+        npy_in = "%s/input_0.npy" % code_gen_dir
+        self.code_gen_dict["$READNPYDATA$"] = []
+        # note: the innermost dim is reversed for the input
+        self.code_gen_dict["$READNPYDATA$"].append(
+            'npy2apintstream<%s, %s, %d, %s>("%s", in0, false);'
+            % (packed_hls_type, elem_hls_type, elem_bits, npy_type, npy_in)
+        )
+
+    def strm_decl(self):
+        self.code_gen_dict["$STREAMDECLARATIONS$"] = []
+        self.code_gen_dict["$STREAMDECLARATIONS$"].append(
+            'hls::stream<ap_uint<{}>> in0 ("in0");'.format(self.get_instream_width())
+        )
+        self.code_gen_dict["$STREAMDECLARATIONS$"].append(
+            'hls::stream<ap_uint<{}>> out ("out");'.format(self.get_outstream_width())
+        )
+
+    def docompute(self):
+        tmpl_args = self.get_template_param_values()
+        # TODO: why put some template parameters into defines and not others?
+        # should ImgDim be defined or just filled in here like we do now?
+        node = self.onnx_node
+        ishape = self.get_folded_input_shape()
+        if len(ishape) == 3:
+            imgdim = 1
+        elif len(ishape) == 5:
+            imgdim = ishape[1]
+        else:
+            raise Exception("""Unexpeted input shape""")
+        self.code_gen_dict["$DOCOMPUTE$"] = [
+            """{}<{}, NumChannels1, PE1, {}, {}>
+            (in0, out, threshs, numReps);""".format(
+                node.op_type, imgdim, tmpl_args["TSrcI"], tmpl_args["TDstI"],
+            )
+        ]
+
+    def dataoutstrm(self):
+        code_gen_dir = self.get_nodeattr("code_gen_dir_cppsim")
+        dtype = self.get_output_datatype()
+        if dtype == DataType.BIPOLAR:
+            # use binary for bipolar storage
+            dtype = DataType.BINARY
+        elem_bits = dtype.bitwidth()
+        packed_bits = self.get_outstream_width()
+        packed_hls_type = "ap_uint<%d>" % packed_bits
+        elem_hls_type = dtype.get_hls_datatype_str()
+        npy_type = "float"
+        npy_out = "%s/output.npy" % code_gen_dir
+        shape = self.get_folded_output_shape()
+        shape_cpp_str = str(shape).replace("(", "{").replace(")", "}")
+
+        # note: the innermost dim is not reversed for the output
+        self.code_gen_dict["$DATAOUTSTREAM$"] = [
+            'apintstream2npy<%s, %s, %d, %s>(out, %s, "%s", false);'
+            % (
+                packed_hls_type,
+                elem_hls_type,
+                elem_bits,
+                npy_type,
+                shape_cpp_str,
+                npy_out,
+            )
+        ]
+
+    def save_as_npy(self):
+        self.code_gen_dict["$SAVEASCNPY$"] = []
+
+    def blackboxfunction(self):
+        self.code_gen_dict["$BLACKBOXFUNCTION$"] = [
+            """void {}(hls::stream<ap_uint<{}>> &in0,
+                hls::stream<ap_uint<{}>> &out
+                )""".format(
+                self.onnx_node.name,
+                self.get_instream_width(),
+                self.get_outstream_width(),
+            )
+        ]
+
+    def pragmas(self):
+        self.code_gen_dict["$PRAGMAS$"] = ["#pragma HLS INTERFACE axis port=in0"]
+        self.code_gen_dict["$PRAGMAS$"].append("#pragma HLS INTERFACE axis port=out")
+        self.code_gen_dict["$PRAGMAS$"].append(
+            "#pragma HLS INTERFACE ap_ctrl_none port=return"
+        )
+
+        # the threshold tensor is acc_type [PE][TMEM][N_THRES]
+        # partition for parallel access along PE and N_THRES
+        # dimensions (dims 1 and 3)
+        self.code_gen_dict["$PRAGMAS$"].append(
+            (
+                "#pragma HLS ARRAY_PARTITION variable=threshs.m_thresholds "
+                "complete dim=1"
+            )
+        )
+        self.code_gen_dict["$PRAGMAS$"].append(
+            (
+                "#pragma HLS ARRAY_PARTITION variable=threshs.m_thresholds "
+                "complete dim=3"
+            )
+        )
+        # set resource type
+        ram_style = self.get_nodeattr("ram_style")
+        pe = self.get_nodeattr("PE")
+        ich = self.get_nodeattr("NumChannels")
+        # if PE less than NumChannels, assign cores according to ram_style;
+        # otherwise if PE == NumChannels, Vivado HLS will unroll to FFs
+        if pe < ich:
+            if ram_style == "distributed":
+                self.code_gen_dict["$PRAGMAS$"].append(
+                    (
+                        "#pragma HLS RESOURCE variable=threshs.m_thresholds "
+                        "core=ROM_2P_LUTRAM"
+                    )
+                )
+            elif ram_style == "block":
+                self.code_gen_dict["$PRAGMAS$"].append(
+                    (
+                        "#pragma HLS RESOURCE variable=threshs.m_thresholds "
+                        "core=ROM_2P_BRAM"
+                    )
+                )
+            else:
+                raise Exception(
+                    """Invalid value for attribute ram_style! Is currently set to: {}
+                has to be set to one of ("block", "distributed")""".format(
+                        ram_style
+                    )
+                )
diff --git a/src/finn/custom_op/registry.py b/src/finn/custom_op/registry.py
index ac1f49ee26d1c23bd1b0e67ae4ba7e0c2b55b435..0d62862c222b44d2e507a90a80bfcd4fa405d3fe 100644
--- a/src/finn/custom_op/registry.py
+++ b/src/finn/custom_op/registry.py
@@ -44,8 +44,10 @@ from finn.custom_op.fpgadataflow.streamingdatawidthconverter_batch import (
     StreamingDataWidthConverter_Batch,
 )
 from finn.custom_op.fpgadataflow.globalaccpool_batch import GlobalAccPool_Batch
+from finn.custom_op.fpgadataflow.thresholding_batch import Thresholding_Batch
 from finn.custom_op.fpgadataflow.addstreams_batch import AddStreams_Batch
 from finn.custom_op.fpgadataflow.labelselect_batch import LabelSelect_Batch
+from finn.custom_op.fpgadataflow.duplicatestreams_batch import DuplicateStreams_Batch
 
 # create a mapping of all known CustomOp names and classes
 custom_op = {}
@@ -62,8 +64,10 @@ custom_op["MaxPoolNHWC"] = MaxPoolNHWC
 custom_op["StreamingDataWidthConverter_Batch"] = StreamingDataWidthConverter_Batch
 custom_op["StreamingFIFO"] = StreamingFIFO
 custom_op["GlobalAccPool_Batch"] = GlobalAccPool_Batch
+custom_op["Thresholding_Batch"] = Thresholding_Batch
 custom_op["AddStreams_Batch"] = AddStreams_Batch
 custom_op["LabelSelect_Batch"] = LabelSelect_Batch
+custom_op["DuplicateStreams_Batch"] = DuplicateStreams_Batch
 
 
 def getCustomOp(node):
diff --git a/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py b/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py
index dbd98623c4cdf5baca9fa9c137debf8be0f70981..3ff86cab48d365c10e69bc2c764e8083c6a36880 100644
--- a/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py
+++ b/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py
@@ -33,6 +33,7 @@ from finn.transformation import Transformation
 from finn.custom_op.registry import getCustomOp
 from finn.transformation.infer_shapes import InferShapes
 from finn.transformation.infer_datatypes import InferDataTypes
+import finn.core.data_layout as DataLayout
 
 
 class InferConvInpGen(Transformation):
@@ -398,3 +399,59 @@ class InferQuantizedStreamingFCLayer(Transformation):
             model = model.transform(InferShapes())
             model = model.transform(InferDataTypes())
         return (model, graph_modified)
+
+
+class InferThresholdingLayer(Transformation):
+    """Convert any MultiThreshold into a standalone thresholding HLS layer."""
+
+    def apply(self, model):
+        graph = model.graph
+        node_ind = 0
+        graph_modified = False
+        for node in graph.node:
+            node_ind += 1
+            if node.op_type == "MultiThreshold":
+                thl_input = node.input[0]
+                thl_threshold = node.input[1]
+                thl_output = node.output[0]
+                thl_in_shape = model.get_tensor_shape(thl_input)
+                idt = model.get_tensor_datatype(thl_input)
+
+                # skip conversion for layers with float input
+                if not idt.is_integer():
+                    continue
+
+                # skip conversion if input is not NHWC or NC
+                thl_in_layout = model.get_tensor_layout(thl_input)
+                if thl_in_layout != DataLayout.NHWC and thl_in_layout != DataLayout.NC:
+                    continue
+
+                # now safe to assume number of channels is in last dimension
+                ifc = int(thl_in_shape[-1])
+                # create node with no parallelization first
+                pe = 1
+                assert ifc % pe == 0, "Requirement IFC divisable by PE is violated."
+
+                odt = model.get_tensor_datatype(thl_output)
+                # create and insert new StreamingFCLayer node
+                new_node = helper.make_node(
+                    "Thresholding_Batch",
+                    [thl_input, thl_threshold],
+                    [thl_output],
+                    domain="finn",
+                    backend="fpgadataflow",
+                    NumChannels=ifc,
+                    PE=pe,
+                    inputDataType=idt.name,
+                    outputDataType=odt.name,
+                    numInputVectors=list(thl_in_shape[:-1]),
+                )
+                graph.node.insert(node_ind, new_node)
+                # remove old node
+                graph.node.remove(node)
+                graph_modified = True
+
+        if graph_modified:
+            model = model.transform(InferShapes())
+            model = model.transform(InferDataTypes())
+        return (model, graph_modified)
diff --git a/src/finn/util/fpgadataflow.py b/src/finn/util/fpgadataflow.py
index 7b66d092107c27decca68926a0667333bebedbe0..d1669444e55cb0fddb2690e51849c4603d47d32c 100644
--- a/src/finn/util/fpgadataflow.py
+++ b/src/finn/util/fpgadataflow.py
@@ -127,3 +127,91 @@ def is_fpgadataflow_node(node):
                     is_node = True
 
     return is_node
+
+
+def rtlsim_multi_io(sim, io_dict, num_out_values, trace_file=""):
+    """Runs the pyverilator simulation by passing the input values to the simulation,
+    toggle the clock and observing the execution time. Function contains also an
+    observation loop that can abort the simulation if no output value is produced
+    after a set number of cycles. Can handle multiple i/o streams. See function
+    implementation for details on how the top-level signals should be named.
+
+    sim: the PyVerilator object for simulation
+    io_dict: a dict of dicts in the following format:
+            {"inputs" : {"in0" : <input_data>, "in1" : <input_data>},
+             "outputs" : {"out0" : [], "out1" : []} }
+            <input_data> is a list of Python arbitrary-precision ints indicating
+            what data to push into the simulation, and the output lists are
+            similarly filled when the simulation is complete
+    num_out_values: number of total values to be read from the simulation to
+                    finish the simulation and return.
+
+    returns: number of clock cycles elapsed for completion
+
+    """
+
+    if trace_file != "":
+        sim.start_vcd_trace(trace_file)
+
+    for outp in io_dict["outputs"]:
+        sim.io[outp + "_V_V_TREADY"] = 1
+
+    # observe if output is completely calculated
+    # total_cycle_count will contain the number of cycles the calculation ran
+    output_done = False
+    total_cycle_count = 0
+    output_count = 0
+    old_output_count = 0
+
+    # avoid infinite looping of simulation by aborting when there is no change in
+    # output values after 100 cycles
+    no_change_count = 0
+    liveness_threshold = pyverilate_get_liveness_threshold_cycles()
+
+    while not (output_done):
+        for inp in io_dict["inputs"]:
+            inputs = io_dict["inputs"][inp]
+            sim.io[inp + "_V_V_TVALID"] = 1 if len(inputs) > 0 else 0
+            sim.io[inp + "_V_V_TDATA"] = inputs[0] if len(inputs) > 0 else 0
+            if sim.io[inp + "_V_V_TREADY"] == 1 and sim.io[inp + "_V_V_TVALID"] == 1:
+                inputs = inputs[1:]
+            io_dict["inputs"][inp] = inputs
+
+        for outp in io_dict["outputs"]:
+            outputs = io_dict["outputs"][outp]
+            if sim.io[outp + "_V_V_TVALID"] == 1 and sim.io[outp + "_V_V_TREADY"] == 1:
+                outputs = outputs + [sim.io[outp + "_V_V_TDATA"]]
+                output_count += 1
+            io_dict["outputs"][outp] = outputs
+
+        sim.io.ap_clk = 1
+        sim.io.ap_clk = 0
+
+        total_cycle_count = total_cycle_count + 1
+
+        if output_count == old_output_count:
+            no_change_count = no_change_count + 1
+        else:
+            no_change_count = 0
+            old_output_count = output_count
+
+        # check if all expected output words received
+        if output_count == num_out_values:
+            output_done = True
+
+        # end sim on timeout
+        if no_change_count == liveness_threshold:
+            if trace_file != "":
+                sim.flush_vcd_trace()
+                sim.stop_vcd_trace()
+            raise Exception(
+                "Error in simulation! Takes too long to produce output. "
+                "Consider setting the LIVENESS_THRESHOLD env.var. to a "
+                "larger value."
+            )
+
+    if trace_file != "":
+        sim.flush_vcd_trace()
+        sim.stop_vcd_trace()
+
+    return total_cycle_count
diff --git a/tests/fpgadataflow/test_convert_to_hls_layers_cnv.py b/tests/fpgadataflow/test_convert_to_hls_layers_cnv.py
index e03090f0581eebf68cac7baffb6888a6992df68d..48803c9614f53a3a149c6eaac4289d10086513a5 100644
--- a/tests/fpgadataflow/test_convert_to_hls_layers_cnv.py
+++ b/tests/fpgadataflow/test_convert_to_hls_layers_cnv.py
@@ -39,6 +39,7 @@ from finn.core.modelwrapper import ModelWrapper
 from finn.transformation.fold_constants import FoldConstants
 from finn.transformation.general import GiveReadableTensorNames, GiveUniqueNodeNames
 from finn.transformation.infer_shapes import InferShapes
+from finn.transformation.infer_data_layouts import InferDataLayouts
 from finn.transformation.streamline import Streamline
 from finn.util.test import get_test_model_trained
 from finn.transformation.double_to_single_float import DoubleToSingleFloat
@@ -54,7 +55,9 @@ export_onnx_path_cnv = "test_output_cnv.onnx"
 
 
 @pytest.mark.vivado
-def test_convert_to_hls_layers_cnv_w1a1():
+# Standalone or fused thresholding-based activation
+@pytest.mark.parametrize("fused_activation", [True, False])
+def test_convert_to_hls_layers_cnv_w1a1(fused_activation):
     cnv = get_test_model_trained("CNV", 1, 1)
     bo.export_finn_onnx(cnv, (1, 3, 32, 32), export_onnx_path_cnv)
     model = ModelWrapper(export_onnx_path_cnv)
@@ -69,6 +72,7 @@ def test_convert_to_hls_layers_cnv_w1a1():
     model = model.transform(absorb.AbsorbTransposeIntoMultiThreshold())
     model = model.transform(ConvertBipolarMatMulToXnorPopcount())
     model = model.transform(Streamline())
+    model = model.transform(InferDataLayouts())
     # model.save("golden.onnx")
     # load one of the test vectors
     fn = pk.resource_filename("finn", "data/cifar10/cifar10-test-data-class3.npz")
@@ -80,6 +84,10 @@ def test_convert_to_hls_layers_cnv_w1a1():
     expected_ctx = oxe.execute_onnx(model, input_dict, True)
     expected = expected_ctx[model.graph.output[0].name]
 
+    # if we infer thresholding first, all MultiThresholds get converted to HLS
+    # subsequently, the FC inference will generate passthrough MVAUs
+    if not fused_activation:
+        model = model.transform(to_hls.InferThresholdingLayer())
     model = model.transform(to_hls.InferBinaryStreamingFCLayer())
     model = model.transform(to_hls.InferQuantizedStreamingFCLayer())
     for node in model.graph.node:
@@ -102,7 +110,12 @@ def test_convert_to_hls_layers_cnv_w1a1():
     model = model.transform(to_hls.InferStreamingMaxPool())
     # check topology status
     finn_nodes = model.get_finn_nodes()
-    assert len(finn_nodes) == 18
+    if fused_activation:
+        assert len(finn_nodes) == 18
+    else:
+        assert len(finn_nodes) == 26
+        thr_nodes = model.get_nodes_by_op_type("Thresholding_Batch")
+        assert len(thr_nodes) == 8
     non_finn_nodes = model.get_non_finn_nodes()
     assert len(non_finn_nodes) == 4
     exp_non_finn_nodes = ["Transpose", "Reshape", "Mul", "Add"]
diff --git a/tests/fpgadataflow/test_fpgadataflow_duplicatestreams.py b/tests/fpgadataflow/test_fpgadataflow_duplicatestreams.py
new file mode 100644
index 0000000000000000000000000000000000000000..4fb84be59333ef0e696204c9064fcf77e35b5d9b
--- /dev/null
+++ b/tests/fpgadataflow/test_fpgadataflow_duplicatestreams.py
@@ -0,0 +1,127 @@
+# Copyright (c) 2020, Xilinx
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+#   list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+#   this list of conditions and the following disclaimer in the documentation
+#   and/or other materials provided with the distribution.
+#
+# * Neither the name of FINN nor the names of its
+#   contributors may be used to endorse or promote products derived from
+#   this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import pytest
+
+from onnx import TensorProto, helper
+
+import finn.core.onnx_exec as oxe
+from finn.core.datatype import DataType
+from finn.core.modelwrapper import ModelWrapper
+from finn.transformation.fpgadataflow.prepare_ip import PrepareIP
+from finn.transformation.fpgadataflow.prepare_cppsim import PrepareCppSim
+from finn.transformation.fpgadataflow.compile_cppsim import CompileCppSim
+from finn.transformation.fpgadataflow.hlssynth_ip import HLSSynthIP
+from finn.transformation.fpgadataflow.set_exec_mode import SetExecMode
+from finn.transformation.general import GiveUniqueNodeNames
+from finn.transformation.fpgadataflow.prepare_rtlsim import PrepareRTLSim
+from finn.util.basic import gen_finn_dt_tensor
+from finn.transformation.fpgadataflow.replace_verilog_relpaths import (
+    ReplaceVerilogRelPaths,
+)
+
+
+def make_dupstreams_modelwrapper(ch, pe, idim, idt):
+    shape = [1, idim, idim, ch]
+    inp = helper.make_tensor_value_info("inp", TensorProto.FLOAT, shape)
+    outp0 = helper.make_tensor_value_info("outp0", TensorProto.FLOAT, shape)
+    outp1 = helper.make_tensor_value_info("outp1", TensorProto.FLOAT, shape)
+
+    dupstrm_node = helper.make_node(
+        "DuplicateStreams_Batch",
+        ["inp"],
+        ["outp0", "outp1"],
+        domain="finn",
+        backend="fpgadataflow",
+        NumChannels=ch,
+        PE=pe,
+        inputDataType=idt.name,
+        numInputVectors=[1, idim, idim],
+    )
+    graph = helper.make_graph(
+        nodes=[dupstrm_node], name="graph", inputs=[inp], outputs=[outp0, outp1]
+    )
+
+    model = helper.make_model(graph, producer_name="addstreams-model")
+    model = ModelWrapper(model)
+
+    model.set_tensor_datatype("inp", idt)
+
+    return model
+
+
+def prepare_inputs(input_tensor, idt):
+    return {"inp": input_tensor}
+
+
+# data type
+@pytest.mark.parametrize("idt", [DataType.INT4, DataType.UINT16])
+# channels
+@pytest.mark.parametrize("ch", [64])
+# folding
+@pytest.mark.parametrize("fold", [-1, 2, 1])
+# image dimension
+@pytest.mark.parametrize("imdim", [7])
+# execution mode
+@pytest.mark.parametrize("exec_mode", ["cppsim", "rtlsim"])
+@pytest.mark.vivado
+def test_fpgadataflow_duplicatestreams(idt, ch, fold, imdim, exec_mode):
+    if fold == -1:
+        pe = 1
+    else:
+        pe = ch // fold
+    assert ch % pe == 0
+
+    # generate input data
+    x = gen_finn_dt_tensor(idt, (1, imdim, imdim, ch))
+
+    model = make_dupstreams_modelwrapper(ch, pe, imdim, idt)
+
+    if exec_mode == "cppsim":
+        model = model.transform(PrepareCppSim())
+        model = model.transform(CompileCppSim())
+        model = model.transform(SetExecMode("cppsim"))
+    elif exec_mode == "rtlsim":
+        model = model.transform(SetExecMode("rtlsim"))
+        model = model.transform(GiveUniqueNodeNames())
+        model = model.transform(PrepareIP("xc7z020clg400-1", 5))
+        model = model.transform(HLSSynthIP())
+        model = model.transform(ReplaceVerilogRelPaths())
+        model = model.transform(PrepareRTLSim())
+    else:
+        raise Exception("Unknown exec_mode")
+
+    # prepare input data and execute
+    input_dict = prepare_inputs(x, idt)
+    output_dict = oxe.execute_onnx(model, input_dict)
+    y0 = output_dict["outp0"]
+    y1 = output_dict["outp1"]
+    expected_y = x
+
+    assert (y0 == expected_y).all(), exec_mode + " failed"
+    assert (y1 == expected_y).all(), exec_mode + " failed"
diff --git a/tests/fpgadataflow/test_fpgadataflow_thresholding.py b/tests/fpgadataflow/test_fpgadataflow_thresholding.py
new file mode 100644
index 0000000000000000000000000000000000000000..50b990f13494f22e985406791445b406e9946147
--- /dev/null
+++ b/tests/fpgadataflow/test_fpgadataflow_thresholding.py
@@ -0,0 +1,154 @@
+# Copyright (c) 2020, Xilinx
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+#   list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+#   this list of conditions and the following disclaimer in the documentation
+#   and/or other materials provided with the distribution.
+#
+# * Neither the name of FINN nor the names of its
+#   contributors may be used to endorse or promote products derived from
+#   this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import pytest
+
+import numpy as np
+from onnx import TensorProto, helper
+
+import finn.core.onnx_exec as oxe
+from finn.analysis.fpgadataflow.hls_synth_res_estimation import hls_synth_res_estimation
+from finn.core.datatype import DataType
+from finn.core.modelwrapper import ModelWrapper
+from finn.custom_op.multithreshold import multithreshold
+from finn.transformation.fpgadataflow.prepare_ip import PrepareIP
+from finn.transformation.fpgadataflow.prepare_cppsim import PrepareCppSim
+from finn.transformation.fpgadataflow.compile_cppsim import CompileCppSim
+from finn.transformation.fpgadataflow.hlssynth_ip import HLSSynthIP
+from finn.transformation.fpgadataflow.set_exec_mode import SetExecMode
+from finn.transformation.general import GiveUniqueNodeNames
+from finn.transformation.fpgadataflow.prepare_rtlsim import PrepareRTLSim
+from finn.util.basic import gen_finn_dt_tensor
+from finn.transformation.fpgadataflow.replace_verilog_relpaths import (
+    ReplaceVerilogRelPaths,
+)
+
+
+def make_single_thresholding_modelwrapper(T, pe, idt, odt):
+    NumChannels = T.shape[0]
+
+    inp = helper.make_tensor_value_info("inp", TensorProto.FLOAT, [1, NumChannels])
+    outp = helper.make_tensor_value_info("outp", TensorProto.FLOAT, [1, NumChannels])
+
+    node_inp_list = ["inp", "thresh"]
+
+    Thresholding_node = helper.make_node(
+        "Thresholding_Batch",
+        node_inp_list,
+        ["outp"],
+        domain="finn",
+        backend="fpgadataflow",
+        NumChannels=NumChannels,
+        PE=pe,
+        inputDataType=idt.name,
+        outputDataType=odt.name,
+    )
+    graph = helper.make_graph(
+        nodes=[Thresholding_node],
+        name="thresholding_graph",
+        inputs=[inp],
+        outputs=[outp],
+    )
+
+    model = helper.make_model(graph, producer_name="thresholding-model")
+    model = ModelWrapper(model)
+
+    model.set_tensor_datatype("inp", idt)
+    model.set_tensor_datatype("outp", odt)
+
+    model.set_tensor_datatype("thresh", idt)
+    model.set_initializer("thresh", T)
+    return model
+
+
+# activation: None or DataType
+@pytest.mark.parametrize("act", [DataType.INT4, DataType.BIPOLAR])
+# input datatype
+@pytest.mark.parametrize("idt", [DataType.INT16, DataType.UINT16])
+# folding, -1 is maximum possible
+@pytest.mark.parametrize("nf", [-1, 2, 1])
+# number of input features
+@pytest.mark.parametrize("ich", [16])
+# execution mode
+@pytest.mark.parametrize("exec_mode", ["cppsim", "rtlsim"])
+@pytest.mark.vivado
+@pytest.mark.slow
+def test_fpgadataflow_thresholding(idt, act, nf, ich, exec_mode):
+    if nf == -1:
+        nf = ich
+    pe = ich // nf
+    assert ich % pe == 0
+
+    # generate input data
+    x = gen_finn_dt_tensor(idt, (1, ich))
+
+    odt = act
+    n_steps = act.get_num_possible_values() - 1
+    T = np.random.randint(idt.min(), idt.max() + 1, (ich, n_steps)).astype(np.float32)
+    # provide non-decreasing thresholds
+    T = np.sort(T, axis=1)
+
+    model = make_single_thresholding_modelwrapper(T, pe, idt, odt)
+
+    if exec_mode == "cppsim":
+        model = model.transform(PrepareCppSim())
+        model = model.transform(CompileCppSim())
+        model = model.transform(SetExecMode("cppsim"))
+    elif exec_mode == "rtlsim":
+        model = model.transform(SetExecMode("rtlsim"))
+        model = model.transform(GiveUniqueNodeNames())
+        model = model.transform(PrepareIP("xc7z020clg400-1", 5))
+        model = model.transform(HLSSynthIP())
+        model = model.transform(ReplaceVerilogRelPaths())
+        model = model.transform(PrepareRTLSim())
+    else:
+        raise Exception("Unknown exec_mode")
+
+    # package input data as dictionary
+    input_dict = {"inp": x}
+
+    y = multithreshold(x, T)
+    if act == DataType.BIPOLAR:
+        # binary to bipolar
+        y = 2 * y - 1
+    else:
+        # signed offset
+        y += act.min()
+
+    oshape = model.get_tensor_shape("outp")
+    y_expected = y.reshape(oshape)
+    # execute model
+    y_produced = oxe.execute_onnx(model, input_dict)["outp"]
+
+    y_produced = y_produced.reshape(y_expected.shape)
+
+    assert (y_produced == y_expected).all(), "cppsim failed"
+
+    if exec_mode == "rtlsim":
+        hls_synt_res_est = model.analysis(hls_synth_res_estimation)
+        assert "Thresholding_Batch_0" in hls_synt_res_est