Skip to content
Snippets Groups Projects
Commit 82f87146 authored by auphelia's avatar auphelia
Browse files

[Docker & Test] Update brevitas version and extend avg pool export testing

parent 81deb807
No related branches found
No related tags found
No related merge requests found
...@@ -13,7 +13,7 @@ gecho () { ...@@ -13,7 +13,7 @@ gecho () {
# checkout the correct dependency repo commits # checkout the correct dependency repo commits
# the repos themselves are cloned in the Dockerfile # the repos themselves are cloned in the Dockerfile
BREVITAS_COMMIT=d45ac15325c7f33de6a9d2d2f654ef48cb20c49d BREVITAS_COMMIT=093de7d138c6715dbcaf82a9e1d530069327ad98
CNPY_COMMIT=4e8810b1a8637695171ed346ce68f6984e585ef4 CNPY_COMMIT=4e8810b1a8637695171ed346ce68f6984e585ef4
HLSLIB_COMMIT=6b88db826bb023937506913a23d964775a7606af HLSLIB_COMMIT=6b88db826bb023937506913a23d964775a7606af
PYVERILATOR_COMMIT=1d89cb0d4e0c97469cc6352c611f876ec13edfa6 PYVERILATOR_COMMIT=1d89cb0d4e0c97469cc6352c611f876ec13edfa6
......
import os
import onnx # noqa import onnx # noqa
import torch import torch
import numpy as np import numpy as np
...@@ -5,6 +7,12 @@ import brevitas.onnx as bo ...@@ -5,6 +7,12 @@ import brevitas.onnx as bo
from brevitas.nn import QuantAvgPool2d from brevitas.nn import QuantAvgPool2d
from brevitas.quant_tensor import pack_quant_tensor from brevitas.quant_tensor import pack_quant_tensor
from brevitas.core.quant import QuantType from brevitas.core.quant import QuantType
from finn.core.modelwrapper import ModelWrapper
from finn.core.datatype import DataType
from finn.transformation.infer_datatypes import InferDataTypes
from finn.transformation.infer_shapes import InferShapes
from finn.util.basic import gen_finn_dt_tensor
import finn.core.onnx_exec as oxe
import pytest import pytest
...@@ -36,3 +44,30 @@ def test_brevitas_avg_pool_export(kernel_size, stride, signed, bit_width): ...@@ -36,3 +44,30 @@ def test_brevitas_avg_pool_export(kernel_size, stride, signed, bit_width):
tensor=input_tensor, scale=output_scale, bit_width=ibw_tensor tensor=input_tensor, scale=output_scale, bit_width=ibw_tensor
) )
bo.export_finn_onnx(b_avgpool, ishape, export_onnx_path, input_t=input_quant_tensor) bo.export_finn_onnx(b_avgpool, ishape, export_onnx_path, input_t=input_quant_tensor)
model = ModelWrapper(export_onnx_path)
# set FINN datatype
if signed is True:
prefix = "INT"
else:
prefix = "UINT"
dt_name = prefix + str(bit_width)
dtype = DataType[dt_name]
model.set_tensor_datatype(model.graph.input[0].name, dtype)
model = model.transform(InferShapes())
model = model.transform(InferDataTypes())
# calculate golden output
inp = gen_finn_dt_tensor(dtype, ishape)
input_tensor = torch.from_numpy(inp).float()
input_quant_tensor = pack_quant_tensor(
tensor=input_tensor, scale=output_scale, bit_width=ibw_tensor
)
b_avgpool.eval()
expected = b_avgpool.forward(input_quant_tensor).tensor.detach().numpy()
# finn execution
idict = {model.graph.input[0].name : inp}
odict = oxe.execute_onnx(model, idict, True)
produced = odict[model.graph.output[0].name]
os.remove(export_onnx_path)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment