Skip to content
Snippets Groups Projects
Commit a9a70b98 authored by Hendrik Borras's avatar Hendrik Borras
Browse files

Added QONNX_export test to test_brevitas_qlinear test.

parent 6ec329d6
No related branches found
No related tags found
No related merge requests found
......@@ -33,12 +33,15 @@ import numpy as np
import os
import torch
from brevitas.core.quant import QuantType
from brevitas.export.onnx.generic.manager import BrevitasONNXManager
from brevitas.nn import QuantLinear
from qonnx.util.cleanup import cleanup as qonnx_cleanup
import finn.core.onnx_exec as oxe
from finn.core.datatype import DataType
from finn.core.modelwrapper import ModelWrapper
from finn.transformation.infer_shapes import InferShapes
from finn.transformation.qonnx.convert_qonnx_to_finn import ConvertQONNXtoFINN
from finn.util.basic import gen_finn_dt_tensor
export_onnx_path = "test_brevitas_qlinear.onnx"
......@@ -49,7 +52,10 @@ export_onnx_path = "test_brevitas_qlinear.onnx"
@pytest.mark.parametrize("in_features", [3])
@pytest.mark.parametrize("w_bits", [4])
@pytest.mark.parametrize("i_dtype", [DataType.UINT4])
def test_brevitas_qlinear(bias, out_features, in_features, w_bits, i_dtype):
@pytest.mark.parametrize("QONNX_export", [False, True])
def test_brevitas_qlinear(
bias, out_features, in_features, w_bits, i_dtype, QONNX_export
):
i_shape = (1, in_features)
w_shape = (out_features, in_features)
b_linear = QuantLinear(
......@@ -66,7 +72,15 @@ def test_brevitas_qlinear(bias, out_features, in_features, w_bits, i_dtype):
)
b_linear.weight.data = torch.from_numpy(weight_tensor_fp)
b_linear.eval()
bo.export_finn_onnx(b_linear, i_shape, export_onnx_path)
if QONNX_export:
m_path = export_onnx_path
BrevitasONNXManager.export(b_linear, i_shape, m_path)
qonnx_cleanup(m_path, out_file=m_path)
model = ModelWrapper(m_path)
model = model.transform(ConvertQONNXtoFINN())
model.save(m_path)
else:
bo.export_finn_onnx(b_linear, i_shape, export_onnx_path)
model = ModelWrapper(export_onnx_path)
model = model.transform(InferShapes())
inp_tensor = gen_finn_dt_tensor(i_dtype, i_shape)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment