Skip to content
Snippets Groups Projects
Commit 97a77227 authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Test] update avgpool export test for newest Brevitas

parent ba4bbe59
No related branches found
No related tags found
No related merge requests found
......@@ -5,7 +5,7 @@ import torch
import numpy as np
import brevitas.onnx as bo
from brevitas.nn import QuantAvgPool2d
from brevitas.quant_tensor import pack_quant_tensor
from brevitas.quant_tensor import QuantTensor
from brevitas.core.quant import QuantType
from finn.core.modelwrapper import ModelWrapper
from finn.core.datatype import DataType
......@@ -41,11 +41,18 @@ def test_brevitas_avg_pool_export(
# call forward pass manually once to cache scale factor and bitwidth
input_tensor = torch.from_numpy(np.zeros(ishape)).float()
scale = np.ones((1, channels, 1, 1))
zpt = torch.from_numpy(np.zeros((1))).float()
output_scale = torch.from_numpy(scale).float()
input_quant_tensor = pack_quant_tensor(
tensor=input_tensor, scale=output_scale, bit_width=ibw_tensor, signed=signed
input_quant_tensor = QuantTensor(
value=input_tensor,
scale=output_scale,
bit_width=ibw_tensor,
signed=signed,
zero_point=zpt,
)
bo.export_finn_onnx(
b_avgpool, export_path=export_onnx_path, input_t=input_quant_tensor
)
bo.export_finn_onnx(b_avgpool, ishape, export_onnx_path, input_t=input_quant_tensor)
model = ModelWrapper(export_onnx_path)
# determine input FINN datatype
......@@ -62,8 +69,12 @@ def test_brevitas_avg_pool_export(
# calculate golden output
inp = gen_finn_dt_tensor(dtype, ishape)
input_tensor = torch.from_numpy(inp).float()
input_quant_tensor = pack_quant_tensor(
tensor=input_tensor, scale=output_scale, bit_width=ibw_tensor, signed=signed
input_quant_tensor = QuantTensor(
value=input_tensor,
scale=output_scale,
bit_width=ibw_tensor,
signed=signed,
zero_point=zpt,
)
b_avgpool.eval()
expected = b_avgpool.forward(input_quant_tensor).tensor.detach().numpy()
......@@ -81,11 +92,17 @@ def test_brevitas_avg_pool_export(
inp_tensor = inp * scale
input_tensor = torch.from_numpy(inp_tensor).float()
input_scale = torch.from_numpy(scale).float()
input_quant_tensor = pack_quant_tensor(
tensor=input_tensor, scale=input_scale, bit_width=ibw_tensor, signed=signed
input_quant_tensor = QuantTensor(
value=input_tensor,
scale=input_scale,
bit_width=ibw_tensor,
signed=signed,
zero_point=zpt,
)
# export again to set the scale values correctly
bo.export_finn_onnx(b_avgpool, ishape, export_onnx_path, input_t=input_quant_tensor)
bo.export_finn_onnx(
b_avgpool, export_path=export_onnx_path, input_t=input_quant_tensor
)
model = ModelWrapper(export_onnx_path)
model = model.transform(InferShapes())
model = model.transform(InferDataTypes())
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment