diff --git a/docker/finn_entrypoint.sh b/docker/finn_entrypoint.sh index 4baa12aa5215ffc7351b08a6b2e9868402ec749d..30dae7d2bd37516a887ba5ca20c1398af75905f3 100644 --- a/docker/finn_entrypoint.sh +++ b/docker/finn_entrypoint.sh @@ -13,7 +13,7 @@ gecho () { # checkout the correct dependency repo commits # the repos themselves are cloned in the Dockerfile -BREVITAS_COMMIT=d45ac15325c7f33de6a9d2d2f654ef48cb20c49d +BREVITAS_COMMIT=093de7d138c6715dbcaf82a9e1d530069327ad98 CNPY_COMMIT=4e8810b1a8637695171ed346ce68f6984e585ef4 HLSLIB_COMMIT=6b88db826bb023937506913a23d964775a7606af PYVERILATOR_COMMIT=1d89cb0d4e0c97469cc6352c611f876ec13edfa6 diff --git a/tests/brevitas/test_brevitas_avg_pool_export.py b/tests/brevitas/test_brevitas_avg_pool_export.py index 3bf6a8ed6f86ba4e54f49b64844a7a51337be445..17b2bc5aaa1c3dfd16ed42cae71dec948e9114a9 100644 --- a/tests/brevitas/test_brevitas_avg_pool_export.py +++ b/tests/brevitas/test_brevitas_avg_pool_export.py @@ -1,3 +1,5 @@ +import os + import onnx # noqa import torch import numpy as np @@ -5,6 +7,12 @@ import brevitas.onnx as bo from brevitas.nn import QuantAvgPool2d from brevitas.quant_tensor import pack_quant_tensor from brevitas.core.quant import QuantType +from finn.core.modelwrapper import ModelWrapper +from finn.core.datatype import DataType +from finn.transformation.infer_datatypes import InferDataTypes +from finn.transformation.infer_shapes import InferShapes +from finn.util.basic import gen_finn_dt_tensor +import finn.core.onnx_exec as oxe import pytest @@ -36,3 +44,30 @@ def test_brevitas_avg_pool_export(kernel_size, stride, signed, bit_width): tensor=input_tensor, scale=output_scale, bit_width=ibw_tensor ) bo.export_finn_onnx(b_avgpool, ishape, export_onnx_path, input_t=input_quant_tensor) + model = ModelWrapper(export_onnx_path) + # set FINN datatype + if signed is True: + prefix = "INT" + else: + prefix = "UINT" + dt_name = prefix + str(bit_width) + dtype = DataType[dt_name] + model.set_tensor_datatype(model.graph.input[0].name, dtype) + model = model.transform(InferShapes()) + model = model.transform(InferDataTypes()) + + # calculate golden output + inp = gen_finn_dt_tensor(dtype, ishape) + input_tensor = torch.from_numpy(inp).float() + input_quant_tensor = pack_quant_tensor( + tensor=input_tensor, scale=output_scale, bit_width=ibw_tensor + ) + b_avgpool.eval() + expected = b_avgpool.forward(input_quant_tensor).tensor.detach().numpy() + + # finn execution + idict = {model.graph.input[0].name : inp} + odict = oxe.execute_onnx(model, idict, True) + produced = odict[model.graph.output[0].name] + + os.remove(export_onnx_path)