From 313176337deff5c703381598679a3939b8fc350f Mon Sep 17 00:00:00 2001 From: auphelia <jakobapk@web.de> Date: Mon, 25 May 2020 16:33:48 +0100 Subject: [PATCH] [Test] Add test for QConv2d export --- tests/brevitas/test_brevitas_QConv2d.py | 76 +++++++++++++++++++++++++ 1 file changed, 76 insertions(+) create mode 100644 tests/brevitas/test_brevitas_QConv2d.py diff --git a/tests/brevitas/test_brevitas_QConv2d.py b/tests/brevitas/test_brevitas_QConv2d.py new file mode 100644 index 000000000..198f1e796 --- /dev/null +++ b/tests/brevitas/test_brevitas_QConv2d.py @@ -0,0 +1,76 @@ +import pytest +import os +import numpy as np +import torch +import brevitas.onnx as bo +from brevitas.nn import QuantConv2d +from brevitas.core.restrict_val import RestrictValueType +from brevitas.core.quant import QuantType +from brevitas.core.scaling import ScalingImplType +from brevitas.core.stats import StatsOp + +from finn.core.modelwrapper import ModelWrapper +from finn.core.datatype import DataType +import finn.core.onnx_exec as oxe +from finn.transformation.infer_shapes import InferShapes +from finn.util.basic import gen_finn_dt_tensor + +export_onnx_path = "test_brevitas_conv.onnx" + + +@pytest.mark.parametrize("dw", [False, True]) +@pytest.mark.parametrize("in_channels", [32]) +def test_brevitas_QConv2d(dw, in_channels): + ishape = (1, 32, 111, 111) + if dw is True: + groups = in_channels + out_channels = in_channels + kernel_size = 3 + padding = 1 + stride = 1 + w_shape = (32, 1, 3, 3) + + else: + groups = 1 + out_channels = 64 + kernel_size = 1 + padding = 0 + stride = 1 + w_shape = (64, 32, 1, 1) + + b_conv = QuantConv2d( + in_channels=in_channels, + out_channels=out_channels, + groups=groups, + kernel_size=kernel_size, + padding=padding, + stride=stride, + bias=False, + bias_quant_type=QuantType.FP, + compute_output_bit_width=False, + compute_output_scale=False, + weight_bit_width=4, + weight_quant_type=QuantType.INT, + weight_scaling_impl_type=ScalingImplType.STATS, + weight_scaling_stats_op=StatsOp.MAX, + weight_scaling_per_output_channel=True, + weight_restrict_scaling_type=RestrictValueType.LOG_FP, + weight_narrow_range=True, + weight_scaling_min_val=2e-16, + ) + weight_tensor = gen_finn_dt_tensor(DataType.INT4, w_shape) + b_conv.weight = torch.nn.Parameter(torch.from_numpy(weight_tensor).float()) + + bo.export_finn_onnx(b_conv, ishape, export_onnx_path) + model = ModelWrapper(export_onnx_path) + model = model.transform(InferShapes()) + inp_tensor = np.random.uniform(low=-1.0, high=1.0, size=ishape).astype(np.float32) + idict = {model.graph.input[0].name: inp_tensor} + odict = oxe.execute_onnx(model, idict, True) + produced = odict[model.graph.output[0].name] + inp_tensor = torch.from_numpy(inp_tensor).float() + b_conv.eval() + expected = b_conv.forward(inp_tensor).detach().numpy() + + assert np.isclose(produced, expected, atol=1e-3).all() + os.remove(export_onnx_path) -- GitLab