Skip to content
Snippets Groups Projects
Commit 6ec329d6 authored by Hendrik Borras's avatar Hendrik Borras
Browse files

Added QONNX_export test to test_brevitas_QConv2d test.

parent 828de10f
No related branches found
No related tags found
No related merge requests found
......@@ -36,12 +36,15 @@ from brevitas.core.quant import QuantType
from brevitas.core.restrict_val import RestrictValueType
from brevitas.core.scaling import ScalingImplType
from brevitas.core.stats import StatsOp
from brevitas.export.onnx.generic.manager import BrevitasONNXManager
from brevitas.nn import QuantConv2d
from qonnx.util.cleanup import cleanup as qonnx_cleanup
import finn.core.onnx_exec as oxe
from finn.core.datatype import DataType
from finn.core.modelwrapper import ModelWrapper
from finn.transformation.infer_shapes import InferShapes
from finn.transformation.qonnx.convert_qonnx_to_finn import ConvertQONNXtoFINN
from finn.util.basic import gen_finn_dt_tensor
export_onnx_path = "test_brevitas_conv.onnx"
......@@ -50,7 +53,8 @@ export_onnx_path = "test_brevitas_conv.onnx"
@pytest.mark.parametrize("dw", [False, True])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("in_channels", [32])
def test_brevitas_QConv2d(dw, bias, in_channels):
@pytest.mark.parametrize("QONNX_export", [False, True])
def test_brevitas_QConv2d(dw, bias, in_channels, QONNX_export):
ishape = (1, 32, 111, 111)
if dw is True:
groups = in_channels
......@@ -89,7 +93,15 @@ def test_brevitas_QConv2d(dw, bias, in_channels):
weight_tensor = gen_finn_dt_tensor(DataType.INT4, w_shape)
b_conv.weight = torch.nn.Parameter(torch.from_numpy(weight_tensor).float())
b_conv.eval()
bo.export_finn_onnx(b_conv, ishape, export_onnx_path)
if QONNX_export:
m_path = export_onnx_path
BrevitasONNXManager.export(b_conv, ishape, m_path)
qonnx_cleanup(m_path, out_file=m_path)
model = ModelWrapper(m_path)
model = model.transform(ConvertQONNXtoFINN())
model.save(m_path)
else:
bo.export_finn_onnx(b_conv, ishape, export_onnx_path)
model = ModelWrapper(export_onnx_path)
model = model.transform(InferShapes())
inp_tensor = np.random.uniform(low=-1.0, high=1.0, size=ishape).astype(np.float32)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment