diff --git a/docker/Dockerfile.finn_dev b/docker/Dockerfile.finn_dev index 0645a36113b8b8115bc711eb8fa7d39bc606c0e4..1c2cb19d14137b866b55417522fdebb8e0d7ad90 100644 --- a/docker/Dockerfile.finn_dev +++ b/docker/Dockerfile.finn_dev @@ -72,7 +72,7 @@ USER $UNAME # cloning dependency repos (as user) # Brevitas -RUN git clone https://github.com/auphelia/brevitas.git /workspace/brevitas +RUN git clone https://github.com/Xilinx/brevitas.git /workspace/brevitas # CNPY RUN git clone https://github.com/rogersce/cnpy.git /workspace/cnpy # FINN hlslib diff --git a/docker/finn_entrypoint.sh b/docker/finn_entrypoint.sh index 45a6ab0d25896f096339b01e0a9abad2b6154992..4baa12aa5215ffc7351b08a6b2e9868402ec749d 100644 --- a/docker/finn_entrypoint.sh +++ b/docker/finn_entrypoint.sh @@ -13,8 +13,7 @@ gecho () { # checkout the correct dependency repo commits # the repos themselves are cloned in the Dockerfile -#BREVITAS_COMMIT=989cdfdba4700fdd900ba0b25a820591d561c21a -BREVITAS_COMMIT=265f61355d68054f11106b6f5903ab737b91038f +BREVITAS_COMMIT=d45ac15325c7f33de6a9d2d2f654ef48cb20c49d CNPY_COMMIT=4e8810b1a8637695171ed346ce68f6984e585ef4 HLSLIB_COMMIT=6b88db826bb023937506913a23d964775a7606af PYVERILATOR_COMMIT=1d89cb0d4e0c97469cc6352c611f876ec13edfa6 diff --git a/tests/brevitas/test_brevitas_avg_pool_export.py b/tests/brevitas/test_brevitas_avg_pool_export.py index d4a5776ed0eb5bb18bab26ad8bc9b90d578b0f96..3bf6a8ed6f86ba4e54f49b64844a7a51337be445 100644 --- a/tests/brevitas/test_brevitas_avg_pool_export.py +++ b/tests/brevitas/test_brevitas_avg_pool_export.py @@ -1,7 +1,11 @@ import onnx # noqa +import torch +import numpy as np import brevitas.onnx as bo from brevitas.nn import QuantAvgPool2d +from brevitas.quant_tensor import pack_quant_tensor from brevitas.core.quant import QuantType + import pytest export_onnx_path = "test_avg_pool.onnx" @@ -12,7 +16,10 @@ export_onnx_path = "test_avg_pool.onnx" @pytest.mark.parametrize("signed", [False]) @pytest.mark.parametrize("bit_width", [4]) def test_brevitas_avg_pool_export(kernel_size, stride, signed, bit_width): - ishape = (1, 1024, 7, 7) + ch = 4 + ishape = (1, ch, 7, 7) + input_bit_width = 32 + ibw_tensor = torch.Tensor([input_bit_width]) b_avgpool = QuantAvgPool2d( kernel_size=kernel_size, @@ -22,4 +29,10 @@ def test_brevitas_avg_pool_export(kernel_size, stride, signed, bit_width): max_overall_bit_width=bit_width, quant_type=QuantType.INT, ) - bo.export_finn_onnx(b_avgpool, ishape, export_onnx_path) + # call forward pass manually once to cache scale factor and bitwidth + input_tensor = torch.from_numpy(np.zeros(ishape)).float() + output_scale = torch.from_numpy(np.ones((1, ch, 1, 1))).float() + input_quant_tensor = pack_quant_tensor( + tensor=input_tensor, scale=output_scale, bit_width=ibw_tensor + ) + bo.export_finn_onnx(b_avgpool, ishape, export_onnx_path, input_t=input_quant_tensor)