diff --git a/docker/finn_entrypoint.sh b/docker/finn_entrypoint.sh
index 80e5261c9b6568e0f340ae0add86d269c036ff16..b312737c317517ca0ab19c74cf22284b5977b661 100644
--- a/docker/finn_entrypoint.sh
+++ b/docker/finn_entrypoint.sh
@@ -13,9 +13,9 @@ gecho () {
 
 # checkout the correct dependency repo commits
 # the repos themselves are cloned in the Dockerfile
-BREVITAS_COMMIT=026a509186b7e7b0b65d46a2f905043d41069306
+BREVITAS_COMMIT=f9a27226d4acf1661dd38bc449f71f89e0983cce
 CNPY_COMMIT=4e8810b1a8637695171ed346ce68f6984e585ef4
-HLSLIB_COMMIT=8aed899c278c36c977a249558d71795086cf852c
+HLSLIB_COMMIT=8f9f2018762f654f196b666838aeaf6fc730ad9a
 PYVERILATOR_COMMIT=c97a5ba41bbc7c419d6f25c74cdf3bdc3393174f
 PYNQSHELL_COMMIT=0c82a61b0ec1a07fa275a14146233824ded7a13d
 OMX_COMMIT=1bae737669901e762f581af73348332b5c4b2ada
diff --git a/src/finn/custom_op/fpgadataflow/channelwise_op_batch.py b/src/finn/custom_op/fpgadataflow/channelwise_op_batch.py
index 027524dfdc3fdd45a37892bd1b0a510b5b3866a7..ad68a4bde29123b2498ac7789048bcd2e13bf3bc 100644
--- a/src/finn/custom_op/fpgadataflow/channelwise_op_batch.py
+++ b/src/finn/custom_op/fpgadataflow/channelwise_op_batch.py
@@ -41,18 +41,18 @@ from finn.util.data_packing import (
 )
 from . import templates
 
-# ONNX i/o tensor shape assumptions for Thresholding:
+# ONNX i/o tensor shape assumptions for channelwise ops:
 # input 0 is the input tensor, shape (..., NumChannels)
-# input 1 is the threshold tensor, shape (NumChannels, n_thres)
+# input 1 is the channelwise parameter tensor, shape (NumChannels, params_per_channel)
 # output 0 is the output tensor, shape (..., NumChannels) - same as input
 # the ... here can be any shape (representing groups of vectors)
 
-# by setting Func appropriately, this function can implement
-# any channel-wise operation, including Add, Mul, Thresholding
-
 
 class ChannelwiseOp_Batch(HLSCustomOp):
-    """Class that corresponds to finn-hls Thresholding_Batch function."""
+    """Class that corresponds to finn-hls Thresholding_Batch function.
+    It can implement a variety of channel-wise parametrized operations,
+    including Add, Mul and multi-thresholding.
+    """
 
     def __init__(self, onnx_node):
         super().__init__(onnx_node)
@@ -60,13 +60,16 @@ class ChannelwiseOp_Batch(HLSCustomOp):
 
     def get_nodeattr_types(self):
         my_attrs = {
+            # channelwise "map" function to apply:
+            # one of cmp_le, cmp_ge, add, mul
             "Func": ("s", False, "cmp_le"),
             "PE": ("i", True, 0),
             "NumChannels": ("i", True, 0),
-            # string defining memory type
+            # string defining memory resource type for parameters
             "ram_style": ("s", False, "distributed"),
             # FINN DataTypes for inputs, weights, outputs
             "inputDataType": ("s", True, ""),
+            "paramDataType": ("s", True, ""),
             "outputDataType": ("s", True, ""),
             # input and output FIFO depths
             "inFIFODepth": ("i", False, 0),
@@ -81,7 +84,8 @@ class ChannelwiseOp_Batch(HLSCustomOp):
         return my_attrs
 
     def calc_tmem(self):
-        """Calculates and returns TMEM."""
+        """Calculates and returns TMEM, the depth of the memory used
+        to store the channelwise op parameters."""
         chn = self.get_nodeattr("NumChannels")
         pe = self.get_nodeattr("PE")
         return chn // pe
@@ -107,7 +111,8 @@ class ChannelwiseOp_Batch(HLSCustomOp):
         # check input datatype against property
         idt_name = self.get_input_datatype().name
         exp_idt_name = self.get_nodeattr("inputDataType")
-        assert exp_idt_name == idt_name, "Bad input DataType for Thresholding layer"
+        assert exp_idt_name == idt_name, "Bad input DataType for ChannelwiseOp layer"
+        # TODO: dynamically infer/update odt based on idt as done in ConvertToHLSLayers?
         # set output datatype from property
         odt = self.get_output_datatype()
         model.set_tensor_datatype(node.output[0], odt)
@@ -136,6 +141,7 @@ class ChannelwiseOp_Batch(HLSCustomOp):
             self.get_nodeattr("NumChannels")
             self.get_nodeattr("PE")
             self.get_nodeattr("inputDataType")
+            self.get_nodeattr("paramDataType")
             self.get_nodeattr("outputDataType")
             info_messages.append("All necessary attributes exist")
         except Exception:
@@ -278,18 +284,8 @@ class ChannelwiseOp_Batch(HLSCustomOp):
         code_gen_dir = path
         # save thresholds in params.h
         parameters = model.get_initializer(self.onnx_node.input[1])
-
         parameter_tensor = self.get_hls_compatible_parameter_tensor(parameters)
-
-        # determine parameters data type from range of threshold and input tensors
-        p_min = parameters.min()
-        p_max = parameters.max()
-        p_absmax = max(abs(p_min), abs(p_max))
-        if p_min < 0:
-            p_min = min(p_min, -p_absmax - 1)
-            pdt = DataType.get_smallest_possible(p_min)
-        else:
-            pdt = DataType.get_smallest_possible(p_max)
+        pdt = DataType[self.get_nodeattr("paramDataType")]
 
         parameters_hls_code = numpy_to_hls_code(
             parameter_tensor, pdt, "parameters", False, True
@@ -534,8 +530,8 @@ class ChannelwiseOp_Batch(HLSCustomOp):
             "#pragma HLS INTERFACE ap_ctrl_none port=return"
         )
 
-        # the threshold tensor is acc_type [PE][TMEM][N_THRES]
-        # partition for parallel access along PE and N_THRES
+        # the channelwise parameter tensor is acc_type [PE][TMEM][N_PARAMS_PER_CHANNEL]
+        # partition for parallel access along PE and N_PARAMS_PER_CHANNEL
         # dimensions (dims 1 and 3)
         self.code_gen_dict["$PRAGMAS$"].append(
             (
diff --git a/src/finn/custom_op/fpgadataflow/pool_batch.py b/src/finn/custom_op/fpgadataflow/pool_batch.py
new file mode 100644
index 0000000000000000000000000000000000000000..c7edc24d0e24eef1154293caca2519ab3aa68358
--- /dev/null
+++ b/src/finn/custom_op/fpgadataflow/pool_batch.py
@@ -0,0 +1,395 @@
+# Copyright (c) 2020, Xilinx
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+#   list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+#   this list of conditions and the following disclaimer in the documentation
+#   and/or other materials provided with the distribution.
+#
+# * Neither the name of FINN nor the names of its
+#   contributors may be used to endorse or promote products derived from
+#   this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import os
+import numpy as np
+
+from finn.custom_op.fpgadataflow import HLSCustomOp
+from finn.core.datatype import DataType
+from onnx import TensorProto, helper
+from finn.util.data_packing import npy_to_rtlsim_input, rtlsim_output_to_npy
+
+
+class Pool_Batch(HLSCustomOp):
+    """Class that corresponds to finn-hlslib Pool_batch function.
+    Requires ConvolutionInputGenerator(depthwise == 1) to format its input
+
+    TODO: explain input shape (to reuse im2col code)
+    Input shape (BatchSize,OutImgDim,OutImgDim,KernelSize^2*Channels)
+    Output shape (BatchSize,OutImgDim,OutImgDim,Channels)
+
+    # note: the actual data layout produced by the hlslib kernels is different
+    # for depthwise ops.
+    # * depthwise SWG: (1, OFMDim, OFMDim, IFMChannels/PE, K, K, PE)
+
+    Channels can be folded using PE (SIMD from the input perspective)
+    TODO: doc
+    """
+
+    def get_nodeattr_types(self):
+        my_attrs = {
+            "Channels": ("i", True, 0),
+            "PE": ("i", True, 1),
+            "KernelSize": ("i", True, 0),
+            # Function:
+            #  - MaxPool
+            #  - AvgPool (not yet supported, but HLSLIB does)
+            #  - AccPool (not yet supported, but HLSLIB does)
+            "Function": ("s", True, ""),
+            "OutImgDim": ("i", True, 0),
+            # FINN DataTypes for inputs/outputs
+            "dataType": ("s", True, ""),
+            "BatchSize": ("i", False, 1),
+        }
+
+        my_attrs.update(super().get_nodeattr_types())
+        return my_attrs
+
+    def get_input_datatype(self):
+        """Returns FINN DataType of input."""
+        return DataType[self.get_nodeattr("dataType")]
+
+    def get_output_datatype(self):
+        """Returns FINN DataType of output."""
+        fxn = self.get_nodeattr("Function")
+        if fxn == "MaxPool":
+            # Same as input
+            return DataType[self.get_nodeattr("dataType")]
+        else:
+            raise Exception("Pool_Batch doesn't currently support " + fxn)
+
+    def get_normal_input_shape(self):
+        ifm_ch = self.get_nodeattr("Channels")
+        odim = self.get_nodeattr("OutImgDim")
+        batch_size = self.get_nodeattr("BatchSize")
+        k = self.get_nodeattr("KernelSize")
+        ishape = (batch_size, odim, odim, k * k * ifm_ch)
+        return ishape
+
+    def get_folded_input_shape(self):
+        normal_ishape = list(self.get_normal_input_shape())
+        ifm_ch = self.get_nodeattr("Channels")
+        pe = self.get_nodeattr("PE")
+        assert ifm_ch % pe == 0, "PE must divide input channels"
+        fold = int(normal_ishape[-1] / pe)
+        folded_ishape = normal_ishape[:-1] + [fold, pe]
+        return tuple(folded_ishape)
+
+    def get_normal_output_shape(self):
+        ofm_ch = self.get_nodeattr("Channels")
+        odim = self.get_nodeattr("OutImgDim")
+        batch_size = self.get_nodeattr("BatchSize")
+        oshape = (batch_size, odim, odim, ofm_ch)
+        return oshape
+
+    def get_folded_output_shape(self):
+        normal_oshape = list(self.get_normal_output_shape())
+        ifm_ch = self.get_nodeattr("Channels")
+        pe = self.get_nodeattr("PE")
+        assert ifm_ch % pe == 0, "PE must divide input channels"
+        fold = int(ifm_ch / pe)
+        folded_oshape = normal_oshape[:-1] + [fold, pe]
+        return tuple(folded_oshape)
+
+    def get_number_output_values(self):
+        folded_oshape = self.get_folded_output_shape()
+        return np.prod(folded_oshape[1:-1])
+
+    def get_instream_width(self):
+        dt_bits = self.get_input_datatype().bitwidth()
+        pe = self.get_nodeattr("PE")
+        # ofm_ch = self.get_nodeattr("Channels")
+        # k = self.get_nodeattr("KernelSize")
+        # assert ifm_ch % pe == 0, "PE must divide input channels"
+        # simd = int(ifm_ch/pe)
+        in_width = int(dt_bits * pe)
+        return in_width
+
+    def get_outstream_width(self):
+        fxn = self.get_nodeattr("Function")
+        if fxn == "MaxPool":
+            return self.get_instream_width()
+        else:
+            raise Exception("Pool_Batch doesn't currently support " + fxn)
+
+    def make_shape_compatible_op(self, model):
+        exp_ishape = self.get_normal_input_shape()
+        oshape = self.get_normal_output_shape()
+        ishape = tuple(model.get_tensor_shape(self.onnx_node.input[0]))
+        assert ishape == exp_ishape, "Unexpected input shape for Pool_Batch."
+        # implement tensor with correct shape
+        values = np.random.randn(*oshape).astype(np.float32)
+        return helper.make_node(
+            "Constant",
+            inputs=[],
+            outputs=[self.onnx_node.output[0]],
+            value=helper.make_tensor(
+                name="const_tensor",
+                data_type=TensorProto.FLOAT,
+                dims=values.shape,
+                vals=values.flatten().astype(float),
+            ),
+        )
+
+    def infer_node_datatype(self, model):
+        node = self.onnx_node
+        # data type stays the same
+        dtype = self.get_output_datatype()
+        model.set_tensor_datatype(node.output[0], dtype)
+
+    def verify_node(self):
+        info_messages = []
+
+        # verify that "domain" is set to "finn"
+        domain_value = self.onnx_node.domain
+        if domain_value == "finn":
+            info_messages.append("Attribute domain is set correctly")
+        else:
+            info_messages.append('Attribute domain should be set to "finn"')
+
+        # verify that "backend" is set to "fpgadataflow"
+        backend_value = self.get_nodeattr("backend")
+        if backend_value == "fpgadataflow":
+            info_messages.append("Attribute backend is set correctly")
+        else:
+            info_messages.append('Attribute backend should be set to "fpgadataflow"')
+
+        # verify the number of inputs
+        if len(self.onnx_node.input) == 1:
+            info_messages.append("The number of inputs is correct")
+        else:
+            info_messages.append("""Pool_Batch needs 1 data input""")
+
+        # check supported function
+        fnx = self.get_nodeattr("Function")
+        if fnx == "MaxPool":
+            info_messages.append(
+                "Attribute Function contains a supported pool function"
+            )
+        else:
+            info_messages.append(
+                "Attribute Function contains an unsupported pool function"
+            )
+        return info_messages
+
+    def global_includes(self):
+        self.code_gen_dict["$GLOBALS$"] = ['#include "maxpool.h"']
+        self.code_gen_dict["$GLOBALS$"] += ['#include "pool.hpp"']
+
+    def defines(self, var):
+        self.code_gen_dict["$DEFINES$"] = []
+
+        ifm_ch = self.get_nodeattr("Channels")
+        self.code_gen_dict["$DEFINES$"] += ["#define Channels {}".format(ifm_ch)]
+
+        pe = self.get_nodeattr("PE")
+        self.code_gen_dict["$DEFINES$"] += ["#define PE {}".format(pe)]
+
+        k = self.get_nodeattr("KernelSize")
+        self.code_gen_dict["$DEFINES$"] += ["#define KernelSize {}".format(k)]
+
+        odim = self.get_nodeattr("OutImgDim")
+        self.code_gen_dict["$DEFINES$"] += ["#define OFMDim {}".format(odim)]
+
+        numReps = self.get_nodeattr("BatchSize")
+        self.code_gen_dict["$DEFINES$"] += ["#define numReps {}".format(numReps)]
+
+    def read_npy_data(self):
+        code_gen_dir = self.get_nodeattr("code_gen_dir_cppsim")
+        dtype = self.get_input_datatype()
+        if dtype == DataType.BIPOLAR:
+            # use binary for bipolar storage
+            dtype = DataType.BINARY
+        elem_bits = dtype.bitwidth()
+        packed_bits = self.get_instream_width()
+        packed_hls_type = "ap_uint<%d>" % packed_bits
+        elem_hls_type = dtype.get_hls_datatype_str()
+        npy_type = "float"
+        npy_in = "%s/input_0.npy" % code_gen_dir
+        self.code_gen_dict["$READNPYDATA$"] = []
+        self.code_gen_dict["$READNPYDATA$"].append(
+            'npy2apintstream<%s, %s, %d, %s>("%s", in0,false);'
+            % (packed_hls_type, elem_hls_type, elem_bits, npy_type, npy_in)
+        )
+
+    def strm_decl(self):
+        self.code_gen_dict["$STREAMDECLARATIONS$"] = []
+        self.code_gen_dict["$STREAMDECLARATIONS$"].append(
+            'hls::stream<ap_uint<{}>> in0 ("in0");'.format(self.get_instream_width())
+        )
+        self.code_gen_dict["$STREAMDECLARATIONS$"].append(
+            'hls::stream<ap_uint<{}>> out ("out");'.format(self.get_outstream_width())
+        )
+
+    def docompute(self):
+        idt = self.get_input_datatype()
+        i_hls_dt = idt.get_hls_datatype_str()
+        odt = self.get_output_datatype()
+        o_hls_dt = odt.get_hls_datatype_str()
+
+        self.code_gen_dict["$DOCOMPUTE$"] = []
+
+        fxn = self.get_nodeattr("Function")
+        if fxn == "MaxPool":
+            self.code_gen_dict["$DOCOMPUTE$"] += [
+                "MaxPoolFunction<{},KernelSize> pool_fxn;".format(i_hls_dt)
+            ]
+        else:
+            raise Exception("Pool_Batch doesn't currently support " + fxn)
+
+        self.code_gen_dict["$DOCOMPUTE$"] += [
+            """Pool_batch<Channels, PE, KernelSize,Slice<{} >, Slice< {} > >
+        (in0,out, pool_fxn, OFMDim*OFMDim*numReps);""".format(
+                i_hls_dt, o_hls_dt
+            )
+        ]
+
+    def dataoutstrm(self):
+        code_gen_dir = self.get_nodeattr("code_gen_dir_cppsim")
+        dtype = self.get_output_datatype()
+        if dtype == DataType.BIPOLAR:
+            # use binary for bipolar storage
+            dtype = DataType.BINARY
+        elem_bits = dtype.bitwidth()
+        packed_bits = self.get_outstream_width()
+        packed_hls_type = "ap_uint<%d>" % packed_bits
+        elem_hls_type = dtype.get_hls_datatype_str()
+        npy_type = "float"
+        npy_out = "%s/output.npy" % code_gen_dir
+        oshape = self.get_folded_output_shape()
+        oshape_cpp_str = str(oshape).replace("(", "{").replace(")", "}")
+
+        self.code_gen_dict["$DATAOUTSTREAM$"] = [
+            'apintstream2npy<%s, %s, %d, %s>(out, %s, "%s",false);'
+            % (
+                packed_hls_type,
+                elem_hls_type,
+                elem_bits,
+                npy_type,
+                oshape_cpp_str,
+                npy_out,
+            )
+        ]
+
+    def save_as_npy(self):
+        self.code_gen_dict["$SAVEASCNPY$"] = []
+
+    def blackboxfunction(self):
+        packed_ibits = self.get_instream_width()
+        packed_in_hls_type = "ap_uint<%d>" % packed_ibits
+
+        packed_obits = self.get_outstream_width()
+        packed_out_hls_type = "ap_uint<%d>" % packed_obits
+        self.code_gen_dict["$BLACKBOXFUNCTION$"] = [
+            "void %s(hls::stream<%s > &in0, hls::stream<%s > &out)"
+            % (self.onnx_node.name, packed_in_hls_type, packed_out_hls_type)
+        ]
+
+    def pragmas(self):
+        self.code_gen_dict["$PRAGMAS$"] = ["#pragma HLS INTERFACE axis port=in0"]
+        self.code_gen_dict["$PRAGMAS$"].append("#pragma HLS INTERFACE axis port=out")
+        self.code_gen_dict["$PRAGMAS$"].append(
+            "#pragma HLS INTERFACE ap_ctrl_none port=return"
+        )
+
+    def execute_node(self, context, graph):
+        mode = self.get_nodeattr("exec_mode")
+        node = self.onnx_node
+        exp_ishape = self.get_normal_input_shape()
+        folded_ishape = self.get_folded_input_shape()
+        exp_oshape = self.get_normal_output_shape()
+        folded_oshape = self.get_folded_output_shape()
+
+        # TODO ensure codegen dir exists
+        if mode == "cppsim":
+            code_gen_dir = self.get_nodeattr("code_gen_dir_cppsim")
+        elif mode == "rtlsim":
+            code_gen_dir = self.get_nodeattr("code_gen_dir_ipgen")
+        else:
+            raise Exception(
+                """Invalid value for attribute exec_mode! Is currently set to: {}
+            has to be set to one of the following value ("cppsim", "rtlsim")""".format(
+                    mode
+                )
+            )
+
+        inp = context[node.input[0]]
+
+        assert str(inp.dtype) == "float32", "Input datatype is not float32"
+        assert (
+            inp.shape == exp_ishape
+        ), """Input shape doesn't
+        match expected shape (batch_size,odim,odim,k*k*ifm_ch)."""
+
+        export_idt = self.get_input_datatype()
+        reshaped_input = inp.reshape(folded_ishape)
+
+        np.save(os.path.join(code_gen_dir, "input_0.npy"), reshaped_input)
+
+        if mode == "cppsim":
+            # execute the precompiled model
+            super().exec_precompiled_singlenode_model()
+            # load output npy file
+            super().npy_to_dynamic_output(context)
+            assert (
+                context[node.output[0]].shape == folded_oshape
+            ), "cppsim did not produce expected folded output shape"
+            context[node.output[0]] = context[node.output[0]].reshape(*exp_oshape)
+        elif mode == "rtlsim":
+            sim = self.get_rtlsim()
+            nbits = self.get_instream_width()
+            rtlsim_inp = npy_to_rtlsim_input(
+                "{}/input_0.npy".format(code_gen_dir), export_idt, nbits
+            )
+            super().reset_rtlsim(sim)
+            super().toggle_clk(sim)
+            rtlsim_output = self.rtlsim(sim, rtlsim_inp)
+            odt = export_idt
+            target_bits = odt.bitwidth()
+            packed_bits = self.get_outstream_width()
+            out_npy_path = "{}/output.npy".format(code_gen_dir)
+            out_shape = self.get_folded_output_shape()
+            rtlsim_output_to_npy(
+                rtlsim_output, out_npy_path, odt, out_shape, packed_bits, target_bits
+            )
+            # load and reshape output
+            output = np.load(out_npy_path)
+            output = np.asarray([output], dtype=np.float32).reshape(*exp_oshape)
+            context[node.output[0]] = output
+        else:
+            raise Exception(
+                """Invalid value for attribute exec_mode! Is currently set to: {}
+            has to be set to one of the following value ("cppsim", "rtlsim")""".format(
+                    mode
+                )
+            )
+
+        assert (
+            context[node.output[0]].shape == exp_oshape
+        ), """Output
+        shape doesn't match expected shape (1, ofm_dim, ofm_dim, k*k*ifm_ch)."""
diff --git a/src/finn/custom_op/im2col.py b/src/finn/custom_op/im2col.py
index 82a6b140f7af1be4e5c0f429d077b99c7865383e..8ed0041704d421dab587f08bcbcd9e739e8434e9 100644
--- a/src/finn/custom_op/im2col.py
+++ b/src/finn/custom_op/im2col.py
@@ -80,6 +80,8 @@ class Im2Col(CustomOp):
             "input_shape": ("s", True, ""),
             "pad_amount": ("i", False, 0),
             "pad_value": ("i", False, 0),
+            # depthwise: if != 0, infer ConvolutionInputGenerator with depthwise == 1
+            "depthwise": ("i", False, 0),
         }
 
     def make_shape_compatible_op(self, model):
diff --git a/src/finn/custom_op/registry.py b/src/finn/custom_op/registry.py
index 46d27472a9802a4c2a9004bb28c8bd09be8fbfdb..d1f8f02a00810804163918d5fd4336ab6523bde0 100644
--- a/src/finn/custom_op/registry.py
+++ b/src/finn/custom_op/registry.py
@@ -44,6 +44,7 @@ from finn.custom_op.fpgadataflow.streamingdatawidthconverter_batch import (
     StreamingDataWidthConverter_Batch,
 )
 from finn.custom_op.fpgadataflow.globalaccpool_batch import GlobalAccPool_Batch
+from finn.custom_op.fpgadataflow.pool_batch import Pool_Batch
 from finn.custom_op.fpgadataflow.fmpadding_batch import FMPadding_Batch
 from finn.custom_op.fpgadataflow.thresholding_batch import Thresholding_Batch
 from finn.custom_op.fpgadataflow.addstreams_batch import AddStreams_Batch
@@ -67,6 +68,7 @@ custom_op["MaxPoolNHWC"] = MaxPoolNHWC
 custom_op["StreamingDataWidthConverter_Batch"] = StreamingDataWidthConverter_Batch
 custom_op["StreamingFIFO"] = StreamingFIFO
 custom_op["GlobalAccPool_Batch"] = GlobalAccPool_Batch
+custom_op["Pool_Batch"] = Pool_Batch
 custom_op["FMPadding_Batch"] = FMPadding_Batch
 custom_op["Thresholding_Batch"] = Thresholding_Batch
 custom_op["AddStreams_Batch"] = AddStreams_Batch
diff --git a/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py b/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py
index 09ecf4f06a353cb14c6ea3ef17310092fff75bb8..afb37b1ec0e0fc19a170b19337bceafd34f11a0e 100644
--- a/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py
+++ b/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py
@@ -36,6 +36,8 @@ from finn.transformation.infer_shapes import InferShapes
 from finn.transformation.infer_datatypes import InferDataTypes
 import finn.core.data_layout as DataLayout
 from finn.util.onnx import nchw_to_nhwc
+import warnings
+from finn.util.basic import get_by_name
 
 
 class InferConvInpGen(Transformation):
@@ -58,6 +60,7 @@ class InferConvInpGen(Transformation):
                 k = i2c_inst.get_nodeattr("kernel_size")
                 pad = i2c_inst.get_nodeattr("pad_amount")
                 pad_val = i2c_inst.get_nodeattr("pad_value")
+                depthwise = i2c_inst.get_nodeattr("depthwise")
                 ifm_ch = i2c_in_shape[-1]
                 ifm_dim = i2c_in_shape[1]
                 ofm_dim = i2c_out_shape[1]
@@ -69,7 +72,11 @@ class InferConvInpGen(Transformation):
 
                 if pad > 0:
                     # if padding enabled, ensure pad_val supported by DataType
-                    assert dt.allowed(pad_val), "Im2Col DataType must support pad_val"
+                    # assert dt.allowed(pad_val),"""FMPadding_Batch DataType
+                    # must support pad_val"""
+                    assert (
+                        pad_val == 0
+                    ), "FMPadding_Batch doesn't currently support pad_val!= 0"
 
                     odim_padding = ifm_dim + 2 * pad
 
@@ -114,6 +121,7 @@ class InferConvInpGen(Transformation):
                     Stride=stride,
                     inputDataType=dt.name,
                     outputDataType=dt.name,
+                    depthwise=depthwise,
                 )
                 graph.node.insert(ConvInpGen_node_idx, ConvInpGen_node)
                 # remove old nodes
@@ -171,6 +179,137 @@ class InferStreamingMaxPool(Transformation):
         return (model, graph_modified)
 
 
+class InferPool_Batch(Transformation):
+    """If kernel_shape > strides, replace Pool layer with  with of Im2col
+    + pool(with kernel_shape == strides), plus Transpose layers to keep the original
+    data layout."""
+
+    def apply(self, model):
+        graph = model.graph
+        node_ind = 0
+        graph_modified = False
+        for n in graph.node:
+            node_ind += 1
+            if n.op_type in ["MaxPool"]:
+                # extract pool parameters
+                k = get_by_name(n.attribute, "kernel_shape").ints[-1]
+                stride = get_by_name(n.attribute, "strides").ints[-1]
+
+                if k <= stride:
+                    continue
+
+                try:
+                    pad = get_by_name(n.attribute, "pads").ints[-1]
+                except AttributeError:
+                    pad = 0
+
+                node_input = n.input[0]
+                node_output = n.output[0]
+                idt = model.get_tensor_datatype(node_input)
+                if not idt.is_integer():
+                    continue
+
+                # odt = model.get_tensor_datatype(node_output)
+
+                ifm_ch = model.get_tensor_shape(n.input[0])[1]  # assume NCHW
+                ofm_ch = ifm_ch
+                ifm_dim = model.get_tensor_shape(n.input[0])[-1]  # assume NCHW
+                ofm_dim = model.get_tensor_shape(n.output[0])[-1]  # assume NCHW
+                # create new intermediate values
+                inp_trans_out = helper.make_tensor_value_info(
+                    model.make_new_valueinfo_name(),
+                    TensorProto.FLOAT,
+                    (1, ifm_dim, ifm_dim, ifm_ch),  # NHWC
+                )
+                graph.value_info.append(inp_trans_out)
+                inp_trans_out = inp_trans_out.name
+                model.set_tensor_datatype(inp_trans_out, idt)
+
+                im2col_out = helper.make_tensor_value_info(
+                    model.make_new_valueinfo_name(),
+                    TensorProto.FLOAT,
+                    (1, ofm_dim, ofm_dim, ifm_ch * k * k),
+                )
+                graph.value_info.append(im2col_out)
+                im2col_out = im2col_out.name
+                model.set_tensor_datatype(im2col_out, idt)
+
+                pool_output = helper.make_tensor_value_info(
+                    model.make_new_valueinfo_name(),
+                    TensorProto.FLOAT,
+                    (1, ofm_dim, ofm_dim, ofm_ch),
+                )
+                graph.value_info.append(pool_output)
+                pool_output = pool_output.name
+                # model.set_tensor_datatype(pool_output, odt)
+
+                # create new nodes
+                # NCHW -> NHWC
+                inp_trans_node = helper.make_node(
+                    "Transpose", [node_input], [inp_trans_out], perm=[0, 2, 3, 1]
+                )
+
+                if n.op_type == "MaxPool":
+                    pool_fxn = "MaxPool"
+                    pad_value = idt.min()
+                else:
+                    raise Exception(
+                        "pad_value and pool_fxn not configured for {}".format(n.op_type)
+                    )
+
+                # format input tensor
+                im2col_node = helper.make_node(
+                    "Im2Col",
+                    [inp_trans_out],
+                    [im2col_out],
+                    domain="finn",
+                    stride=stride,
+                    kernel_size=k,
+                    pad_amount=pad,
+                    pad_value=pad_value,
+                    depthwise=1,
+                    input_shape="(1,{},{},{})".format(ifm_dim, ifm_dim, ifm_ch),
+                )
+
+                # Warning PE has to be equal to ifm_ch until Im2Col is replaced by
+                # ConvolutionInputGenerator with depthwise=1.
+                # For other settings the output will be incorrect due to incorrect input
+                # data layout
+                pool_node = helper.make_node(
+                    "Pool_Batch",
+                    [im2col_out],
+                    [pool_output],
+                    domain="finn",
+                    backend="fpgadataflow",
+                    dataType=idt.name,
+                    Channels=ifm_ch,
+                    PE=ifm_ch,
+                    KernelSize=k,
+                    Function=pool_fxn,
+                    OutImgDim=ofm_dim,
+                    BatchSize=1,
+                )
+
+                # NHWC -> NCHW
+                out_trans_node = helper.make_node(
+                    "Transpose", [pool_output], [node_output], perm=[0, 3, 1, 2]
+                )
+
+                # insert nodes where the conv is to preserve topological ordering
+                graph.node.insert(node_ind, inp_trans_node)
+                graph.node.insert(node_ind + 1, im2col_node)
+                graph.node.insert(node_ind + 2, pool_node)
+                graph.node.insert(node_ind + 3, out_trans_node)
+                # remove old node
+                graph.node.remove(n)
+                graph_modified = True
+
+        if graph_modified:
+            model = model.transform(InferShapes())
+            model = model.transform(InferDataTypes())
+        return (model, graph_modified)
+
+
 class InferBinaryStreamingFCLayer(Transformation):
     """Convert XnorPopcountMatMul layers to
     StreamingFCLayer_Batch layers. Any immediately following MultiThreshold
@@ -496,6 +635,36 @@ class InferThresholdingLayer(Transformation):
 class InferChannelwiseLinearLayer(Transformation):
     """Convert any channel-wise Add/Mul into a HLS layer."""
 
+    def get_smallest_possible(self, vals):
+        """Returns smallest (fewest bits) possible DataType that can represent
+        value. Prefers unsigned integers where possible."""
+        vals = np.array(vals)
+        for v in vals:
+            assert int(v) == v, "Error float value"
+
+        for k in DataType.__members__:
+            dt = DataType[k]
+
+            if dt in [DataType.BIPOLAR, DataType.TERNARY, DataType.FLOAT32]:
+                # not currently supported
+                continue
+
+            if (dt.min() <= vals).all() and (vals <= dt.max()).all():
+                return dt
+
+        warnings.warn(
+            """InferChannelwiseLinearLayer: Output values may not be
+        representable with supported data types.
+        Setting maximum width data type available.
+        This will lead to errors if there are no constrains on the input
+        """
+        )
+
+        if (0 <= vals).all():
+            return DataType.UINT32
+        else:
+            return DataType.INT32
+
     def apply(self, model):
         graph = model.graph
         node_ind = 0
@@ -503,45 +672,52 @@ class InferChannelwiseLinearLayer(Transformation):
         for node in graph.node:
             node_ind += 1
             if node.op_type == "Add" or node.op_type == "Mul":
+                # assuming input[0] is dynamic
                 ll_input = node.input[0]
                 ll_output = node.output[0]
                 ll_in_shape = model.get_tensor_shape(ll_input)
-                # check that:
-                # input 1 is an initializer
-                # with shape (ch, 1)
-                # and initializer is integers
+
+                # check if input 1 has an initializer
                 ll_const = node.input[1]
                 if ll_const is not None:
                     ll_cinit = model.get_initializer(ll_const)
+                    if ll_cinit is None:
+                        # input 1 is also dynamic
+                        continue
                 else:
                     continue
 
-                ll_cinit_shape = list(ll_cinit.shape)
-                # get number of channels from input
+                # get number of channels and channel index from input
                 ll_in_layout = model.get_tensor_layout(ll_input)
                 if ll_in_layout == DataLayout.NHWC or ll_in_layout == DataLayout.NC:
+                    ch_index = -1
                     ch = ll_in_shape[-1]
                 elif ll_in_layout == DataLayout.NCHW:
+                    ch_index = 1
                     ch = ll_in_shape[1]
                 else:
                     continue
 
-                # check if the shape of initializer is compatible with
-                # number of channels, e.g. (ch,1) or (ch)
-                # TODO: verify plausible shapes
-                if np.prod(ll_cinit_shape) != ch:
+                # check if the shape of initializer is compatible
+                ll_cinit_shape = list(ll_cinit.shape)
+                if np.prod(ll_cinit_shape) == 1:
+                    warnings.warn(
+                        "Broadcasting " + str(node.op_type) + "(" + node.name + ")"
+                    )
+                    ll_cinit = np.full((ch), ll_cinit.flatten()[0])
+                elif np.prod(ll_cinit_shape) != ch or ll_cinit_shape[ch_index] != ch:
+                    # parameter shape not compatible with Channelwise_batch
                     continue
 
                 # check initializer contains integers as floats
                 if not (ll_cinit.astype(np.int32) == ll_cinit).all():
                     continue
-
                 # all initializer conditions are met
+
                 # check inputs
                 idt = model.get_tensor_datatype(ll_input)
-
-                # skip conversion for layers with float input
                 if not idt.is_integer():
+                    # skip conversion for layers with float input
                     continue
 
                 # check layout of inputs/outputs, and convert if needed
@@ -559,24 +735,32 @@ class InferChannelwiseLinearLayer(Transformation):
                     ll_output = nchw_to_nhwc(ll_output, model, node_ind, reverse=True)
                     node_ind += 1
 
-                # create node with no parallelization first
-                pe = 1
-                assert ch % pe == 0, "Requirement IFC divisable by PE is violated."
+                # get parameter data type
+                param_min = min(ll_cinit.flatten())
+                param_max = max(ll_cinit.flatten())
+                pdt = self.get_smallest_possible([param_min, param_max])
 
                 # set function and determine output data type
                 if node.op_type == "Add":
                     func = "add"
-                    if idt.signed():
-                        odt = DataType.get_smallest_possible(2 * idt.min())
-                    else:
-                        odt = DataType.get_smallest_possible(2 * idt.max())
+                    out_min = idt.min() + param_min
+                    out_max = idt.max() + param_max
+                    odt = self.get_smallest_possible([out_min, out_max])
                 elif node.op_type == "Mul":
                     func = "mul"
-                    if idt.signed():
-                        odt = DataType.get_smallest_possible(abs(idt.min()) * idt.min())
-                    else:
-                        odt = DataType.get_smallest_possible(idt.max() * idt.max())
+                    possible_limits = []
+                    possible_limits += [idt.min() * param_min]
+                    possible_limits += [idt.min() * param_max]
+                    possible_limits += [idt.max() * param_min]
+                    possible_limits += [idt.max() * param_max]
+                    odt = self.get_smallest_possible(possible_limits)
+
                 model.set_initializer(ll_const, ll_cinit.reshape(ch))
+                model.set_tensor_datatype(ll_output, odt)
+
+                # create node with no parallelization first
+                pe = 1
+                assert ch % pe == 0, "Requirement IFC divisable by PE is violated."
                 # create and insert node
                 new_node = helper.make_node(
                     "ChannelwiseOp_Batch",
@@ -588,6 +772,7 @@ class InferChannelwiseLinearLayer(Transformation):
                     NumChannels=ch,
                     PE=pe,
                     inputDataType=idt.name,
+                    paramDataType=pdt.name,
                     outputDataType=odt.name,
                     numInputVectors=list(ll_in_shape[:-1]),
                 )
diff --git a/src/finn/transformation/fpgadataflow/prepare_cppsim.py b/src/finn/transformation/fpgadataflow/prepare_cppsim.py
index 4f050be8540ddf5ef48699d1658b571852ff4510..6eae560e1191642cfaf85d92c6d0fcf644630973 100644
--- a/src/finn/transformation/fpgadataflow/prepare_cppsim.py
+++ b/src/finn/transformation/fpgadataflow/prepare_cppsim.py
@@ -80,7 +80,6 @@ class PrepareCppSim(Transformation):
             self._num_workers = mp.cpu_count()
 
     def prepareCppSim_node(self, node):
-        print(node.name)
         if is_fpgadataflow_node(node) is True:
             _codegen_single_node(node, self.model)
         return (node, False)
diff --git a/src/finn/transformation/streamline/absorb.py b/src/finn/transformation/streamline/absorb.py
index dbcf97361017144174f9fbfca35a84361b5abd26..4266488c7d1b86f2997d4c77d70b80f88bf37442 100644
--- a/src/finn/transformation/streamline/absorb.py
+++ b/src/finn/transformation/streamline/absorb.py
@@ -28,11 +28,13 @@
 
 import numpy as np
 from onnx import helper as oh
+import warnings
 
 from finn.core.datatype import DataType
 from finn.transformation import Transformation
 from finn.util.basic import get_by_name
 from finn.custom_op.registry import getCustomOp
+from finn.transformation.infer_shapes import InferShapes
 from finn.transformation.infer_datatypes import InferDataTypes
 
 
@@ -290,3 +292,38 @@ class AbsorbTransposeIntoMultiThreshold(Transformation):
         if graph_modified:
             model = model.transform(InferDataTypes())
         return (model, graph_modified)
+
+
+class AbsorbScalarMulIntoTopK(Transformation):
+    """Absorb a mul node into a suceeding topk node if the mul is scalar."""
+
+    def apply(self, model):
+        graph = model.graph
+        node_ind = 0
+        graph_modified = False
+        for n in graph.node:
+            node_ind += 1
+            if n.op_type == "TopK":
+                prod = model.find_producer(n.input[0])
+                if prod is not None and prod.op_type == "Mul":
+                    prod_input = prod.input[0]
+                    param_name = prod.input[1]
+                    A = model.get_initializer(param_name)
+                    if A is None:
+                        warnings.warn("Param is not constant, skipping")
+                        continue
+                    if all(x == 1 for x in A.shape) and A > 0:
+                        # if the mul is scalar and positive, we can just delete the
+                        # mul node and rewire the top k node. Because the top k node
+                        # works with probabilities and their relation to each other
+                        # the relation doesn't change if every value is multiplied
+                        # with a scalar
+                        graph.node.remove(prod)
+                        n.input[0] = prod_input
+                        # to avoid error the dataype is set to float32
+                        model.set_tensor_datatype(n.input[0], DataType.FLOAT32)
+                        graph_modified = True
+        if graph_modified:
+            model = model.transform(InferShapes())
+            model = model.transform(InferDataTypes())
+        return (model, graph_modified)
diff --git a/src/finn/transformation/streamline/reorder.py b/src/finn/transformation/streamline/reorder.py
index b46b82c77a3f1b70a3b05d87cd3c48fc1d94fd45..a1bd16f6d0b70193122d5d067ccdee395260c7b1 100644
--- a/src/finn/transformation/streamline/reorder.py
+++ b/src/finn/transformation/streamline/reorder.py
@@ -32,6 +32,7 @@ from onnx import helper as oh
 
 from finn.transformation import Transformation
 from finn.transformation.infer_shapes import InferShapes
+from finn.core.datatype import DataType
 from finn.core.onnx_exec import execute_node
 from finn.util.basic import get_by_name
 from finn.custom_op.registry import getCustomOp
@@ -338,6 +339,71 @@ class MoveScalarMulPastConv(Transformation):
         return (model, graph_modified)
 
 
+class MoveMulPastDWConv(Transformation):
+    """Move channelwise mul operations past depthwise conv operations. We want to have muls
+    next to each other such that they can be collapsed into a single mul."""
+
+    def apply(self, model):
+        graph = model.graph
+        node_ind = 0
+        graph_modified = False
+        for n in graph.node:
+            node_ind += 1
+            if (
+                n.op_type == "Mul"
+                and not model.is_fork_node(n)
+                and not model.is_join_node(n)
+            ):
+                consumer = model.find_consumer(n.output[0])
+                if (
+                    consumer is not None
+                    and consumer.op_type == "Conv"
+                    and not model.is_join_node(consumer)
+                ):
+                    mul_weight_name = n.input[1]
+                    A = model.get_initializer(mul_weight_name)
+                    if A is None:
+                        warnings.warn(
+                            """Mul weight tensor is not set. If it is a constant,
+                                please use set_initializer to set the tensor."""
+                        )
+                        continue
+                    conv_node = consumer
+                    mul_node = n
+                    start_name = mul_node.input[0]
+                    conv_in_name = conv_node.input[0]
+                    conv_in_shape = model.get_tensor_shape(conv_in_name)
+                    ifm_ch = conv_in_shape[1]
+                    group_attribute = get_by_name(consumer.attribute, "group")
+                    if group_attribute is None:
+                        continue
+                    group_attribute = group_attribute.i
+                    conv_out_name = conv_node.output[0]
+                    conv_out_shape = model.get_tensor_shape(conv_out_name)
+                    if A.shape == (1, ifm_ch, 1, 1) and ifm_ch == group_attribute:
+                        # if the mul is channelwise and conv is depthwise,
+                        # we can simply swap the order of ops
+                        # rewire mul input to be conv input
+                        conv_node.input[0] = start_name
+                        model.set_tensor_shape(start_name, conv_in_shape)
+                        model.set_tensor_datatype(start_name, DataType.FLOAT32)
+                        # use old conv input tensor as conv output
+                        conv_node.output[0] = conv_in_name
+                        model.set_tensor_shape(conv_in_name, conv_out_shape)
+                        model.set_tensor_datatype(conv_in_name, DataType.FLOAT32)
+                        # use new conv output as new mul node input
+                        mul_node.input[0] = conv_in_name
+                        # use old conv output as new mul node output
+                        mul_node.output[0] = conv_out_name
+                        model.set_tensor_datatype(conv_out_name, DataType.FLOAT32)
+                        # move mul node past conv node
+                        graph.node.remove(mul_node)
+                        graph.node.insert(node_ind, mul_node)
+                        graph_modified = True
+        model = model.transform(InferShapes())
+        return (model, graph_modified)
+
+
 class MoveLinearPastEltwiseAdd(Transformation):
     """Move linear operations (mul, add) past elementwise add operations where possible.
        Specifically,matches and transforms the following patterns:
diff --git a/src/finn/util/onnx.py b/src/finn/util/onnx.py
index 6a56f0cdcc85cd81a0c448971ab625268e8408d2..4d7cdd126ededac887639a932c2021ef5f081c02 100644
--- a/src/finn/util/onnx.py
+++ b/src/finn/util/onnx.py
@@ -41,9 +41,10 @@ def valueinfo_to_tensor(vi):
 
 
 def nchw_to_nhwc(t, model, idx, reverse=False):
-    """Converts a NCHW <-> NHWC by inserting a transpose. Input t is assumed NCHW.
-    By default we insert a transpose NCHW -> NHWC, but if reverse is true,
-    we convert NHWC -> NCHW"""
+    """Converts between NCHW <-> NHWC layouts for tensor t by inserting a transpose. 
+    If reverse=False, t is assumed NCHW and we insert transpose to convert NCHW -> NHWC
+    If reverse=True, t is assumed NHWC and we insert transpose to convert NHWC -> NCHW.
+    """
     graph = model.graph
     # create new NHWC tensor
     t_shape = model.get_tensor_shape(t)
diff --git a/tests/fpgadataflow/test_convert_to_hls_channelwise_layer.py b/tests/fpgadataflow/test_convert_to_hls_channelwise_layer.py
new file mode 100644
index 0000000000000000000000000000000000000000..d09c64a1250f78604c1a0a362cf234712de2cf57
--- /dev/null
+++ b/tests/fpgadataflow/test_convert_to_hls_channelwise_layer.py
@@ -0,0 +1,115 @@
+import pytest
+
+from onnx import TensorProto, helper
+
+import finn.core.onnx_exec as oxe
+from finn.core.datatype import DataType
+from finn.core.modelwrapper import ModelWrapper
+import finn.transformation.fpgadataflow.convert_to_hls_layers as to_hls
+from finn.transformation.fpgadataflow.prepare_ip import PrepareIP
+from finn.transformation.fpgadataflow.prepare_cppsim import PrepareCppSim
+from finn.transformation.fpgadataflow.compile_cppsim import CompileCppSim
+from finn.transformation.fpgadataflow.hlssynth_ip import HLSSynthIP
+from finn.transformation.fpgadataflow.replace_verilog_relpaths import (
+    ReplaceVerilogRelPaths,
+)
+from finn.transformation.fpgadataflow.set_exec_mode import SetExecMode
+
+from finn.transformation.fpgadataflow.prepare_rtlsim import PrepareRTLSim
+from finn.transformation.infer_data_layouts import InferDataLayouts
+from finn.transformation.general import GiveUniqueNodeNames
+from finn.util.basic import gen_finn_dt_tensor
+from finn.transformation.infer_shapes import InferShapes
+import numpy as np
+
+
+def prepare_inputs(input_tensor):
+    return {"inp": input_tensor}
+
+
+def make_single_maxpool_modelwrapper(onnx_op_name, ishape, idt, pdt, pshape):
+
+    inp = helper.make_tensor_value_info("inp", TensorProto.FLOAT, ishape)
+    outp = helper.make_tensor_value_info("outp", TensorProto.FLOAT, ishape)
+    p0 = helper.make_tensor_value_info("p0", TensorProto.FLOAT, pshape)
+
+    model = helper.make_model(
+        helper.make_graph(
+            name="test",
+            inputs=[inp],
+            outputs=[outp],
+            value_info=[p0],
+            nodes=[helper.make_node(onnx_op_name, ["inp", "p0"], ["outp"])],
+        )
+    )
+
+    model = ModelWrapper(model)
+    model.set_initializer("p0", gen_finn_dt_tensor(pdt, pshape))
+    model.set_tensor_datatype("inp", idt)
+    model.transform(InferDataLayouts(), make_deepcopy=False)
+    model.transform(InferShapes(), make_deepcopy=False)
+    return model
+
+
+# parameter datatype
+@pytest.mark.parametrize("pdt", [DataType.BIPOLAR, DataType.UINT4, DataType.INT2])
+# input datatype
+@pytest.mark.parametrize("idt", [DataType.INT32, DataType.UINT4, DataType.INT4])
+# function
+@pytest.mark.parametrize("onnx_op_name", ["Add", "Mul"])
+# vector parameter or scalar parameter (broadcast)
+@pytest.mark.parametrize("scalar_param", [True, False])
+# execution mode
+@pytest.mark.parametrize("exec_mode", ["cppsim", "rtlsim"])
+@pytest.mark.vivado
+@pytest.mark.slow
+def test_convert_to_hls_channelwise_layer(
+    pdt, idt, onnx_op_name, scalar_param, exec_mode
+):
+    ifm_ch = 16
+    ifm_dim = 5
+    ishape = (1, ifm_ch, ifm_dim, ifm_dim)
+    if scalar_param:
+        pshape = (1,)
+    else:
+        pshape = (1, ifm_ch, 1, 1)
+
+    np.random.seed(0)
+    model = make_single_maxpool_modelwrapper(onnx_op_name, ishape, idt, pdt, pshape)
+
+    # Since the aren't Data types with a bit width of a non power of 2,
+    # there are cases where the input won't use it full range.
+    if idt == DataType.INT32:
+        x = gen_finn_dt_tensor(DataType.INT16, (1, ifm_ch, ifm_dim, ifm_dim))
+    elif idt == DataType.UINT32:
+        x = gen_finn_dt_tensor(DataType.UINT16, (1, ifm_ch, ifm_dim, ifm_dim))
+    else:
+        x = gen_finn_dt_tensor(idt, (1, ifm_ch, ifm_dim, ifm_dim))
+
+    input_dict = prepare_inputs(x)
+    y_expected = oxe.execute_onnx(model, input_dict)["outp"]
+
+    new_model = model.transform(to_hls.InferChannelwiseLinearLayer())
+    new_model = new_model.transform(GiveUniqueNodeNames())
+
+    if exec_mode == "cppsim":
+        new_model = new_model.transform(PrepareCppSim())
+        new_model = new_model.transform(CompileCppSim())
+        new_model = new_model.transform(SetExecMode("cppsim"))
+    elif exec_mode == "rtlsim":
+        new_model = new_model.transform(SetExecMode("rtlsim"))
+        new_model = new_model.transform(GiveUniqueNodeNames())
+        new_model = new_model.transform(PrepareIP("xc7z020clg400-1", 5))
+        new_model = new_model.transform(HLSSynthIP())
+        new_model = new_model.transform(ReplaceVerilogRelPaths())
+        new_model = new_model.transform(PrepareRTLSim())
+    else:
+        raise Exception("Unknown exec_mode")
+
+    ctx_produced = oxe.execute_onnx(
+        new_model, input_dict, return_full_exec_context=True
+    )
+    y_produced = ctx_produced["outp"]
+
+    assert (y_produced == y_expected).all()
+    assert new_model.graph.node[1].op_type == "ChannelwiseOp_Batch"
diff --git a/tests/fpgadataflow/test_convert_to_hls_pool_batch.py b/tests/fpgadataflow/test_convert_to_hls_pool_batch.py
new file mode 100644
index 0000000000000000000000000000000000000000..c9f78dcea1a1ce364d0657ad64de7d440d41b822
--- /dev/null
+++ b/tests/fpgadataflow/test_convert_to_hls_pool_batch.py
@@ -0,0 +1,160 @@
+# Copyright (c) 2020, Xilinx
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+#   list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+#   this list of conditions and the following disclaimer in the documentation
+#   and/or other materials provided with the distribution.
+#
+# * Neither the name of FINN nor the names of its
+#   contributors may be used to endorse or promote products derived from
+#   this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import pytest
+
+from onnx import TensorProto, helper
+import numpy as np
+import finn.core.onnx_exec as oxe
+from finn.core.datatype import DataType
+from finn.core.modelwrapper import ModelWrapper
+from finn.transformation.fpgadataflow.prepare_ip import PrepareIP
+from finn.transformation.fpgadataflow.prepare_cppsim import PrepareCppSim
+from finn.transformation.fpgadataflow.compile_cppsim import CompileCppSim
+from finn.transformation.fpgadataflow.hlssynth_ip import HLSSynthIP
+from finn.transformation.fpgadataflow.set_exec_mode import SetExecMode
+from finn.transformation.fpgadataflow.prepare_rtlsim import PrepareRTLSim
+import finn.transformation.fpgadataflow.convert_to_hls_layers as to_hls
+from finn.transformation.general import GiveUniqueNodeNames
+from finn.custom_op.registry import getCustomOp
+from finn.util.basic import gen_finn_dt_tensor
+from finn.transformation.infer_shapes import InferShapes
+
+
+def make_single_maxpool_modelwrapper(k, stride, pad, ifm_ch, ifm_dim, ofm_dim, idt):
+    odt = idt
+    inp = helper.make_tensor_value_info(
+        "inp", TensorProto.FLOAT, [1, ifm_ch, ifm_dim, ifm_dim]
+    )
+    outp = helper.make_tensor_value_info(
+        "outp", TensorProto.FLOAT, [1, ifm_ch, ofm_dim, ofm_dim]
+    )
+
+    mp_node = helper.make_node(
+        "MaxPool",
+        ["inp"],
+        ["outp"],
+        kernel_shape=[k, k],
+        pads=[pad, pad, pad, pad],
+        strides=[stride, stride],
+    )
+    graph = helper.make_graph(
+        nodes=[mp_node], name="mp_graph", inputs=[inp], outputs=[outp]
+    )
+
+    model = helper.make_model(graph, producer_name="mp-model")
+    model = ModelWrapper(model)
+
+    model.set_tensor_datatype("inp", idt)
+    model.set_tensor_datatype("outp", odt)
+    model = model.transform(InferShapes())
+
+    return model
+
+
+def prepare_inputs(input_tensor):
+    return {"inp": input_tensor}
+
+
+# input datatype
+@pytest.mark.parametrize("idt", [DataType.UINT4, DataType.INT4])
+# pool configuration:                   ( k,stride, pad, ifm_dim )
+@pytest.mark.parametrize(
+    "pool_config", [(3, 2, 0, 5), (3, 2, 1, 5), (2, 2, 0, 8), (5, 2, 2, 7)]
+)
+# input channels
+@pytest.mark.parametrize("ifm_ch", [1, 4, 20])
+# number of out channel computed in parallel
+@pytest.mark.parametrize("pe", [1, 4, 20])
+# execution mode
+@pytest.mark.parametrize("exec_mode", ["cppsim", "rtlsim"])
+# pool type
+@pytest.mark.parametrize("op_type", ["MaxPool"])
+@pytest.mark.slow
+@pytest.mark.vivado
+def test_convert_to_hls_pool_batch(idt, pool_config, ifm_ch, pe, exec_mode, op_type):
+    k, stride, pad, ifm_dim = pool_config
+
+    if ifm_ch % pe != 0:
+        pytest.skip("ifm_ch%pe != 0. Skipping")
+
+    if pad != 0 and idt.signed():
+        pytest.skip("No support for pal_val != 0. Skipping")
+
+    np.random.seed(0)
+    ofm_dim = int(((ifm_dim + 2 * pad - k) / stride) + 1)
+
+    x = gen_finn_dt_tensor(idt, (1, ifm_ch, ifm_dim, ifm_dim))
+    # prepare input data
+    input_dict = prepare_inputs(x)
+    if op_type == "MaxPool":
+        model = make_single_maxpool_modelwrapper(
+            k, stride, pad, ifm_ch, ifm_dim, ofm_dim, idt
+        )
+    else:
+        assert False, "{} is not a supported op_type".format(op_type)
+
+    y_expected = oxe.execute_onnx(model, input_dict)["outp"]
+
+    new_model = model.transform(to_hls.InferPool_Batch())
+    new_model = new_model.transform(GiveUniqueNodeNames())
+
+    if ifm_ch != pe:
+        new_model = new_model.transform(to_hls.InferConvInpGen())
+        # Folding
+        for n in new_model.graph.node:
+            if n.op_type == "ConvolutionInputGenerator":
+                inst = getCustomOp(n)
+                inst.set_nodeattr("SIMD", pe)
+            elif n.op_type == "Pool_Batch":
+                inst = getCustomOp(n)
+                inst.set_nodeattr("PE", pe)
+
+    if exec_mode == "cppsim":
+        new_model = new_model.transform(SetExecMode("cppsim"))
+        new_model = new_model.transform(PrepareCppSim())
+        new_model = new_model.transform(CompileCppSim())
+    elif exec_mode == "rtlsim":
+        new_model = new_model.transform(SetExecMode("rtlsim"))
+        new_model = new_model.transform(GiveUniqueNodeNames())
+        new_model = new_model.transform(PrepareIP("xc7z020clg400-1", 5))
+        new_model = new_model.transform(HLSSynthIP())
+        new_model = new_model.transform(PrepareRTLSim())
+    else:
+        raise Exception("Unknown exec_mode")
+
+    # execute new_model
+    y_produced = oxe.execute_onnx(new_model, input_dict)["outp"]
+    assert (y_produced == y_expected).all()
+    if stride != k:
+        if pad == 0 or ifm_ch == pe:
+            assert len(new_model.graph.node) == 4
+        else:
+            assert len(new_model.graph.node) == 5
+    else:
+        assert len(new_model.graph.node) == 1
diff --git a/tests/fpgadataflow/test_fpgadataflow_channelwise_ops.py b/tests/fpgadataflow/test_fpgadataflow_channelwise_ops.py
index 05d8e28498316f0a76da83cd70611801fdb37846..6a69d8180a4a9825a83b419666bbccb9f203a7d8 100644
--- a/tests/fpgadataflow/test_fpgadataflow_channelwise_ops.py
+++ b/tests/fpgadataflow/test_fpgadataflow_channelwise_ops.py
@@ -89,7 +89,7 @@ def make_modelwrapper(C, pe, idt, odt, func, vecs):
 # input datatype
 @pytest.mark.parametrize("idt", [DataType.INT4])
 # folding, -1 is maximum possible
-@pytest.mark.parametrize("nf", [-1, 2, 1])
+@pytest.mark.parametrize("nf", [-1, 2])
 # number of input features
 @pytest.mark.parametrize("ich", [16])
 # vecs
diff --git a/tests/transformation/test_absorb_mul_into_topk.py b/tests/transformation/test_absorb_mul_into_topk.py
new file mode 100644
index 0000000000000000000000000000000000000000..1394220f7c336ccea8fe9c494734c4175bf2e847
--- /dev/null
+++ b/tests/transformation/test_absorb_mul_into_topk.py
@@ -0,0 +1,108 @@
+# Copyright (c) 2020, Xilinx
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+#   list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+#   this list of conditions and the following disclaimer in the documentation
+#   and/or other materials provided with the distribution.
+#
+# * Neither the name of FINN nor the names of its
+#   contributors may be used to endorse or promote products derived from
+#   this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+import pytest
+
+import numpy as np
+from onnx import TensorProto, helper
+
+from finn.core.modelwrapper import ModelWrapper
+from finn.transformation.infer_shapes import InferShapes
+from finn.transformation.infer_datatypes import InferDataTypes
+from finn.transformation.general import GiveUniqueNodeNames, GiveReadableTensorNames
+from finn.transformation.insert_topk import InsertTopK
+from finn.transformation.streamline.absorb import AbsorbScalarMulIntoTopK
+import finn.core.onnx_exec as oxe
+
+# parameter to indicate if mul parameter is negative or positive
+@pytest.mark.parametrize("mul_positive", [True, False])
+# parameter to indicate if mul parameter is scalar or not
+@pytest.mark.parametrize("scalar", [True, False])
+def test_absorb_mul_into_topk(mul_positive, scalar):
+    if scalar is True:
+        shape = [1]
+    else:
+        shape = [1, 1, 1, 1000]
+    inp = helper.make_tensor_value_info("inp", TensorProto.FLOAT, [1, 1, 1, 1000])
+    a0 = helper.make_tensor_value_info("a0", TensorProto.FLOAT, shape)
+    outp = helper.make_tensor_value_info("outp", TensorProto.FLOAT, [1, 1, 1, 1000])
+
+    mul_node = helper.make_node("Mul", ["inp", "a0"], ["outp"])
+    mul_graph = helper.make_graph(
+        nodes=[mul_node],
+        name="mul-graph",
+        inputs=[inp],
+        outputs=[outp],
+        value_info=[a0],
+    )
+
+    model = helper.make_model(mul_graph, producer_name="mul_model")
+    model = ModelWrapper(model)
+    # initialize values
+    if mul_positive is True:
+        a0_values = np.random.uniform(low=0.1, high=1, size=tuple(shape)).astype(
+            np.float32
+        )
+    else:
+        a0_values = np.random.uniform(low=-1, high=-0.1, size=tuple(shape)).astype(
+            np.float32
+        )
+    model.set_initializer("a0", a0_values)
+    model = model.transform(InsertTopK())
+    model = model.transform(InferShapes())
+    model = model.transform(InferDataTypes())
+    model = model.transform(GiveUniqueNodeNames())
+    model = model.transform(GiveReadableTensorNames())
+    model_transformed = model.transform(AbsorbScalarMulIntoTopK())
+
+    # compare execution results
+    inp_values = np.random.uniform(low=-10, high=10, size=(1, 1, 1, 1000)).astype(
+        np.float32
+    )
+    idict = {"global_in": inp_values}
+    odict = oxe.execute_onnx(model, idict, True)
+    y_indices = odict["global_out"]
+    y_values = odict["TopK_0_out0"]
+    odict = oxe.execute_onnx(model_transformed, idict, True)
+    y_tr_indices = odict["global_out"]
+    y_tr_values = odict["TopK_0_out0"]
+
+    # the indices stay the same, if the model is transformed or not
+    assert (y_indices == y_tr_indices).all()
+
+    if scalar is True and mul_positive is True:
+        # the values change if the model was transformed
+        assert (y_values != y_tr_values).all()
+
+        # check for new order
+        assert model.graph != model_transformed.graph
+        assert len(model.graph.node) - 1 == len(model_transformed.graph.node)
+        assert model_transformed.graph.node[0].op_type == "TopK"
+
+    else:
+        assert (y_values == y_tr_values).all()
+        assert model.graph == model_transformed.graph
diff --git a/tests/transformation/test_move_mul_past_dw_conv.py b/tests/transformation/test_move_mul_past_dw_conv.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ae8fbfe89986d58d3d71f5f8735a98469d9d1e3
--- /dev/null
+++ b/tests/transformation/test_move_mul_past_dw_conv.py
@@ -0,0 +1,93 @@
+import pytest
+
+from onnx import helper, TensorProto
+from finn.custom_op.im2col import compute_conv_output_dim
+import finn.core.onnx_exec as oxe
+from finn.core.datatype import DataType
+from finn.core.modelwrapper import ModelWrapper
+from finn.transformation.infer_datatypes import InferDataTypes
+from finn.transformation.infer_shapes import InferShapes
+from finn.util.basic import gen_finn_dt_tensor
+from finn.transformation.streamline.reorder import MoveMulPastDWConv
+
+
+# input dimension
+@pytest.mark.parametrize("ifm_dim", [4, 7])
+# input channels
+@pytest.mark.parametrize("ifm_ch", [2, 3])
+# kernel size
+@pytest.mark.parametrize("k", [2, 3])
+# stride
+@pytest.mark.parametrize("stride", [1, 2])
+# padding
+@pytest.mark.parametrize("pad_amt", [0, 1])
+# depthwise
+@pytest.mark.parametrize("dw", [0, 1])
+def test_move_mul_past_dw_conv(ifm_dim, ifm_ch, k, stride, pad_amt, dw):
+    if dw == 1:
+        ofm_ch = ifm_ch
+        groups = ifm_ch
+        W_shape = [ofm_ch, 1, k, k]
+    else:
+        ofm_ch = ifm_ch + 2
+        groups = 1
+        W_shape = [ofm_ch, ifm_ch, k, k]
+
+    ofm_dim = compute_conv_output_dim(ifm_dim, k, stride, pad_amt)
+
+    # set up onnx model
+    inp = helper.make_tensor_value_info(
+        "inp", TensorProto.FLOAT, [1, ifm_ch, ifm_dim, ifm_dim]
+    )
+    mul = helper.make_tensor_value_info("mul", TensorProto.FLOAT, [1, ifm_ch, 1, 1])
+    W = helper.make_tensor_value_info("W", TensorProto.FLOAT, W_shape)
+    outp = helper.make_tensor_value_info(
+        "outp", TensorProto.FLOAT, [1, ofm_ch, ofm_dim, ofm_dim]
+    )
+
+    Mul_node = helper.make_node("Mul", ["inp", "mul"], ["mul_out"])
+
+    Conv_node = helper.make_node(
+        "Conv",
+        ["mul_out", "W"],
+        ["outp"],
+        group=groups,
+        kernel_shape=[k, k],
+        pads=[pad_amt, pad_amt, pad_amt, pad_amt],
+        strides=[stride, stride],
+    )
+
+    graph = helper.make_graph(
+        nodes=[Mul_node, Conv_node],
+        name="mulpastconv_graph",
+        inputs=[inp],
+        outputs=[outp],
+        value_info=[mul, W],
+    )
+
+    model = helper.make_model(graph, producer_name="mulpastconv-model")
+    model = ModelWrapper(model)
+    inp_values = gen_finn_dt_tensor(DataType.INT2, [1, ifm_ch, ifm_dim, ifm_dim])
+    mul_values = gen_finn_dt_tensor(DataType.INT2, [1, ifm_ch, 1, 1])
+    W_values = gen_finn_dt_tensor(DataType.INT2, W_shape)
+    model.set_initializer("W", W_values)
+    model.set_initializer("mul", mul_values)
+    model = model.transform(InferShapes())
+    model = model.transform(InferDataTypes())
+    idict = {"inp": inp_values}
+    odict = oxe.execute_onnx(model, idict, True)
+    out_before = odict["outp"]
+
+    # move channelwise multiplication past depthwise conv
+    model_transformed = model.transform(MoveMulPastDWConv())
+    odict = oxe.execute_onnx(model_transformed, idict, True)
+    out_after = odict["outp"]
+
+    assert (out_before == out_after).all()
+
+    if dw == 0:
+        assert model.graph.node[0].op_type == model_transformed.graph.node[0].op_type
+        assert model.graph.node[1].op_type == model_transformed.graph.node[1].op_type
+    else:
+        assert model.graph.node[0].op_type == model_transformed.graph.node[1].op_type
+        assert model.graph.node[1].op_type == model_transformed.graph.node[0].op_type