diff --git a/docker/Dockerfile.finn_dev b/docker/Dockerfile.finn_dev
index 0645a36113b8b8115bc711eb8fa7d39bc606c0e4..1c2cb19d14137b866b55417522fdebb8e0d7ad90 100644
--- a/docker/Dockerfile.finn_dev
+++ b/docker/Dockerfile.finn_dev
@@ -72,7 +72,7 @@ USER $UNAME
 
 # cloning dependency repos (as user)
 # Brevitas
-RUN git clone https://github.com/auphelia/brevitas.git /workspace/brevitas
+RUN git clone https://github.com/Xilinx/brevitas.git /workspace/brevitas
 # CNPY
 RUN git clone https://github.com/rogersce/cnpy.git /workspace/cnpy
 # FINN hlslib
diff --git a/docker/finn_entrypoint.sh b/docker/finn_entrypoint.sh
index 45a6ab0d25896f096339b01e0a9abad2b6154992..4baa12aa5215ffc7351b08a6b2e9868402ec749d 100644
--- a/docker/finn_entrypoint.sh
+++ b/docker/finn_entrypoint.sh
@@ -13,8 +13,7 @@ gecho () {
 
 # checkout the correct dependency repo commits
 # the repos themselves are cloned in the Dockerfile
-#BREVITAS_COMMIT=989cdfdba4700fdd900ba0b25a820591d561c21a
-BREVITAS_COMMIT=265f61355d68054f11106b6f5903ab737b91038f
+BREVITAS_COMMIT=d45ac15325c7f33de6a9d2d2f654ef48cb20c49d
 CNPY_COMMIT=4e8810b1a8637695171ed346ce68f6984e585ef4
 HLSLIB_COMMIT=6b88db826bb023937506913a23d964775a7606af
 PYVERILATOR_COMMIT=1d89cb0d4e0c97469cc6352c611f876ec13edfa6
diff --git a/tests/brevitas/test_brevitas_avg_pool_export.py b/tests/brevitas/test_brevitas_avg_pool_export.py
index d4a5776ed0eb5bb18bab26ad8bc9b90d578b0f96..3bf6a8ed6f86ba4e54f49b64844a7a51337be445 100644
--- a/tests/brevitas/test_brevitas_avg_pool_export.py
+++ b/tests/brevitas/test_brevitas_avg_pool_export.py
@@ -1,7 +1,11 @@
 import onnx  # noqa
+import torch
+import numpy as np
 import brevitas.onnx as bo
 from brevitas.nn import QuantAvgPool2d
+from brevitas.quant_tensor import pack_quant_tensor
 from brevitas.core.quant import QuantType
+
 import pytest
 
 export_onnx_path = "test_avg_pool.onnx"
@@ -12,7 +16,10 @@ export_onnx_path = "test_avg_pool.onnx"
 @pytest.mark.parametrize("signed", [False])
 @pytest.mark.parametrize("bit_width", [4])
 def test_brevitas_avg_pool_export(kernel_size, stride, signed, bit_width):
-    ishape = (1, 1024, 7, 7)
+    ch = 4
+    ishape = (1, ch, 7, 7)
+    input_bit_width = 32
+    ibw_tensor = torch.Tensor([input_bit_width])
 
     b_avgpool = QuantAvgPool2d(
         kernel_size=kernel_size,
@@ -22,4 +29,10 @@ def test_brevitas_avg_pool_export(kernel_size, stride, signed, bit_width):
         max_overall_bit_width=bit_width,
         quant_type=QuantType.INT,
     )
-    bo.export_finn_onnx(b_avgpool, ishape, export_onnx_path)
+    # call forward pass manually once to cache scale factor and bitwidth
+    input_tensor = torch.from_numpy(np.zeros(ishape)).float()
+    output_scale = torch.from_numpy(np.ones((1, ch, 1, 1))).float()
+    input_quant_tensor = pack_quant_tensor(
+        tensor=input_tensor, scale=output_scale, bit_width=ibw_tensor
+    )
+    bo.export_finn_onnx(b_avgpool, ishape, export_onnx_path, input_t=input_quant_tensor)