Skip to content
Snippets Groups Projects
Commit 52f26835 authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Test] fix avgpool export test

parent 1a850c9d
No related branches found
No related tags found
No related merge requests found
......@@ -25,31 +25,29 @@
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import os
import onnx # noqa
import torch
import numpy as np
import brevitas.onnx as bo
from brevitas.nn import QuantAvgPool2d
from brevitas.quant_tensor import QuantTensor
from brevitas.core.quant import QuantType
import pytest
import finn.core.onnx_exec as oxe
from finn.core.modelwrapper import ModelWrapper
from finn.core.datatype import DataType
from finn.transformation.infer_shapes import InferShapes
from finn.transformation.infer_datatypes import InferDataTypes
from finn.util.basic import gen_finn_dt_tensor
import finn.core.onnx_exec as oxe
import pytest
from brevitas.export import FINNManager
from brevitas.nn import QuantAvgPool2d
from brevitas.quant_tensor import QuantTensor
export_onnx_path = "test_brevitas_avg_pool_export.onnx"
@pytest.mark.parametrize("kernel_size", [2, 3])
@pytest.mark.parametrize("stride", [1, 2])
@pytest.mark.parametrize("signed", [False, True])
@pytest.mark.parametrize("signed", [True, False])
@pytest.mark.parametrize("bit_width", [2, 4])
@pytest.mark.parametrize("input_bit_width", [4, 8, 16])
@pytest.mark.parametrize("channels", [2, 4])
......@@ -57,90 +55,46 @@ export_onnx_path = "test_brevitas_avg_pool_export.onnx"
def test_brevitas_avg_pool_export(
kernel_size, stride, signed, bit_width, input_bit_width, channels, idim
):
ishape = (1, channels, idim, idim)
ibw_tensor = torch.Tensor([input_bit_width])
b_avgpool = QuantAvgPool2d(
kernel_size=kernel_size,
stride=stride,
bit_width=bit_width,
quant_type=QuantType.INT,
)
# call forward pass manually once to cache scale factor and bitwidth
input_tensor = torch.from_numpy(np.zeros(ishape)).float()
scale = np.ones((1, channels, 1, 1))
zpt = torch.from_numpy(np.zeros((1))).float()
output_scale = torch.from_numpy(scale).float()
input_quant_tensor = QuantTensor(
value=input_tensor,
scale=output_scale,
bit_width=ibw_tensor,
signed=signed,
zero_point=zpt,
)
bo.export_finn_onnx(
b_avgpool, export_path=export_onnx_path, input_t=input_quant_tensor
quant_avgpool = QuantAvgPool2d(
kernel_size=kernel_size, stride=stride, bit_width=bit_width
)
model = ModelWrapper(export_onnx_path)
quant_avgpool.eval()
# determine input FINN datatype
if signed is True:
prefix = "INT"
else:
prefix = "UINT"
# determine input
prefix = "INT" if signed else "UINT"
dt_name = prefix + str(input_bit_width)
dtype = DataType[dt_name]
model = model.transform(InferShapes())
model = model.transform(InferDataTypes())
# execution with input tensor using integers and scale = 1
# calculate golden output
inp = gen_finn_dt_tensor(dtype, ishape)
input_tensor = torch.from_numpy(inp).float()
input_quant_tensor = QuantTensor(
value=input_tensor,
scale=output_scale,
bit_width=ibw_tensor,
signed=signed,
zero_point=zpt,
)
b_avgpool.eval()
expected = b_avgpool.forward(input_quant_tensor).tensor.detach().numpy()
# finn execution
idict = {model.graph.input[0].name: inp}
odict = oxe.execute_onnx(model, idict, True)
produced = odict[model.graph.output[0].name]
assert (expected == produced).all()
# execution with input tensor using float and scale != 1
scale = np.random.uniform(low=0, high=1, size=(1, channels, 1, 1)).astype(
input_shape = (1, channels, idim, idim)
input_array = gen_finn_dt_tensor(dtype, input_shape)
# Brevitas QuantAvgPool layers need QuantTensors to export correctly
# which requires setting up a QuantTensor instance with the scale
# factor, zero point, bitwidth and signedness
scale_array = np.random.uniform(low=0, high=1, size=(1, channels, 1, 1)).astype(
np.float32
)
inp_tensor = inp * scale
input_tensor = torch.from_numpy(inp_tensor).float()
input_scale = torch.from_numpy(scale).float()
input_tensor = torch.from_numpy(input_array * scale_array).float()
scale_tensor = torch.from_numpy(scale_array).float()
zp = torch.tensor(0.0)
input_quant_tensor = QuantTensor(
value=input_tensor,
scale=input_scale,
bit_width=ibw_tensor,
signed=signed,
zero_point=zpt,
input_tensor, scale_tensor, zp, input_bit_width, signed
)
# export again to set the scale values correctly
bo.export_finn_onnx(
b_avgpool, export_path=export_onnx_path, input_t=input_quant_tensor
# export
FINNManager.export(
quant_avgpool, export_path=export_onnx_path, input_t=input_quant_tensor
)
model = ModelWrapper(export_onnx_path)
model = model.transform(InferShapes())
model = model.transform(InferDataTypes())
b_avgpool.eval()
expected = b_avgpool.forward(input_quant_tensor).tensor.detach().numpy()
# finn execution
idict = {model.graph.input[0].name: inp_tensor}
odict = oxe.execute_onnx(model, idict, True)
produced = odict[model.graph.output[0].name]
assert np.isclose(expected, produced).all()
# reference brevitas output
ref_output_array = quant_avgpool(input_quant_tensor).tensor.detach().numpy()
# finn output
idict = {model.graph.input[0].name: input_array}
odict = oxe.execute_onnx(model, idict, True)
finn_output = odict[model.graph.output[0].name]
# compare outputs
assert np.isclose(ref_output_array, finn_output).all()
# cleanup
os.remove(export_onnx_path)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment