Skip to content
Snippets Groups Projects
Commit 52f26835 authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Test] fix avgpool export test

parent 1a850c9d
No related branches found
No related tags found
No related merge requests found
...@@ -25,31 +25,29 @@ ...@@ -25,31 +25,29 @@
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import os import os
import onnx # noqa
import torch import torch
import numpy as np import numpy as np
import brevitas.onnx as bo import pytest
from brevitas.nn import QuantAvgPool2d import finn.core.onnx_exec as oxe
from brevitas.quant_tensor import QuantTensor
from brevitas.core.quant import QuantType
from finn.core.modelwrapper import ModelWrapper from finn.core.modelwrapper import ModelWrapper
from finn.core.datatype import DataType from finn.core.datatype import DataType
from finn.transformation.infer_shapes import InferShapes from finn.transformation.infer_shapes import InferShapes
from finn.transformation.infer_datatypes import InferDataTypes from finn.transformation.infer_datatypes import InferDataTypes
from finn.util.basic import gen_finn_dt_tensor from finn.util.basic import gen_finn_dt_tensor
import finn.core.onnx_exec as oxe
import pytest from brevitas.export import FINNManager
from brevitas.nn import QuantAvgPool2d
from brevitas.quant_tensor import QuantTensor
export_onnx_path = "test_brevitas_avg_pool_export.onnx" export_onnx_path = "test_brevitas_avg_pool_export.onnx"
@pytest.mark.parametrize("kernel_size", [2, 3]) @pytest.mark.parametrize("kernel_size", [2, 3])
@pytest.mark.parametrize("stride", [1, 2]) @pytest.mark.parametrize("stride", [1, 2])
@pytest.mark.parametrize("signed", [False, True]) @pytest.mark.parametrize("signed", [True, False])
@pytest.mark.parametrize("bit_width", [2, 4]) @pytest.mark.parametrize("bit_width", [2, 4])
@pytest.mark.parametrize("input_bit_width", [4, 8, 16]) @pytest.mark.parametrize("input_bit_width", [4, 8, 16])
@pytest.mark.parametrize("channels", [2, 4]) @pytest.mark.parametrize("channels", [2, 4])
...@@ -57,90 +55,46 @@ export_onnx_path = "test_brevitas_avg_pool_export.onnx" ...@@ -57,90 +55,46 @@ export_onnx_path = "test_brevitas_avg_pool_export.onnx"
def test_brevitas_avg_pool_export( def test_brevitas_avg_pool_export(
kernel_size, stride, signed, bit_width, input_bit_width, channels, idim kernel_size, stride, signed, bit_width, input_bit_width, channels, idim
): ):
ishape = (1, channels, idim, idim)
ibw_tensor = torch.Tensor([input_bit_width])
b_avgpool = QuantAvgPool2d( quant_avgpool = QuantAvgPool2d(
kernel_size=kernel_size, kernel_size=kernel_size, stride=stride, bit_width=bit_width
stride=stride,
bit_width=bit_width,
quant_type=QuantType.INT,
)
# call forward pass manually once to cache scale factor and bitwidth
input_tensor = torch.from_numpy(np.zeros(ishape)).float()
scale = np.ones((1, channels, 1, 1))
zpt = torch.from_numpy(np.zeros((1))).float()
output_scale = torch.from_numpy(scale).float()
input_quant_tensor = QuantTensor(
value=input_tensor,
scale=output_scale,
bit_width=ibw_tensor,
signed=signed,
zero_point=zpt,
)
bo.export_finn_onnx(
b_avgpool, export_path=export_onnx_path, input_t=input_quant_tensor
) )
model = ModelWrapper(export_onnx_path) quant_avgpool.eval()
# determine input FINN datatype # determine input
if signed is True: prefix = "INT" if signed else "UINT"
prefix = "INT"
else:
prefix = "UINT"
dt_name = prefix + str(input_bit_width) dt_name = prefix + str(input_bit_width)
dtype = DataType[dt_name] dtype = DataType[dt_name]
model = model.transform(InferShapes()) input_shape = (1, channels, idim, idim)
model = model.transform(InferDataTypes()) input_array = gen_finn_dt_tensor(dtype, input_shape)
# Brevitas QuantAvgPool layers need QuantTensors to export correctly
# execution with input tensor using integers and scale = 1 # which requires setting up a QuantTensor instance with the scale
# calculate golden output # factor, zero point, bitwidth and signedness
inp = gen_finn_dt_tensor(dtype, ishape) scale_array = np.random.uniform(low=0, high=1, size=(1, channels, 1, 1)).astype(
input_tensor = torch.from_numpy(inp).float()
input_quant_tensor = QuantTensor(
value=input_tensor,
scale=output_scale,
bit_width=ibw_tensor,
signed=signed,
zero_point=zpt,
)
b_avgpool.eval()
expected = b_avgpool.forward(input_quant_tensor).tensor.detach().numpy()
# finn execution
idict = {model.graph.input[0].name: inp}
odict = oxe.execute_onnx(model, idict, True)
produced = odict[model.graph.output[0].name]
assert (expected == produced).all()
# execution with input tensor using float and scale != 1
scale = np.random.uniform(low=0, high=1, size=(1, channels, 1, 1)).astype(
np.float32 np.float32
) )
inp_tensor = inp * scale input_tensor = torch.from_numpy(input_array * scale_array).float()
input_tensor = torch.from_numpy(inp_tensor).float() scale_tensor = torch.from_numpy(scale_array).float()
input_scale = torch.from_numpy(scale).float() zp = torch.tensor(0.0)
input_quant_tensor = QuantTensor( input_quant_tensor = QuantTensor(
value=input_tensor, input_tensor, scale_tensor, zp, input_bit_width, signed
scale=input_scale,
bit_width=ibw_tensor,
signed=signed,
zero_point=zpt,
) )
# export again to set the scale values correctly
bo.export_finn_onnx( # export
b_avgpool, export_path=export_onnx_path, input_t=input_quant_tensor FINNManager.export(
quant_avgpool, export_path=export_onnx_path, input_t=input_quant_tensor
) )
model = ModelWrapper(export_onnx_path) model = ModelWrapper(export_onnx_path)
model = model.transform(InferShapes()) model = model.transform(InferShapes())
model = model.transform(InferDataTypes()) model = model.transform(InferDataTypes())
b_avgpool.eval()
expected = b_avgpool.forward(input_quant_tensor).tensor.detach().numpy()
# finn execution
idict = {model.graph.input[0].name: inp_tensor}
odict = oxe.execute_onnx(model, idict, True)
produced = odict[model.graph.output[0].name]
assert np.isclose(expected, produced).all()
# reference brevitas output
ref_output_array = quant_avgpool(input_quant_tensor).tensor.detach().numpy()
# finn output
idict = {model.graph.input[0].name: input_array}
odict = oxe.execute_onnx(model, idict, True)
finn_output = odict[model.graph.output[0].name]
# compare outputs
assert np.isclose(ref_output_array, finn_output).all()
# cleanup
os.remove(export_onnx_path) os.remove(export_onnx_path)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment