Skip to content
Snippets Groups Projects
Commit a0dea606 authored by auphelia's avatar auphelia
Browse files

Merge changes in origin/feature/onnx_export_quant_avg_pool into local feature branch

parents c39e085a bd887eba
No related branches found
No related tags found
No related merge requests found
......@@ -72,7 +72,7 @@ USER $UNAME
# cloning dependency repos (as user)
# Brevitas
RUN git clone https://github.com/auphelia/brevitas.git /workspace/brevitas
RUN git clone https://github.com/Xilinx/brevitas.git /workspace/brevitas
# CNPY
RUN git clone https://github.com/rogersce/cnpy.git /workspace/cnpy
# FINN hlslib
......
......@@ -13,8 +13,7 @@ gecho () {
# checkout the correct dependency repo commits
# the repos themselves are cloned in the Dockerfile
#BREVITAS_COMMIT=989cdfdba4700fdd900ba0b25a820591d561c21a
BREVITAS_COMMIT=265f61355d68054f11106b6f5903ab737b91038f
BREVITAS_COMMIT=d45ac15325c7f33de6a9d2d2f654ef48cb20c49d
CNPY_COMMIT=4e8810b1a8637695171ed346ce68f6984e585ef4
HLSLIB_COMMIT=6b88db826bb023937506913a23d964775a7606af
PYVERILATOR_COMMIT=1d89cb0d4e0c97469cc6352c611f876ec13edfa6
......
import onnx # noqa
import torch
import numpy as np
import brevitas.onnx as bo
from brevitas.nn import QuantAvgPool2d
from brevitas.quant_tensor import pack_quant_tensor
from brevitas.core.quant import QuantType
import pytest
export_onnx_path = "test_avg_pool.onnx"
......@@ -12,7 +16,10 @@ export_onnx_path = "test_avg_pool.onnx"
@pytest.mark.parametrize("signed", [False])
@pytest.mark.parametrize("bit_width", [4])
def test_brevitas_avg_pool_export(kernel_size, stride, signed, bit_width):
ishape = (1, 1024, 7, 7)
ch = 4
ishape = (1, ch, 7, 7)
input_bit_width = 32
ibw_tensor = torch.Tensor([input_bit_width])
b_avgpool = QuantAvgPool2d(
kernel_size=kernel_size,
......@@ -22,4 +29,10 @@ def test_brevitas_avg_pool_export(kernel_size, stride, signed, bit_width):
max_overall_bit_width=bit_width,
quant_type=QuantType.INT,
)
bo.export_finn_onnx(b_avgpool, ishape, export_onnx_path)
# call forward pass manually once to cache scale factor and bitwidth
input_tensor = torch.from_numpy(np.zeros(ishape)).float()
output_scale = torch.from_numpy(np.ones((1, ch, 1, 1))).float()
input_quant_tensor = pack_quant_tensor(
tensor=input_tensor, scale=output_scale, bit_width=ibw_tensor
)
bo.export_finn_onnx(b_avgpool, ishape, export_onnx_path, input_t=input_quant_tensor)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment