diff --git a/docker/finn_entrypoint.sh b/docker/finn_entrypoint.sh
index 65513f3148e0fed2583d02e1eba249bc9a1f2f6e..0074cce02f7de57dc778e0b671c484233df72a8a 100644
--- a/docker/finn_entrypoint.sh
+++ b/docker/finn_entrypoint.sh
@@ -13,7 +13,7 @@ gecho () {
 
 # checkout the correct dependency repo commits
 # the repos themselves are cloned in the Dockerfile
-BREVITAS_COMMIT=7696326e5f279cacffd5b6ac8d9e8d81deec3978
+BREVITAS_COMMIT=026a509186b7e7b0b65d46a2f905043d41069306
 CNPY_COMMIT=4e8810b1a8637695171ed346ce68f6984e585ef4
 HLSLIB_COMMIT=13e9b0772a27a3a1efc40c878d8e78ed09efb716
 PYVERILATOR_COMMIT=c97a5ba41bbc7c419d6f25c74cdf3bdc3393174f
diff --git a/src/finn/custom_op/quantavgpool2d.py b/src/finn/custom_op/quantavgpool2d.py
new file mode 100644
index 0000000000000000000000000000000000000000..3bc328a9f4f6670041d33491d58af6c553bafac9
--- /dev/null
+++ b/src/finn/custom_op/quantavgpool2d.py
@@ -0,0 +1,83 @@
+import numpy as np
+from onnx import TensorProto, helper
+import onnxruntime as rt
+
+from finn.custom_op import CustomOp
+from finn.core.datatype import DataType
+
+
+class QuantAvgPool2d(CustomOp):
+    """Class that corresponds to the quantized average pooling
+    layer from brevitas"""
+
+    def get_nodeattr_types(self):
+        return {
+            "stride": ("i", True, 1),
+            "kernel": ("i", True, 1),
+            "ibits": ("i", True, 1),
+            "obits": ("i", True, 1),
+            "signed": ("i", True, 0),
+        }
+
+    def make_shape_compatible_op(self, model):
+        node = self.onnx_node
+        k = self.get_nodeattr("kernel")
+        s = self.get_nodeattr("stride")
+        return helper.make_node(
+            "AveragePool",
+            inputs=[node.input[0]],
+            outputs=[node.output[0]],
+            kernel_shape=[k, k],
+            strides=[s, s],
+        )
+
+    def infer_node_datatype(self, model):
+        node = self.onnx_node
+        bw = self.get_nodeattr("obits")
+        if bw in [2, 4, 8, 16, 32]:
+            if self.get_nodeattr("signed") == 0:
+                dtype = DataType["UINT%d" % bw]
+            else:
+                dtype = DataType["INT%d" % bw]
+        else:
+            raise Exception("Unsupported output datatype for QuantAvgPool2d")
+        model.set_tensor_datatype(node.output[0], dtype)
+
+    def execute_node(self, context, graph):
+        # create a standard average pooling node to help calculate the result
+        node = self.onnx_node
+        k = self.get_nodeattr("kernel")
+        s = self.get_nodeattr("stride")
+        ishape = context[node.input[0]].shape
+        oshape = context[node.output[0]].shape
+        inp = helper.make_tensor_value_info(node.input[0], TensorProto.FLOAT, ishape)
+        outp = helper.make_tensor_value_info(node.output[0], TensorProto.FLOAT, oshape)
+        node_avgpool = helper.make_node(
+            "AveragePool",
+            inputs=[node.input[0]],
+            outputs=[node.output[0]],
+            kernel_shape=[k, k],
+            strides=[s, s],
+        )
+        graph_avgpool = helper.make_graph(
+            nodes=[node_avgpool],
+            name="single-avgpool-exec",
+            inputs=[inp],
+            outputs=[outp],
+        )
+        model_avgpool = helper.make_model(graph_avgpool)
+        idict = {node.input[0]: context[node.input[0]]}
+        sess = rt.InferenceSession(model_avgpool.SerializeToString())
+        result_temp = sess.run(None, idict)
+        # remove scaling introduced by average
+        result_temp = result_temp[0] * (k * k)
+        ibits = self.get_nodeattr("ibits")
+        max_value = 2 ** ibits - 1
+        max_value = max_value * k * k
+        max_bit_width = int(max_value).bit_length()
+        shift_bits = max_bit_width - self.get_nodeattr("obits")
+        result = np.right_shift(result_temp.astype(int), shift_bits)
+        context[node.output[0]] = result.astype(np.float32)
+
+    def verify_node(self):
+        pass
diff --git a/src/finn/custom_op/registry.py b/src/finn/custom_op/registry.py
index 614a3d7ffd70d0b102bad2b76177a2d3b32765c7..2dae826cf9712bef17d0053a0878c41ef51fec36 100644
--- a/src/finn/custom_op/registry.py
+++ b/src/finn/custom_op/registry.py
@@ -48,6 +48,7 @@ from finn.custom_op.fpgadataflow.fmpadding import FMPadding_Batch
 from finn.custom_op.fpgadataflow.thresholding_batch import Thresholding_Batch
 from finn.custom_op.fpgadataflow.addstreams_batch import AddStreams_Batch
 from finn.custom_op.fpgadataflow.labelselect_batch import LabelSelect_Batch
+from finn.custom_op.quantavgpool2d import QuantAvgPool2d
 from finn.custom_op.fpgadataflow.duplicatestreams_batch import DuplicateStreams_Batch
 
 # create a mapping of all known CustomOp names and classes
@@ -69,6 +70,7 @@ custom_op["FMPadding_Batch"] = FMPadding_Batch
 custom_op["Thresholding_Batch"] = Thresholding_Batch
 custom_op["AddStreams_Batch"] = AddStreams_Batch
 custom_op["LabelSelect_Batch"] = LabelSelect_Batch
+custom_op["QuantAvgPool2d"] = QuantAvgPool2d
 custom_op["DuplicateStreams_Batch"] = DuplicateStreams_Batch
 
 
diff --git a/src/finn/transformation/infer_datatypes.py b/src/finn/transformation/infer_datatypes.py
index 1acd4e3abe2d77248810cf15c15475e806a3bd32..39b7a787be8c725e7b6d474757dd96fc4848dfe0 100644
--- a/src/finn/transformation/infer_datatypes.py
+++ b/src/finn/transformation/infer_datatypes.py
@@ -71,7 +71,13 @@ def _infer_node_datatype(model, node):
         else:
             # unknown, assume node produces float32 outputs
             for o in node.output:
-                model.set_tensor_datatype(o, DataType.FLOAT32)
+                # check if output datatype is already set to a value != FLOAT32
+                odtype = model.get_tensor_datatype(o)
+                if odtype is not None and odtype != DataType.FLOAT32:
+                    # don't change data type
+                    model.set_tensor_datatype(o, odtype)
+                else:
+                    model.set_tensor_datatype(o, DataType.FLOAT32)
     # compare old and new output dtypes to see if anything changed
     new_odtypes = list(map(lambda x: model.get_tensor_datatype(x), node.output))
     graph_modified = new_odtypes != odtypes
diff --git a/src/finn/transformation/streamline/collapse_repeated.py b/src/finn/transformation/streamline/collapse_repeated.py
index 67824ad4f633983b93e3178d03118927a1ddd85b..769bed841ce07c1c9c62f762de4b2c0937a6d68f 100644
--- a/src/finn/transformation/streamline/collapse_repeated.py
+++ b/src/finn/transformation/streamline/collapse_repeated.py
@@ -30,6 +30,7 @@ from onnx import helper as oh
 
 from finn.transformation import Transformation
 from finn.transformation.infer_shapes import InferShapes
+from finn.core.datatype import DataType
 
 
 class CollapseRepeatedOp(Transformation):
@@ -83,6 +84,9 @@ class CollapseRepeatedOp(Transformation):
                     graph.node.insert(node_ind, new_node)
                     # replace parameter value
                     model.set_initializer(new_node_param_name, new_param)
+                    # be conservative with param/output DataTypes
+                    model.set_tensor_datatype(new_node_param_name, DataType.FLOAT32)
+                    model.set_tensor_datatype(end_name, DataType.FLOAT32)
                     # remove old nodes
                     graph.node.remove(n)
                     graph.node.remove(consumer)
diff --git a/src/finn/util/basic.py b/src/finn/util/basic.py
index 585d6b90a59ca9f3dac56a6133de705c2f56f025..809b34157ee4b7890a4155bd09f33dcc85c6ceec 100644
--- a/src/finn/util/basic.py
+++ b/src/finn/util/basic.py
@@ -106,6 +106,14 @@ def get_finn_root():
         )
 
 
+def get_execution_error_thresh():
+    "Return the max error that is allowed for rounding in FINN execution."
+    try:
+        return float(os.environ["ERROR_THRESH"])
+    except KeyError:
+        return 1e-2
+
+
 def make_build_dir(prefix=""):
     """Creates a temporary folder with given prefix to be used as a build dir.
     Use this function instead of tempfile.mkdtemp to ensure any generated files
@@ -305,7 +313,7 @@ def sanitize_quant_values(model, node_tensors, execution_context, check_values=F
             )
         # check if rounded values are not too far from original values
         max_error = max(np.abs(current_values - updated_values).flatten())
-        if max_error <= 1e-4:
+        if max_error <= get_execution_error_thresh():
             if check_values is True:
                 # check again if values can now be represented with set finn datatype
                 # TODO: vectorize with numpy
diff --git a/tests/brevitas/test_brevitas_avg_pool_export.py b/tests/brevitas/test_brevitas_avg_pool_export.py
new file mode 100644
index 0000000000000000000000000000000000000000..24854a2153df9af78feb8352ca119e831a9ac9eb
--- /dev/null
+++ b/tests/brevitas/test_brevitas_avg_pool_export.py
@@ -0,0 +1,103 @@
+import os
+
+import onnx  # noqa
+import torch
+import numpy as np
+import brevitas.onnx as bo
+from brevitas.nn import QuantAvgPool2d
+from brevitas.quant_tensor import pack_quant_tensor
+from brevitas.core.quant import QuantType
+from finn.core.modelwrapper import ModelWrapper
+from finn.core.datatype import DataType
+from finn.transformation.infer_shapes import InferShapes
+from finn.transformation.infer_datatypes import InferDataTypes
+from finn.util.basic import gen_finn_dt_tensor
+import finn.core.onnx_exec as oxe
+
+import pytest
+
+export_onnx_path = "test_avg_pool.onnx"
+
+
+@pytest.mark.parametrize("kernel_size", [2, 3])
+@pytest.mark.parametrize("stride", [1, 2])
+@pytest.mark.parametrize("signed", [False, True])
+@pytest.mark.parametrize("bit_width", [2, 4])
+@pytest.mark.parametrize("input_bit_width", [4, 8, 32])
+@pytest.mark.parametrize("channels", [2, 4])
+@pytest.mark.parametrize("idim", [7, 8])
+def test_brevitas_avg_pool_export(
+    kernel_size, stride, signed, bit_width, input_bit_width, channels, idim
+):
+    ishape = (1, channels, idim, idim)
+    ibw_tensor = torch.Tensor([input_bit_width])
+
+    b_avgpool = QuantAvgPool2d(
+        kernel_size=kernel_size,
+        stride=stride,
+        signed=signed,
+        min_overall_bit_width=bit_width,
+        max_overall_bit_width=bit_width,
+        quant_type=QuantType.INT,
+    )
+    # call forward pass manually once to cache scale factor and bitwidth
+    input_tensor = torch.from_numpy(np.zeros(ishape)).float()
+    scale = np.ones((1, channels, 1, 1))
+    output_scale = torch.from_numpy(scale).float()
+    input_quant_tensor = pack_quant_tensor(
+        tensor=input_tensor, scale=output_scale, bit_width=ibw_tensor
+    )
+    bo.export_finn_onnx(b_avgpool, ishape, export_onnx_path, input_t=input_quant_tensor)
+    model = ModelWrapper(export_onnx_path)
+
+    # determine input FINN datatype
+    if signed is True:
+        prefix = "INT"
+    else:
+        prefix = "UINT"
+    dt_name = prefix + str(input_bit_width // 2)
+    dtype = DataType[dt_name]
+    model = model.transform(InferShapes())
+    model = model.transform(InferDataTypes())
+
+    # execution with input tensor using integers and scale = 1
+    # calculate golden output
+    inp = gen_finn_dt_tensor(dtype, ishape)
+    input_tensor = torch.from_numpy(inp).float()
+    input_quant_tensor = pack_quant_tensor(
+        tensor=input_tensor, scale=output_scale, bit_width=ibw_tensor
+    )
+    b_avgpool.eval()
+    expected = b_avgpool.forward(input_quant_tensor).tensor.detach().numpy()
+
+    # finn execution
+    idict = {model.graph.input[0].name: inp}
+    odict = oxe.execute_onnx(model, idict, True)
+    produced = odict[model.graph.output[0].name]
+    assert (expected == produced).all()
+
+    # execution with input tensor using float and scale != 1
+    scale = np.random.uniform(low=0, high=1, size=(1, channels, 1, 1)).astype(
+        np.float32
+    )
+    inp_tensor = inp * scale
+    input_tensor = torch.from_numpy(inp_tensor).float()
+    input_scale = torch.from_numpy(scale).float()
+    input_quant_tensor = pack_quant_tensor(
+        tensor=input_tensor, scale=input_scale, bit_width=ibw_tensor
+    )
+    # export again to set the scale values correctly
+    bo.export_finn_onnx(b_avgpool, ishape, export_onnx_path, input_t=input_quant_tensor)
+    model = ModelWrapper(export_onnx_path)
+    model = model.transform(InferShapes())
+    model = model.transform(InferDataTypes())
+    b_avgpool.eval()
+    expected = b_avgpool.forward(input_quant_tensor).tensor.detach().numpy()
+    # finn execution
+    idict = {model.graph.input[0].name: inp_tensor}
+    odict = oxe.execute_onnx(model, idict, True)
+    produced = odict[model.graph.output[0].name]
+
+    assert np.isclose(expected, produced).all()
+
+    os.remove(export_onnx_path)