diff --git a/src/finn/custom_op/fpgadataflow/streamingfclayer_batch.py b/src/finn/custom_op/fpgadataflow/streamingfclayer_batch.py
index 738bfa25403ded4bf22945e1dcd353ae9d5634fc..72aa322e0e44a6f4a5c11025d94bdfeb820338a3 100644
--- a/src/finn/custom_op/fpgadataflow/streamingfclayer_batch.py
+++ b/src/finn/custom_op/fpgadataflow/streamingfclayer_batch.py
@@ -617,6 +617,9 @@ class StreamingFCLayer_Batch(HLSCustomOp):
                 wt_is_bipolar = wt_is_bipolar or (wt_is_binary and bin_xnor_mode)
                 if inp_is_bipolar and wt_is_bipolar:
                     tdt = DataType.UINT32
+                assert np.vectorize(tdt.allowed)(
+                    threshold_tensor
+                ).all(), "Thresholds are not int"
                 thresholds_hls_code = numpy_to_hls_code(
                     threshold_tensor, tdt, "thresholds", False, True
                 )
diff --git a/src/finn/custom_op/fpgadataflow/thresholding_batch.py b/src/finn/custom_op/fpgadataflow/thresholding_batch.py
index c2e3739e8f62b5ce0459ee8fbb1f3dcda7b50c1e..379ebd92d86d54c6bc621c7f89b01eacba2b5d3f 100644
--- a/src/finn/custom_op/fpgadataflow/thresholding_batch.py
+++ b/src/finn/custom_op/fpgadataflow/thresholding_batch.py
@@ -284,6 +284,9 @@ class Thresholding_Batch(HLSCustomOp):
 
         threshold_tensor = self.get_hls_compatible_threshold_tensor(thresholds)
         tdt = DataType.INT32
+        assert np.vectorize(tdt.allowed)(
+            threshold_tensor
+        ).all(), "Thresholds are not int"
         thresholds_hls_code = numpy_to_hls_code(
             threshold_tensor, tdt, "thresholds", False, True
         )
diff --git a/src/finn/custom_op/fpgadataflow/vector_vector_activate_batch.py b/src/finn/custom_op/fpgadataflow/vector_vector_activate_batch.py
new file mode 100644
index 0000000000000000000000000000000000000000..942e4b25700d0c52c1bc5bcd81614a058342f178
--- /dev/null
+++ b/src/finn/custom_op/fpgadataflow/vector_vector_activate_batch.py
@@ -0,0 +1,506 @@
+import os
+import numpy as np
+
+from onnx import TensorProto, helper
+from finn.core.datatype import DataType
+from finn.custom_op.fpgadataflow import HLSCustomOp
+from finn.util.basic import interleave_matrix_outer_dim_from_partitions
+from finn.util.data_packing import (
+    npy_to_rtlsim_input,
+    numpy_to_hls_code,
+    rtlsim_output_to_npy,
+)
+
+
+class Vector_Vector_Activate_Batch(HLSCustomOp):
+    """Class that corresponds to finn-hlslib Vector_Vector_Activate_Batch function"""
+
+    def __init__(self, onnx_node):
+        super().__init__(onnx_node)
+
+    def get_nodeattr_types(self):
+        my_attrs = {
+            "PE": ("i", True, 0),
+            "Dim": ("i", True, 0),
+            "Channels": ("i", True, 0),
+            "Kernel": ("i", True, 0),
+            "resType": ("s", True, ""),
+            "ActVal": ("i", False, 0),
+            # FINN DataTypes for inputs, weights, outputs
+            "inputDataType": ("s", True, ""),
+            "weightDataType": ("s", True, ""),
+            "outputDataType": ("s", True, ""),
+            # no-activation mode (produce accumulators)
+            "noActivation": ("i", False, 0),
+        }
+        my_attrs.update(super().get_nodeattr_types())
+        return my_attrs
+
+    def calc_wmem(self):
+        """Calculates and returns WMEM."""
+        ch = self.get_nodeattr("Channels")
+        k = self.get_nodeattr("Kernel")
+        pe = self.get_nodeattr("PE")
+        wmem = k * k * ch // pe
+        return wmem
+
+    def calc_tmem(self):
+        """Calculates and returns TMEM."""
+        if self.get_nodeattr("noActivation") == 1:
+            return 0
+        else:
+            ch = self.get_nodeattr("Channels")
+            pe = self.get_nodeattr("PE")
+            return ch // pe
+
+    def make_shape_compatible_op(self, model):
+        oshape = self.get_normal_output_shape()
+        # implement tensor with correct shape
+        values = np.random.randn(*oshape).astype(np.float32)
+        return helper.make_node(
+            "Constant",
+            inputs=[],
+            outputs=[self.onnx_node.output[0]],
+            value=helper.make_tensor(
+                name="const_tensor",
+                data_type=TensorProto.FLOAT,
+                dims=values.shape,
+                vals=values.flatten().astype(float),
+            ),
+        )
+
+    def infer_node_datatype(self, model):
+        node = self.onnx_node
+        # check input datatype against property
+        idt_name = self.get_input_datatype().name
+        exp_idt_name = self.get_nodeattr("inputDataType")
+        assert exp_idt_name == idt_name, "Bad input DataType for VVAU  node"
+        # set output datatype from property
+        odt = self.get_output_datatype()
+        model.set_tensor_datatype(node.output[0], odt)
+
+    def verify_node(self):
+        pass
+
+    def get_input_datatype(self):
+        """Returns FINN DataType of input."""
+        return DataType[self.get_nodeattr("inputDataType")]
+
+    def get_weight_datatype(self):
+        """Returns FINN DataType of weights."""
+        return DataType[self.get_nodeattr("weightDataType")]
+
+    def get_output_datatype(self):
+        """Returns FINN DataType of output."""
+        return DataType[self.get_nodeattr("outputDataType")]
+
+    def get_instream_width(self):
+        i_bits = self.get_input_datatype().bitwidth()
+        in_width = i_bits * self.get_nodeattr("Channels")
+        return in_width
+
+    def get_outstream_width(self):
+        o_bits = self.get_output_datatype().bitwidth()
+        out_width = o_bits * self.get_nodeattr("PE")
+        return out_width
+
+    def get_folded_input_shape(self):
+        k = self.get_nodeattr("Kernel")
+        sf = k * k
+        dim = self.get_nodeattr("Dim")
+        ch = self.get_nodeattr("Channels")
+        pe = self.get_nodeattr("PE")
+        nf = ch // pe
+        folded_input_shape = tuple([1, dim, dim, sf * nf, pe])
+        return folded_input_shape
+
+    def get_folded_output_shape(self):
+        ch = self.get_nodeattr("Channels")
+        pe = self.get_nodeattr("PE")
+        nf = ch // pe
+        dim = self.get_nodeattr("Dim")
+        folded_output_shape = tuple([1, dim, dim, nf, pe])
+        return folded_output_shape
+
+    def get_normal_input_shape(self):
+        dim = self.get_nodeattr("Dim")
+        ch = self.get_nodeattr("Channels")
+        k = self.get_nodeattr("Kernel")
+        normal_input_shape = tuple([1, dim, dim, k * k * ch])
+        return normal_input_shape
+
+    def get_normal_output_shape(self):
+        ch = self.get_nodeattr("Channels")
+        dim = self.get_nodeattr("Dim")
+        normal_output_shape = tuple([1, dim, dim, ch])
+        return normal_output_shape
+
+    def get_number_output_values(self):
+        nf = np.prod(self.get_folded_output_shape()[:-1])
+        return nf
+
+    def get_exp_cycles(self):
+        pe = self.get_nodeattr("PE")
+        ch = self.get_nodeattr("Channels")
+        dim = self.get_nodeattr("Dim")
+        k = self.get_nodeattr("Kernel")
+        # currently FINN supports for vvau a batch size of 1
+        batch_size = 1
+        # since mmv != 1 is not supported yet, we set mmv for now to 1
+        mmv = 1
+        exp_cycles = ((ch * k * k) / pe) * batch_size * (dim * dim) / mmv
+        return int(exp_cycles)
+
+    def get_template_param_values(self):
+        """Returns the template parameter values according to input, output and weight
+        data types."""
+        ret = dict()
+        inp_hls_str = self.get_input_datatype().get_hls_datatype_str()
+        out_hls_str = self.get_output_datatype().get_hls_datatype_str()
+        inp_is_bipolar = self.get_input_datatype() == DataType.BIPOLAR
+        wt_is_bipolar = self.get_weight_datatype() == DataType.BIPOLAR
+        # fill in TSrcI and TWeightI
+        # TODO handle bipolar inputs
+        if inp_is_bipolar or wt_is_bipolar:
+            raise Exception("VVAU node doesn't support bipolar values yet.")
+        else:
+            ret["TSrcI"] = "Slice<%s>" % inp_hls_str
+            ret["TWeightI"] = "Identity"
+
+        # fill in TDstI
+        ret["TDstI"] = "Slice<%s>" % out_hls_str
+
+        return ret
+
+    def get_hls_compatible_weight_tensor(self, orig_weight_matrix):
+        pe = self.get_nodeattr("PE")
+        ch = self.get_nodeattr("Channels")
+        k = self.get_nodeattr("Kernel")
+        wmem = self.calc_wmem()
+        assert orig_weight_matrix.shape == (
+            ch,
+            1,
+            k,
+            k,
+        ), """Weights matrix doesn't
+        have expected shape (channels, 1, kernel_size, kernel_size)"""
+        ret = orig_weight_matrix
+        ret = ret.reshape(ch, k * k)
+        # distribute rows between PEs
+        ret = interleave_matrix_outer_dim_from_partitions(ret, pe)
+        ret = ret.reshape(1, pe, wmem, 1)
+        return ret
+
+    def get_hls_compatible_threshold_tensor(self, orig_thres_matrix):
+        ch = self.get_nodeattr("Channels")
+        pe = self.get_nodeattr("PE")
+        tmem = self.calc_tmem()
+        assert ch % pe == 0, "Requirement Channels divisable by PE is violated."
+        assert (
+            orig_thres_matrix.ndim == 2
+        ), """Threshold matrix dimension is
+        not as expected (2)."""
+        n_thres_steps = orig_thres_matrix.shape[1]
+        ret = orig_thres_matrix
+        # distribute rows between PEs
+        ret = interleave_matrix_outer_dim_from_partitions(ret, pe)
+        assert (
+            ret.shape[0] == pe
+        ), """First dimension after distribution of the
+        rows between PEs is not as expected (pe)"""
+        assert (
+            ret.shape[1] == tmem
+        ), """Second dimension after distribution of the
+        rows between PEs is not as expected (tmem)"""
+        assert (
+            ret.shape[2] == n_thres_steps
+        ), """Third dimension after distribution of the
+        rows between PEs is not as expected (n_thres_steps)"""
+        return ret.reshape(1, pe, tmem, n_thres_steps)
+
+    def generate_params(self, model, path):
+        # weights
+        weights = model.get_initializer(self.onnx_node.input[1])
+        # convert weights into hlslib-compatible format
+        weight_tensor = self.get_hls_compatible_weight_tensor(weights)
+        wdt = self.get_weight_datatype()
+        code_gen_dir = path
+
+        """Saves weights into params.h"""
+        weight_hls_code = numpy_to_hls_code(weight_tensor, wdt, "weights", True, True)
+        # write weights into params.h
+        f_weights = open("{}/params.h".format(code_gen_dir), "w")
+
+        if wdt.bitwidth() != 1:
+            f_weights.write(
+                "const FixedPointWeights<1,{},{},{}> weights = ".format(
+                    wdt.get_hls_datatype_str(),
+                    self.get_nodeattr("PE"),
+                    self.calc_wmem(),
+                )
+            )
+        else:
+            f_weights.write(
+                "const BinaryWeights<1,{},{}> weights = ".format(
+                    self.get_nodeattr("PE"), self.calc_wmem()
+                )
+            )
+        f_weights.write(weight_hls_code)
+        f_weights.close()
+
+        # save thresholds in thresh.h
+        if len(self.onnx_node.input) > 2:
+            thresholds = model.get_initializer(self.onnx_node.input[2])
+            if thresholds is not None:
+                threshold_tensor = self.get_hls_compatible_threshold_tensor(thresholds)
+                tdt = DataType.INT32
+                assert np.vectorize(tdt.allowed)(
+                    threshold_tensor
+                ).all(), "Thresholds are not int"
+                thresholds_hls_code = numpy_to_hls_code(
+                    threshold_tensor, tdt, "thresholds", False, True
+                )
+                # write thresholds into thresh.h
+                f_thresh = open("{}/thresh.h".format(code_gen_dir), "w")
+                tdt_hls = tdt.get_hls_datatype_str()
+                odt = self.get_output_datatype()
+                odt_hls = odt.get_hls_datatype_str()
+                f_thresh.write(
+                    "static ThresholdsActivation<{},{},{},{},{},{},{}> threshs \
+                    = ".format(
+                        self.calc_tmem(),
+                        self.get_nodeattr("PE"),
+                        threshold_tensor.shape[-1],
+                        tdt_hls,
+                        odt_hls,
+                        self.get_nodeattr("ActVal"),
+                        "std::less_equal<%s>" % tdt_hls,
+                    )
+                )
+                f_thresh.write(thresholds_hls_code)
+                f_thresh.close()
+
+    def execute_node(self, context, graph):
+        mode = self.get_nodeattr("exec_mode")
+        node = self.onnx_node
+
+        # TODO ensure codegen dir exists
+        if mode == "cppsim":
+            code_gen_dir = self.get_nodeattr("code_gen_dir_cppsim")
+        elif mode == "rtlsim":
+            code_gen_dir = self.get_nodeattr("code_gen_dir_ipgen")
+        else:
+            raise Exception(
+                """Invalid value for attribute exec_mode! Is currently set to: {}
+            has to be set to one of the following value ("cppsim", "rtlsim")""".format(
+                    mode
+                )
+            )
+
+        # create a npy file fore each input of the node (in_ind is input index)
+        in_ind = 0
+        for inputs in node.input:
+            # it is assumed that the first input of the node is the data input
+            # the second input are the weights
+            # the third input are the thresholds
+            if in_ind == 0:
+                assert (
+                    str(context[inputs].dtype) == "float32"
+                ), """Input datatype is
+                not float32 as expected."""
+                expected_inp_shape = self.get_folded_input_shape()
+                reshaped_input = context[inputs].reshape(expected_inp_shape)
+                # make copy before saving the array
+                reshaped_input = reshaped_input.copy()
+                np.save(
+                    os.path.join(code_gen_dir, "input_{}.npy".format(in_ind)),
+                    reshaped_input,
+                )
+            elif in_ind > 2:
+                raise Exception(
+                    "Unexpected input found for Vector_Vector_Activate_Unit"
+                )
+            in_ind += 1
+
+        if mode == "cppsim":
+            # execute the precompiled model
+            super().exec_precompiled_singlenode_model()
+            # load output npy file
+            super().npy_to_dynamic_output(context)
+            assert (
+                context[node.output[0]].shape == self.get_folded_output_shape()
+            ), """Output shape is not as expected"""
+            # reshape output to have expected shape
+            oshape = self.get_normal_output_shape()
+            context[node.output[0]] = context[node.output[0]].reshape(*oshape)
+        elif mode == "rtlsim":
+            sim = self.get_rtlsim()
+            nbits = self.get_instream_width()
+            idt = self.get_input_datatype()
+            inp = npy_to_rtlsim_input("{}/input_0.npy".format(code_gen_dir), idt, nbits)
+            super().reset_rtlsim(sim)
+            super().toggle_clk(sim)
+            output = self.rtlsim(sim, inp)
+            odt = self.get_output_datatype()
+            target_bits = odt.bitwidth()
+            packed_bits = self.get_outstream_width()
+            out_npy_path = "{}/output.npy".format(code_gen_dir)
+            out_shape = self.get_folded_output_shape()
+            rtlsim_output_to_npy(
+                output, out_npy_path, odt, out_shape, packed_bits, target_bits
+            )
+
+            # load and reshape output
+            output = np.load(out_npy_path)
+            oshape = self.get_normal_output_shape()
+            output = np.asarray([output], dtype=np.float32).reshape(*oshape)
+            context[node.output[0]] = output
+        else:
+            raise Exception(
+                """Invalid value for attribute exec_mode! Is currently set to: {}
+            has to be set to one of the following value ("cppsim", "rtlsim")""".format(
+                    mode
+                )
+            )
+
+    def global_includes(self):
+        self.code_gen_dict["$GLOBALS$"] = ['#include "weights.hpp"']
+        self.code_gen_dict["$GLOBALS$"] += ['#include "activations.hpp"']
+        if self.calc_tmem() != 0:
+            self.code_gen_dict["$GLOBALS$"] += ['#include "thresh.h"']
+
+    def defines(self, var):
+        dim = self.get_nodeattr("Dim")
+        numReps = 1 * dim * dim
+        self.code_gen_dict["$DEFINES$"] = [
+            """#define Channels1 {}\n #define Kernel1 {}\n
+            #define SIMD1 1\n #define PE1 {}\n #define numReps {}""".format(
+                self.get_nodeattr("Channels"),
+                self.get_nodeattr("Kernel"),
+                self.get_nodeattr("PE"),
+                numReps,
+            )
+        ]
+
+    def read_npy_data(self):
+        code_gen_dir = self.get_nodeattr("code_gen_dir_cppsim")
+        dtype = self.get_input_datatype()
+        elem_bits = dtype.bitwidth()
+        packed_bits = self.get_instream_width()
+        packed_hls_type = "ap_uint<%d>" % packed_bits
+        elem_hls_type = dtype.get_hls_datatype_str()
+        npy_type = "float"
+        npy_in = "%s/input_0.npy" % code_gen_dir
+        self.code_gen_dict["$READNPYDATA$"] = []
+        # note: the innermost dim is reversed for the input
+        self.code_gen_dict["$READNPYDATA$"].append(
+            'npy2apintstream<%s, %s, %d, %s>("%s", in0, false);'
+            % (packed_hls_type, elem_hls_type, elem_bits, npy_type, npy_in)
+        )
+
+    def strm_decl(self):
+        self.code_gen_dict["$STREAMDECLARATIONS$"] = []
+        self.code_gen_dict["$STREAMDECLARATIONS$"].append(
+            'hls::stream<ap_uint<{}>> in0 ("in0");'.format(self.get_instream_width())
+        )
+        self.code_gen_dict["$STREAMDECLARATIONS$"].append(
+            'hls::stream<ap_uint<{}>> out ("out");'.format(self.get_outstream_width())
+        )
+
+    def docompute(self):
+        tmpl_args = self.get_template_param_values()
+        if self.calc_tmem() == 0:
+            odtype_hls_str = self.get_output_datatype().get_hls_datatype_str()
+            threshs = "PassThroughActivation<%s>()" % odtype_hls_str
+        else:
+            threshs = "threshs"
+        node = self.onnx_node
+        self.code_gen_dict["$DOCOMPUTE$"] = [
+            """{}<Channels1, Kernel1, SIMD1, PE1, 1, {}, {}, {}>
+            (in0, out, weights, {}, numReps, {});""".format(
+                node.op_type,
+                tmpl_args["TSrcI"],
+                tmpl_args["TDstI"],
+                tmpl_args["TWeightI"],
+                threshs,
+                self.get_nodeattr("resType"),
+            )
+        ]
+
+    def dataoutstrm(self):
+        code_gen_dir = self.get_nodeattr("code_gen_dir_cppsim")
+        dtype = self.get_output_datatype()
+        elem_bits = dtype.bitwidth()
+        packed_bits = self.get_outstream_width()
+        packed_hls_type = "ap_uint<%d>" % packed_bits
+        elem_hls_type = dtype.get_hls_datatype_str()
+        npy_type = "float"
+        npy_out = "%s/output.npy" % code_gen_dir
+        shape = self.get_folded_output_shape()
+        shape_cpp_str = str(shape).replace("(", "{").replace(")", "}")
+
+        # note: the innermost dim is not reversed for the output
+        self.code_gen_dict["$DATAOUTSTREAM$"] = [
+            'apintstream2npy<%s, %s, %d, %s>(out, %s, "%s", false);'
+            % (
+                packed_hls_type,
+                elem_hls_type,
+                elem_bits,
+                npy_type,
+                shape_cpp_str,
+                npy_out,
+            )
+        ]
+
+    def save_as_npy(self):
+        self.code_gen_dict["$SAVEASCNPY$"] = []
+
+    def blackboxfunction(self):
+        self.code_gen_dict["$BLACKBOXFUNCTION$"] = [
+            """void {}(hls::stream<ap_uint<{}>> &in0,
+            hls::stream<ap_uint<{}>> &out
+            )""".format(
+                self.onnx_node.name,
+                self.get_instream_width(),
+                self.get_outstream_width(),
+            )
+        ]
+
+    def pragmas(self):
+        self.code_gen_dict["$PRAGMAS$"] = ["#pragma HLS INTERFACE axis port=in0"]
+        self.code_gen_dict["$PRAGMAS$"].append("#pragma HLS INTERFACE axis port=out")
+        in_fifo_depth = self.get_nodeattr("inFIFODepth")
+        out_fifo_depth = self.get_nodeattr("outFIFODepth")
+        # insert depth pragmas only if specified
+        if in_fifo_depth != 0:
+            self.code_gen_dict["$PRAGMAS$"].append(
+                "#pragma HLS stream depth=%d variable=in0" % in_fifo_depth
+            )
+        if out_fifo_depth != 0:
+            self.code_gen_dict["$PRAGMAS$"].append(
+                "#pragma HLS stream depth=%d variable=out" % out_fifo_depth
+            )
+        self.code_gen_dict["$PRAGMAS$"].append(
+            "#pragma HLS INTERFACE ap_ctrl_none port=return"
+        )
+
+        self.code_gen_dict["$PRAGMAS$"].append('#include "params.h"')
+        # the weight tensor is ap_uint<ch*prec> [PE][WMEM]
+        # partition for parallel access along the PE dimension (dim 1)
+        self.code_gen_dict["$PRAGMAS$"].append(
+            ("#pragma HLS ARRAY_PARTITION variable=weights.m_weights " "complete dim=1")
+        )
+        if self.calc_tmem() != 0:
+            # TODO find a better way of checking for no pregenerated thresholds
+            self.code_gen_dict["$PRAGMAS$"].append(
+                (
+                    "#pragma HLS ARRAY_PARTITION variable=threshs.m_thresholds "
+                    "complete dim=1"
+                )
+            )
+            self.code_gen_dict["$PRAGMAS$"].append(
+                (
+                    "#pragma HLS ARRAY_PARTITION variable=threshs.m_thresholds "
+                    "complete dim=3"
+                )
+            )
diff --git a/src/finn/custom_op/registry.py b/src/finn/custom_op/registry.py
index e4317e02d46df90c8fd0c8854262ca6eb0ea4f48..0cc0e53eaebd94d5e2cd0e030bc107da098e4931 100644
--- a/src/finn/custom_op/registry.py
+++ b/src/finn/custom_op/registry.py
@@ -52,6 +52,9 @@ from finn.custom_op.fpgadataflow.addstreams_batch import AddStreams_Batch
 from finn.custom_op.fpgadataflow.labelselect_batch import LabelSelect_Batch
 from finn.custom_op.quantavgpool2d import QuantAvgPool2d
 from finn.custom_op.fpgadataflow.duplicatestreams_batch import DuplicateStreams_Batch
+from finn.custom_op.fpgadataflow.vector_vector_activate_batch import (
+    Vector_Vector_Activate_Batch,
+)
 from finn.custom_op.fpgadataflow.channelwise_op_batch import ChannelwiseOp_Batch
 from finn.custom_op.fpgadataflow.iodma import IODMA
 
@@ -78,6 +81,7 @@ custom_op["AddStreams_Batch"] = AddStreams_Batch
 custom_op["LabelSelect_Batch"] = LabelSelect_Batch
 custom_op["QuantAvgPool2d"] = QuantAvgPool2d
 custom_op["DuplicateStreams_Batch"] = DuplicateStreams_Batch
+custom_op["Vector_Vector_Activate_Batch"] = Vector_Vector_Activate_Batch
 custom_op["ChannelwiseOp_Batch"] = ChannelwiseOp_Batch
 custom_op["IODMA"] = IODMA
 
diff --git a/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py b/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py
index 4cdf138130f37809357b281155d260fdbd789e12..7b929edc4e672199a2eb6d7c8f427365af0dd9f5 100644
--- a/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py
+++ b/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py
@@ -38,7 +38,6 @@ from finn.transformation.infer_datatypes import InferDataTypes
 from finn.transformation.general import SortGraph
 import finn.core.data_layout as DataLayout
 from finn.util.onnx import nchw_to_nhwc
-import warnings
 from finn.util.basic import get_by_name
 
 
@@ -509,7 +508,7 @@ class InferQuantizedStreamingFCLayer(Transformation):
         graph_modified = False
         for n in graph.node:
             node_ind += 1
-            if n.op_type == "MatMul":
+            if n.op_type == "MatMul" and model.get_tensor_sparsity(n.input[1]) is None:
                 mm_input = n.input[0]
                 mm_weight = n.input[1]
                 mm_output = n.output[0]
@@ -628,6 +627,150 @@ class InferQuantizedStreamingFCLayer(Transformation):
         return (model, graph_modified)
 
 
+class InferVVAU(Transformation):
+    """Convert MatMul layers with quantized inputs and weights to
+    Vector_Vector_Activate_Batch layers, if the sparsity annotation
+    of the weight matrix indicates that the MatMul layer belongs to
+    a depthwise convolution. Any immediately following MultiThreshold
+    layers will also be absorbed into the VVAU."""
+
+    def apply(self, model):
+        graph = model.graph
+        node_ind = 0
+        graph_modified = False
+        for n in graph.node:
+            node_ind += 1
+            if (
+                n.op_type == "MatMul"
+                and model.get_tensor_sparsity(n.input[1]) is not None
+            ):
+                sparsity = model.get_tensor_sparsity(n.input[1])
+                try:
+                    k = sparsity["dw"]["kernel_shape"]
+                except KeyError:
+                    raise Exception(
+                        """Sparsity doesn't indicate that MatMul
+                        belongs to a depthwise convolution."""
+                    )
+
+                mm_input = n.input[0]
+                mm_weight = n.input[1]
+                mm_output = n.output[0]
+                mm_in_shape = model.get_tensor_shape(mm_input)
+                mm_out_shape = model.get_tensor_shape(mm_output)
+                idt = model.get_tensor_datatype(mm_input)
+                wdt = model.get_tensor_datatype(mm_weight)
+                if idt.is_integer() and wdt.is_integer():
+                    mm_output = n.output[0]
+                    W = model.get_initializer(mm_weight)
+                    # infer dense weight tensor from sparse weight matrix
+                    # kernel size k which was extracted above and the value of
+                    # the channels is used.
+                    # the weight matrix has a shape of (k * k * Channels, Channels)
+                    # we need to reverse the creation of the sparse weight matrix
+                    # to achieve a weight tensor of shape (Channels, 1, k, k)
+                    channels = int(W.shape[1])
+                    # transpose to achieve a shape of (k * k * Channels, Channels)
+                    W = W.T
+                    # reshape to (Channels, k, k, Channels) to transpose afterwards
+                    # to (Channels, Channels, k, k)
+                    W = W.reshape(channels, k, k, channels)
+                    W = W.transpose(0, 3, 1, 2)
+                    # now we can extract the values using a for loop over the channels
+                    # and fill a zero numpy array in the correct shape
+                    w_tensor = np.zeros((channels, 1, k, k))
+                    for ch in range(channels):
+                        w_tensor[ch][0] = W[ch][ch]
+                    model.set_initializer(mm_weight, w_tensor)
+                    model.set_tensor_shape(mm_weight, (channels, 1, k, k))
+                    # create node with pe=channels as default
+                    pe = channels
+                    assert (
+                        channels % pe == 0
+                    ), "Requirement Channels divisable by PE is violated."
+                    # see if we have any following thresholds
+                    consumer = model.find_consumer(mm_output)
+                    if consumer is not None and consumer.op_type == "MultiThreshold":
+                        # create VVAU (i.e. including activation)
+                        mt_output = consumer.output[0]
+                        mt_out_shape = model.get_tensor_shape(mt_output)
+                        mt_thres = consumer.input[1]
+                        T = model.get_initializer(mt_thres)
+                        assert (
+                            T.shape[0] == 1 or T.shape[0] == channels
+                        ), """First dimension of
+                        thresholds neither 1 nor Channels."""
+                        odt = model.get_tensor_datatype(mt_output)
+                        scale = getCustomOp(consumer).get_nodeattr("out_scale")
+                        assert (
+                            scale == 1.0
+                        ), "out_scale must be equal to 1.0 for HLS conversion."
+                        actval = getCustomOp(consumer).get_nodeattr("out_bias")
+                        assert (
+                            int(actval) == actval
+                        ), "out_bias must be integer for HLS conversion."
+                        actval = int(actval)
+                        assert (not odt.signed()) or (
+                            actval < 0
+                        ), "Signed output requres actval < 0"
+                        model.set_tensor_shape(mm_input, mm_in_shape)
+                        model.set_tensor_shape(mt_output, mt_out_shape)
+                        # create and insert new Vector_Vector_Activate_Batch node
+                        new_node = helper.make_node(
+                            "Vector_Vector_Activate_Batch",
+                            [mm_input, mm_weight, mt_thres],
+                            [mt_output],
+                            domain="finn",
+                            backend="fpgadataflow",
+                            resType="ap_resource_lut()",
+                            PE=pe,
+                            Dim=mm_in_shape[1],
+                            Channels=channels,
+                            Kernel=k,
+                            inputDataType=idt.name,
+                            weightDataType=wdt.name,
+                            outputDataType=odt.name,
+                            ActVal=actval,
+                            noActivation=0,
+                        )
+                        graph.node.insert(node_ind, new_node)
+                        # remove old nodes
+                        graph.node.remove(n)
+                        graph.node.remove(consumer)
+                        graph_modified = True
+                    else:
+                        # no activation, matmul only
+                        odt = model.get_tensor_datatype(mm_output)
+                        model.set_tensor_shape(mm_input, mm_in_shape)
+                        model.set_tensor_shape(mm_output, mm_out_shape)
+                        # create and insert new VVAU node
+                        new_node = helper.make_node(
+                            "Vector_Vector_Activate_Batch",
+                            [mm_input, mm_weight],
+                            [mm_output],
+                            domain="finn",
+                            backend="fpgadataflow",
+                            resType="ap_resource_lut()",
+                            PE=pe,
+                            Dim=mm_in_shape[1],
+                            Channels=channels,
+                            Kernel=k,
+                            inputDataType=idt.name,
+                            weightDataType=wdt.name,
+                            outputDataType=odt.name,
+                            ActVal=0,
+                            noActivation=1,
+                        )
+                        graph.node.insert(node_ind, new_node)
+                        # remove old node
+                        graph.node.remove(n)
+                        graph_modified = True
+        if graph_modified:
+            model = model.transform(InferShapes())
+            model = model.transform(InferDataTypes())
+        return (model, graph_modified)
+
+
 class InferThresholdingLayer(Transformation):
     """Convert any MultiThreshold into a standalone thresholding HLS layer."""
 
diff --git a/src/finn/transformation/lower_convs_to_matmul.py b/src/finn/transformation/lower_convs_to_matmul.py
index aa231a43a3865a161a501b4997ff2f538800554f..e5a1f778d0cac48925ecd97ae8b970f7bdab9c4f 100644
--- a/src/finn/transformation/lower_convs_to_matmul.py
+++ b/src/finn/transformation/lower_convs_to_matmul.py
@@ -26,6 +26,7 @@
 # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 
+import numpy as np
 from onnx import TensorProto
 from onnx import helper
 
@@ -54,12 +55,34 @@ class LowerConvsToMatMul(Transformation):
                 k = get_by_name(n.attribute, "kernel_shape").ints[-1]
                 pad = get_by_name(n.attribute, "pads").ints[-1]
                 stride = get_by_name(n.attribute, "strides").ints[-1]
+                group = get_by_name(n.attribute, "group").i
                 weight_name = n.input[1]
                 W_conv = model.get_initializer(weight_name)
-                ifm_ch = W_conv.shape[1]
-                ofm_ch = W_conv.shape[0]
+                ifm_ch = model.get_tensor_shape(n.input[0])[1]  # assume NCHW
+                ofm_ch = model.get_tensor_shape(n.output[0])[1]  # assume NCHW
                 ifm_dim = model.get_tensor_shape(n.input[0])[-1]  # assume NCHW
                 ofm_dim = model.get_tensor_shape(n.output[0])[-1]  # assume NCHW
+
+                # if depthwise conv create sparse matrix and variable "dw"
+                # to store as attribute in Im2Col that indicates that the created
+                # Im2Col node belongs to a depthwise convolution
+                dw = False
+                if group == ifm_ch and ofm_ch == ifm_ch:
+                    W_sparse = np.zeros((ofm_ch, ifm_ch, k, k))
+                    for ch in range(ifm_ch):
+                        W_sparse[ch][ch] = W_conv[ch][0]
+                    W_conv = W_sparse.astype(np.float32)
+                    # we need to store information of the
+                    # sparsity of the weight matrix. For this
+                    # we use the sparsity annotation of the
+                    # weight tensor
+                    sparsity = {"dw": {"kernel_shape": k}}
+                    model.set_tensor_sparsity(weight_name, sparsity)
+                    # additionally create variable "dw" to store
+                    # as attribute in Im2Col that indicates that the created
+                    # Im2Col node belongs to a depthwise convolution
+                    dw = True
+
                 # reuse conv weights for new matmul weights
                 # conv weights are [OFM][IFM][k][k]
                 # first convert to [OFM][k][k][IFM] (to remain compatible with
@@ -70,6 +93,7 @@ class LowerConvsToMatMul(Transformation):
                 # transpose to get ONNX-compatible [k*k*IFM][OFM] matrix
                 W_matmul = W_matmul.T
                 model.set_initializer(weight_name, W_matmul)
+
                 # create new intermediate values
                 inp_trans_out = helper.make_tensor_value_info(
                     model.make_new_valueinfo_name(),
@@ -121,6 +145,7 @@ class LowerConvsToMatMul(Transformation):
                         kernel_size=k,
                         pad_amount=pad,
                         input_shape="(1,{},{},{})".format(ifm_dim, ifm_dim, ifm_ch),
+                        depthwise=dw,
                     )
 
                 # do matmul
diff --git a/tests/brevitas/test_brevitas_QConv2d.py b/tests/brevitas/test_brevitas_QConv2d.py
new file mode 100644
index 0000000000000000000000000000000000000000..198f1e7961a9e160589989b8b34b45b5fda53817
--- /dev/null
+++ b/tests/brevitas/test_brevitas_QConv2d.py
@@ -0,0 +1,76 @@
+import pytest
+import os
+import numpy as np
+import torch
+import brevitas.onnx as bo
+from brevitas.nn import QuantConv2d
+from brevitas.core.restrict_val import RestrictValueType
+from brevitas.core.quant import QuantType
+from brevitas.core.scaling import ScalingImplType
+from brevitas.core.stats import StatsOp
+
+from finn.core.modelwrapper import ModelWrapper
+from finn.core.datatype import DataType
+import finn.core.onnx_exec as oxe
+from finn.transformation.infer_shapes import InferShapes
+from finn.util.basic import gen_finn_dt_tensor
+
+export_onnx_path = "test_brevitas_conv.onnx"
+
+
+@pytest.mark.parametrize("dw", [False, True])
+@pytest.mark.parametrize("in_channels", [32])
+def test_brevitas_QConv2d(dw, in_channels):
+    ishape = (1, 32, 111, 111)
+    if dw is True:
+        groups = in_channels
+        out_channels = in_channels
+        kernel_size = 3
+        padding = 1
+        stride = 1
+        w_shape = (32, 1, 3, 3)
+
+    else:
+        groups = 1
+        out_channels = 64
+        kernel_size = 1
+        padding = 0
+        stride = 1
+        w_shape = (64, 32, 1, 1)
+
+    b_conv = QuantConv2d(
+        in_channels=in_channels,
+        out_channels=out_channels,
+        groups=groups,
+        kernel_size=kernel_size,
+        padding=padding,
+        stride=stride,
+        bias=False,
+        bias_quant_type=QuantType.FP,
+        compute_output_bit_width=False,
+        compute_output_scale=False,
+        weight_bit_width=4,
+        weight_quant_type=QuantType.INT,
+        weight_scaling_impl_type=ScalingImplType.STATS,
+        weight_scaling_stats_op=StatsOp.MAX,
+        weight_scaling_per_output_channel=True,
+        weight_restrict_scaling_type=RestrictValueType.LOG_FP,
+        weight_narrow_range=True,
+        weight_scaling_min_val=2e-16,
+    )
+    weight_tensor = gen_finn_dt_tensor(DataType.INT4, w_shape)
+    b_conv.weight = torch.nn.Parameter(torch.from_numpy(weight_tensor).float())
+
+    bo.export_finn_onnx(b_conv, ishape, export_onnx_path)
+    model = ModelWrapper(export_onnx_path)
+    model = model.transform(InferShapes())
+    inp_tensor = np.random.uniform(low=-1.0, high=1.0, size=ishape).astype(np.float32)
+    idict = {model.graph.input[0].name: inp_tensor}
+    odict = oxe.execute_onnx(model, idict, True)
+    produced = odict[model.graph.output[0].name]
+    inp_tensor = torch.from_numpy(inp_tensor).float()
+    b_conv.eval()
+    expected = b_conv.forward(inp_tensor).detach().numpy()
+
+    assert np.isclose(produced, expected, atol=1e-3).all()
+    os.remove(export_onnx_path)
diff --git a/tests/fpgadataflow/test_convert_to_hls_conv_layer.py b/tests/fpgadataflow/test_convert_to_hls_conv_layer.py
index d69e4c3231a3381a9eecab2a551455714dd26720..9be9c904b0be0a8c1ab2421590922ae6cf2e1295 100644
--- a/tests/fpgadataflow/test_convert_to_hls_conv_layer.py
+++ b/tests/fpgadataflow/test_convert_to_hls_conv_layer.py
@@ -32,29 +32,36 @@ from finn.analysis.fpgadataflow.exp_cycles_per_layer import exp_cycles_per_layer
 @pytest.mark.parametrize(
     "conv_config", [(1, 2, 0), (1, 3, 0), (3, 2, 1), (3, 1, 0), (3, 1, 1), (5, 2, 1)]
 )
+@pytest.mark.parametrize("depthwise", [False, True])
 @pytest.mark.parametrize("exec_mode", ["cppsim", "rtlsim"])
 @pytest.mark.slow
 @pytest.mark.vivado
-def test_convert_to_hls_conv_layer(conv_config, exec_mode):
+def test_convert_to_hls_conv_layer(conv_config, depthwise, exec_mode):
     kernel_size, stride, pad = conv_config
     np.random.seed(0)
     idt = DataType.UINT4
 
     in_feature_dim = 7
     in_chn = 16
-    out_chn = 20
+
+    if depthwise is True:
+        group = out_chn = in_chn
+        conv_param_shape = [out_chn, 1, kernel_size, kernel_size]
+    else:
+        group = 1
+        out_chn = 20
+        conv_param_shape = [out_chn, in_chn, kernel_size, kernel_size]
 
     out_feature_dim = compute_conv_output_dim(in_feature_dim, kernel_size, stride, pad)
 
     input_shape = [1, in_chn, in_feature_dim, in_feature_dim]
     output_shape = [1, out_chn, out_feature_dim, out_feature_dim]
 
-    conv_param_shape = [out_chn, in_chn, kernel_size, kernel_size]
     conv_weight_dt = DataType.UINT4
 
     conv_config = {}
     conv_config["dilations"] = [1, 1]
-    conv_config["group"] = 1
+    conv_config["group"] = group
     conv_config["kernel_shape"] = [kernel_size, kernel_size]
     conv_config["pads"] = [pad, pad, pad, pad]
     conv_config["strides"] = [stride, stride]
@@ -88,6 +95,18 @@ def test_convert_to_hls_conv_layer(conv_config, exec_mode):
 
     new_model = model.transform(LowerConvsToMatMul())
     new_model = new_model.transform(to_hls.InferConvInpGen())
+    if depthwise is True:
+        new_model = new_model.transform(to_hls.InferVVAU())
+    else:
+        new_model = new_model.transform(to_hls.InferQuantizedStreamingFCLayer())
+        fc_node = new_model.get_nodes_by_op_type("StreamingFCLayer_Batch")[0]
+        fc_inst = getCustomOp(fc_node)
+        mw = fc_inst.get_nodeattr("MW")
+        mh = fc_inst.get_nodeattr("MH")
+        pe_cands = list(filter(lambda x: mh % x == 0, range(2, mh + 1)))
+        simd_cands = list(filter(lambda x: mw % x == 0, range(2, mw + 1)))
+        fc_inst.set_nodeattr("PE", pe_cands[0])
+        fc_inst.set_nodeattr("SIMD", simd_cands[0])
 
     new_model = new_model.transform(GiveUniqueNodeNames())
     new_model = new_model.transform(InferShapes())
@@ -125,3 +144,12 @@ def test_convert_to_hls_conv_layer(conv_config, exec_mode):
         padding_node = new_model.get_nodes_by_op_type("FMPadding_Batch")[0]
         padding_inst = getCustomOp(padding_node)
         assert padding_inst.get_nodeattr("SIMD") == in_chn
+
+    if depthwise is True and exec_mode == "rtlsim":
+        node = new_model.get_nodes_by_op_type("Vector_Vector_Activate_Batch")[0]
+        inst = getCustomOp(node)
+        cycles_rtlsim = inst.get_nodeattr("cycles_rtlsim")
+        exp_cycles_dict = new_model.analysis(exp_cycles_per_layer)
+        exp_cycles = exp_cycles_dict[node.name]
+        assert np.isclose(exp_cycles, cycles_rtlsim, atol=11)
+        assert exp_cycles != 0
diff --git a/tests/fpgadataflow/test_depthwise_convolution.py b/tests/fpgadataflow/test_depthwise_convolution.py
new file mode 100644
index 0000000000000000000000000000000000000000..f530926e46ac5c116c3f15688c7f2face7954a30
--- /dev/null
+++ b/tests/fpgadataflow/test_depthwise_convolution.py
@@ -0,0 +1,249 @@
+# Copyright (c) 2020, Xilinx
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+#   list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+#   this list of conditions and the following disclaimer in the documentation
+#   and/or other materials provided with the distribution.
+#
+# * Neither the name of FINN nor the names of its
+#   contributors may be used to endorse or promote products derived from
+#   this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import pytest
+import onnx.helper as oh
+from onnx import TensorProto
+import numpy as np
+
+from finn.core.modelwrapper import ModelWrapper
+from finn.core.datatype import DataType
+from finn.transformation.infer_shapes import InferShapes
+from finn.transformation.fpgadataflow.convert_to_hls_layers import (
+    InferConvInpGen,
+    InferVVAU,
+)
+from finn.transformation.fpgadataflow.prepare_cppsim import PrepareCppSim
+from finn.transformation.fpgadataflow.compile_cppsim import CompileCppSim
+from finn.transformation.fpgadataflow.set_exec_mode import SetExecMode
+
+import finn.core.onnx_exec as oxe
+from finn.custom_op.im2col import compute_conv_output_dim
+from finn.util.basic import calculate_signed_dot_prod_range, gen_finn_dt_tensor
+from finn.custom_op.registry import getCustomOp
+
+from finn.transformation.fpgadataflow.prepare_ip import PrepareIP
+from finn.transformation.fpgadataflow.hlssynth_ip import HLSSynthIP
+from finn.transformation.general import GiveUniqueNodeNames
+from finn.transformation.fpgadataflow.prepare_rtlsim import PrepareRTLSim
+from finn.transformation.fpgadataflow.replace_verilog_relpaths import (
+    ReplaceVerilogRelPaths,
+)
+
+
+def set_up_reference_model(act, idt, wdt, k, ifm_dim, ifm_ch, stride, padding):
+
+    # set up reference model consisting of Im2Col + MatMul (+ MultiThreshold)
+    ofm_ch = ifm_ch
+    ofm_dim = compute_conv_output_dim(ifm_dim, k, stride, pad=padding)
+
+    if act is None:
+        odt = DataType.INT32
+    else:
+        odt = act
+        out_act = oh.make_tensor_value_info(
+            "out_act", TensorProto.FLOAT, [1, ofm_dim, ofm_dim, ofm_ch]
+        )
+        T = oh.make_tensor_value_info("T", TensorProto.FLOAT, [ofm_ch, 15])
+        tdt = DataType.INT32
+        thresh_node = oh.make_node(
+            "MultiThreshold",
+            domain="finn",
+            inputs=["outp", "T"],
+            outputs=["out_act"],
+            data_layout="NHWC",
+            out_dtype=odt.name,
+            out_scale=1.0,
+            out_bias=0.0,
+        )
+
+    # set up onnx model
+    inp = oh.make_tensor_value_info(
+        "inp", TensorProto.FLOAT, [1, ifm_dim, ifm_dim, ifm_ch]
+    )
+    outp = oh.make_tensor_value_info(
+        "outp", TensorProto.FLOAT, [1, ofm_dim, ofm_dim, ofm_ch]
+    )
+
+    W_sparse = oh.make_tensor_value_info(
+        "W_sparse", TensorProto.FLOAT, [ifm_ch * k * k, ofm_ch]
+    )
+
+    im2col_node = oh.make_node(
+        "Im2Col",
+        domain="finn",
+        inputs=["inp"],
+        outputs=["im2col_out"],
+        kernel_size=k,
+        stride=stride,
+        pad_amount=padding,
+        input_shape="(1, {}, {}, {})".format(ifm_dim, ifm_dim, ifm_ch),
+        depthwise=1,
+    )
+
+    matmul_node = oh.make_node(
+        "MatMul", inputs=["im2col_out", "W_sparse"], outputs=["outp"]
+    )
+
+    if act is None:
+        node_list = [im2col_node, matmul_node]
+        global_out = outp
+        value_info = [W_sparse]
+    else:
+        node_list = [im2col_node, matmul_node, thresh_node]
+        global_out = out_act
+        value_info = [W_sparse, T]
+
+    graph = oh.make_graph(
+        nodes=node_list,
+        name="lowered_dw_cnv_graph",
+        inputs=[inp],
+        outputs=[global_out],
+        value_info=value_info,
+    )
+    model = oh.make_model(graph, producer_name="lowered_dw_cnv-model")
+    model = ModelWrapper(model)
+
+    # initialize model
+    model.set_tensor_datatype("inp", idt)
+    model.set_tensor_datatype(model.graph.output[0].name, odt)
+    model.set_tensor_datatype("W_sparse", wdt)
+
+    w_tensor = gen_finn_dt_tensor(wdt, [ofm_ch, 1, k, k])
+    # create sparse matrix
+    W_matrix = np.zeros((ofm_ch, ifm_ch, k, k))
+    for ch in range(ifm_ch):
+        W_matrix[ch][ch] = w_tensor[ch][0]
+    W_matrix = W_matrix.astype(np.float32)
+    W_matrix = W_matrix.transpose(0, 2, 3, 1)
+    W_matrix = W_matrix.reshape(ofm_ch, ifm_ch * k * k)
+
+    model.set_initializer("W_sparse", W_matrix.T)
+    sparsity = {"dw": {"kernel_shape": k}}
+    model.set_tensor_sparsity("W_sparse", sparsity)
+
+    if act is not None:
+        (min, max) = calculate_signed_dot_prod_range(idt, wdt, ifm_ch * k * k)
+        n_steps = odt.get_num_possible_values() - 1
+        T_values = np.random.randint(min, max - 1, (ofm_ch, n_steps)).astype(np.float32)
+        # provide non-decreasing thresholds
+        T_values = np.sort(T_values, axis=1)
+        model.set_initializer("T", T_values)
+        model.set_tensor_datatype("T", tdt)
+
+    model = model.transform(InferShapes())
+
+    return model
+
+
+# PE
+@pytest.mark.parametrize("pe", [1, 2, 4])
+# Output activation
+@pytest.mark.parametrize("act", [None, DataType.UINT4])
+# kernel size
+@pytest.mark.parametrize("k", [2, 4])
+# stride
+@pytest.mark.parametrize("stride", [1, 2])
+# padding
+@pytest.mark.parametrize("padding", [0, 1])
+@pytest.mark.slow
+@pytest.mark.vivado
+def test_depthwise_conv_hls_cppsim(act, pe, k, stride, padding):
+    idt = wdt = DataType.INT4
+    ifm_dim = 6
+    ifm_ch = 4
+
+    # set up reference model consisting of Im2Col + MatMul (+ MultiThreshold)
+    model = set_up_reference_model(act, idt, wdt, k, ifm_dim, ifm_ch, stride, padding)
+
+    input_tensor = gen_finn_dt_tensor(idt, [1, ifm_dim, ifm_dim, ifm_ch])
+    input_dict = {"inp": input_tensor}
+
+    new_model = model.transform(InferConvInpGen())
+    new_model = new_model.transform(InferVVAU())
+
+    # set SIMD in ConvInputGen node and PE in VVAU node
+
+    for n in new_model.graph.node:
+        if n.op_type == "ConvolutionInputGenerator":
+            convinputgen_node = getCustomOp(n)
+            convinputgen_node.set_nodeattr("SIMD", pe)
+        elif n.op_type == "Vector_Vector_Activate_Batch":
+            vvau_node = getCustomOp(n)
+            vvau_node.set_nodeattr("PE", pe)
+    new_model = new_model.transform(SetExecMode("cppsim"))
+    new_model = new_model.transform(PrepareCppSim())
+    new_model = new_model.transform(CompileCppSim())
+
+    assert oxe.compare_execution(model, new_model, input_dict)
+
+
+# PE
+@pytest.mark.parametrize("pe", [1, 2, 4])
+# Output activation
+@pytest.mark.parametrize("act", [None, DataType.UINT4])
+# kernel size
+@pytest.mark.parametrize("k", [2, 4])
+# stride
+@pytest.mark.parametrize("stride", [1, 2])
+# padding
+@pytest.mark.parametrize("padding", [0, 1])
+@pytest.mark.slow
+@pytest.mark.vivado
+def test_depthwise_conv_hls_rtlsim(act, pe, k, stride, padding):
+    idt = wdt = DataType.INT4
+    ifm_dim = 6
+    ifm_ch = 4
+
+    # set up reference model consisting of Im2Col + MatMul (+ MultiThreshold)
+    model = set_up_reference_model(act, idt, wdt, k, ifm_dim, ifm_ch, stride, padding)
+
+    input_tensor = gen_finn_dt_tensor(idt, [1, ifm_dim, ifm_dim, ifm_ch])
+    input_dict = {"inp": input_tensor}
+
+    new_model = model.transform(InferConvInpGen())
+    new_model = new_model.transform(InferVVAU())
+
+    # set SIMD in ConvInputGen node and PE in VVAU node
+
+    for n in new_model.graph.node:
+        if n.op_type == "ConvolutionInputGenerator":
+            convinputgen_node = getCustomOp(n)
+            convinputgen_node.set_nodeattr("SIMD", pe)
+        elif n.op_type == "Vector_Vector_Activate_Batch":
+            vvau_node = getCustomOp(n)
+            vvau_node.set_nodeattr("PE", pe)
+
+    new_model = new_model.transform(SetExecMode("rtlsim"))
+    new_model = new_model.transform(GiveUniqueNodeNames())
+    new_model = new_model.transform(PrepareIP("xc7z020clg400-1", 5))
+    new_model = new_model.transform(HLSSynthIP())
+    new_model = new_model.transform(ReplaceVerilogRelPaths())
+    new_model = new_model.transform(PrepareRTLSim())
+
+    assert oxe.compare_execution(model, new_model, input_dict)
diff --git a/tests/transformation/test_conv_lowering.py b/tests/transformation/test_conv_lowering.py
index 16c574b29b55e314b06661b28e4bb869bd6b7996..ab545d483321f8c52625b5401828277987bba3a9 100644
--- a/tests/transformation/test_conv_lowering.py
+++ b/tests/transformation/test_conv_lowering.py
@@ -26,6 +26,7 @@
 # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 
+import pytest
 import onnx.helper as oh
 from onnx import TensorProto
 import os
@@ -34,12 +35,16 @@ import brevitas.onnx as bo
 import numpy as np
 
 from finn.core.modelwrapper import ModelWrapper
+from finn.core.datatype import DataType
 from finn.transformation.fold_constants import FoldConstants
 from finn.transformation.infer_shapes import InferShapes
 from finn.util.test import get_test_model_trained
 from finn.transformation.lower_convs_to_matmul import LowerConvsToMatMul
 from finn.transformation.double_to_single_float import DoubleToSingleFloat
 import finn.core.onnx_exec as oxe
+from finn.custom_op.im2col import compute_conv_output_dim
+from finn.util.basic import gen_finn_dt_tensor
+from finn.custom_op.registry import getCustomOp
 
 export_onnx_path = "test_conv_lowering.onnx"
 
@@ -68,6 +73,76 @@ def test_conv_lowering_cnv_w1a1():
     os.remove(export_onnx_path)
 
 
+# input datatype
+@pytest.mark.parametrize("idt", [DataType.INT2, DataType.INT4])
+# kernel size
+@pytest.mark.parametrize("k", [2, 4])
+# input dimension
+@pytest.mark.parametrize("ifm_dim", [4, 6])
+# input channels
+@pytest.mark.parametrize("ifm_ch", [2, 3])
+# stride
+@pytest.mark.parametrize("stride", [1, 2])
+# padding
+@pytest.mark.parametrize("padding", [[0, 0, 0, 0], [1, 1, 1, 1]])
+def test_depthwise_conv_lowering(idt, k, ifm_dim, ifm_ch, stride, padding):
+    wdt = idt
+    odt = DataType.INT32
+    ofm_ch = ifm_ch
+    ofm_dim = compute_conv_output_dim(ifm_dim, k, stride, pad=padding[0])
+
+    # set up onnx model
+    inp = oh.make_tensor_value_info(
+        "inp", TensorProto.FLOAT, [1, ifm_ch, ifm_dim, ifm_dim]
+    )
+    outp = oh.make_tensor_value_info(
+        "outp", TensorProto.FLOAT, [1, ofm_ch, ofm_dim, ofm_dim]
+    )
+
+    W = oh.make_tensor_value_info("W", TensorProto.FLOAT, [ofm_ch, 1, k, k])
+
+    dw_cnv = oh.make_node(
+        "Conv",
+        inputs=["inp", "W"],
+        outputs=["outp"],
+        kernel_shape=[k, k],
+        pads=padding,
+        strides=[stride, stride],
+        group=ifm_ch,
+    )
+    graph = oh.make_graph(
+        nodes=[dw_cnv],
+        name="dw_cnv_graph",
+        inputs=[inp],
+        outputs=[outp],
+        value_info=[W],
+    )
+
+    model = oh.make_model(graph, producer_name="dws_cnv-model")
+    model = ModelWrapper(model)
+    model.set_tensor_datatype("inp", idt)
+    model.set_tensor_datatype("outp", odt)
+    model.set_tensor_datatype("W", wdt)
+    w_tensor = gen_finn_dt_tensor(wdt, [ofm_ch, 1, k, k])
+    model.set_initializer("W", w_tensor)
+    model = model.transform(InferShapes())
+
+    input_tensor = gen_finn_dt_tensor(idt, [1, ifm_ch, ifm_dim, ifm_dim])
+    input_dict = {"inp": input_tensor}
+    output_dict = oxe.execute_onnx(model, input_dict)
+    expected = output_dict["outp"]
+
+    model = model.transform(LowerConvsToMatMul())
+    output_dict = oxe.execute_onnx(model, input_dict)
+    produced = output_dict["outp"]
+    assert (produced == expected).all()
+
+    # check if created nodes have attributes that indicate depthwise conv
+    assert model.get_tensor_sparsity("W") is not None
+    im2col_node = getCustomOp(model.graph.node[1])
+    assert im2col_node.get_nodeattr("depthwise") == 1
+
+
 def test_conv_lowering_conv_1x1():
     np.random.seed(0)