diff --git a/tests/brevitas/test_brevitas_act_export.py b/tests/brevitas/test_brevitas_non_scaled_QuantHardTanh_export.py similarity index 94% rename from tests/brevitas/test_brevitas_act_export.py rename to tests/brevitas/test_brevitas_non_scaled_QuantHardTanh_export.py index e1cfc9db98a9c9d746b6f66ee071ddfb85cc5dbb..b66348a9902802bc65b2a35e8bc3e311cc81e0bc 100644 --- a/tests/brevitas/test_brevitas_act_export.py +++ b/tests/brevitas/test_brevitas_non_scaled_QuantHardTanh_export.py @@ -1,15 +1,16 @@ import os +import onnx # noqa import numpy as np import torch import brevitas.onnx as bo from brevitas.nn import QuantHardTanh from brevitas.core.restrict_val import RestrictValueType -from brevitas.core.quant import QuantType from brevitas.core.scaling import ScalingImplType import pytest from finn.core.modelwrapper import ModelWrapper import finn.core.onnx_exec as oxe from finn.transformation.infer_shapes import InferShapes +from brevitas.core.quant import QuantType export_onnx_path = "test_act.onnx" @@ -17,7 +18,7 @@ export_onnx_path = "test_act.onnx" @pytest.mark.parametrize("abits", [1, 2, 4, 8]) @pytest.mark.parametrize("narrow_range", [False, True]) @pytest.mark.parametrize("max_val", [1.0, 1 - 2 ** (-7)]) -def test_brevitas_act_export(abits, narrow_range, max_val): +def test_brevitas_act_export_qhardtanh_nonscaled(abits, narrow_range, max_val): def get_quant_type(bit_width): if bit_width is None: return QuantType.FP diff --git a/tests/brevitas/test_brevitas_relu_act_export.py b/tests/brevitas/test_brevitas_relu_act_export.py new file mode 100644 index 0000000000000000000000000000000000000000..c9d8f2d812bc7bea1a2fd2598a7711099ad421e6 --- /dev/null +++ b/tests/brevitas/test_brevitas_relu_act_export.py @@ -0,0 +1,69 @@ +import os +import onnx # noqa +import numpy as np +import torch +import brevitas.onnx as bo +from brevitas.nn import QuantReLU +from brevitas.core.quant import QuantType +from brevitas.core.restrict_val import RestrictValueType +from brevitas.core.scaling import ScalingImplType +import pytest +from finn.core.modelwrapper import ModelWrapper +import finn.core.onnx_exec as oxe +from finn.transformation.infer_shapes import InferShapes + +export_onnx_path = "test_act.onnx" + + +@pytest.mark.parametrize("abits", [1, 2, 4, 8]) +@pytest.mark.parametrize("max_val", [1.0, 1.5, 1 - 2 ** (-7)]) +@pytest.mark.parametrize( + "scaling_impl_type", [ScalingImplType.CONST, ScalingImplType.PARAMETER] +) +def test_brevitas_act_export_relu(abits, max_val, scaling_impl_type): + min_val = -1.0 + ishape = (1, 15) + b_act = QuantReLU( + bit_width=abits, + max_val=max_val, + scaling_impl_type=scaling_impl_type, + restrict_scaling_type=RestrictValueType.LOG_FP, + quant_type=QuantType.INT, + ) + if scaling_impl_type == ScalingImplType.PARAMETER: + checkpoint = { + "act_quant_proxy.fused_activation_quant_proxy.tensor_quant.\ +scaling_impl.learned_value": torch.tensor( + 0.49 + ).type( + torch.FloatTensor + ) + } + b_act.load_state_dict(checkpoint) + + bo.export_finn_onnx(b_act, ishape, export_onnx_path) + model = ModelWrapper(export_onnx_path) + model = model.transform(InferShapes()) + inp_tensor = np.random.uniform(low=min_val, high=max_val, size=ishape).astype( + np.float32 + ) + idict = {model.graph.input[0].name: inp_tensor} + odict = oxe.execute_onnx(model, idict, True) + produced = odict[model.graph.output[0].name] + inp_tensor = torch.from_numpy(inp_tensor).float() + b_act.eval() + expected = b_act.forward(inp_tensor).detach().numpy() + if not np.isclose(produced, expected, atol=1e-3).all(): + print(abits, max_val, scaling_impl_type) + print("scale: ", b_act.quant_act_scale().type(torch.FloatTensor).detach()) + if abits < 5: + print( + "thres:", + ", ".join(["{:8.4f}".format(x) for x in b_act.export_thres[0]]), + ) + print("input:", ", ".join(["{:8.4f}".format(x) for x in inp_tensor[0]])) + print("prod :", ", ".join(["{:8.4f}".format(x) for x in produced[0]])) + print("expec:", ", ".join(["{:8.4f}".format(x) for x in expected[0]])) + + assert np.isclose(produced, expected, atol=1e-3).all() + os.remove(export_onnx_path) diff --git a/tests/brevitas/test_brevitas_scaled_QHardTanh_export.py b/tests/brevitas/test_brevitas_scaled_QHardTanh_export.py new file mode 100644 index 0000000000000000000000000000000000000000..d499f1517341477eca9915245da9ad12c346c5a9 --- /dev/null +++ b/tests/brevitas/test_brevitas_scaled_QHardTanh_export.py @@ -0,0 +1,93 @@ +import onnx # noqa +import os +import numpy as np +import torch +import brevitas.onnx as bo +from brevitas.nn import QuantHardTanh +from brevitas.core.restrict_val import RestrictValueType +from brevitas.core.quant import QuantType +from brevitas.core.scaling import ScalingImplType +import pytest +from finn.core.modelwrapper import ModelWrapper +import finn.core.onnx_exec as oxe +from finn.transformation.infer_shapes import InferShapes + +export_onnx_path = "test_act.onnx" + + +@pytest.mark.parametrize("abits", [2, 4, 8]) +@pytest.mark.parametrize("narrow_range", [False, True]) +@pytest.mark.parametrize("min_val", [-1.0, -(1 - 2 ** (-7)), -2]) +@pytest.mark.parametrize("max_val", [1.0, 1 - 2 ** (-7), 2]) +@pytest.mark.parametrize( + "scaling_impl_type", [ScalingImplType.CONST, ScalingImplType.PARAMETER] +) +def test_brevitas_act_export_qhardtanh_scaled( + abits, narrow_range, min_val, max_val, scaling_impl_type +): + def get_quant_type(bit_width): + if bit_width is None: + return QuantType.FP + elif bit_width == 1: + return QuantType.BINARY + else: + return QuantType.INT + + act_quant_type = get_quant_type(abits) + ishape = (1, 15) + b_act = QuantHardTanh( + bit_width=abits, + quant_type=act_quant_type, + max_val=max_val, + min_val=min_val, + restrict_scaling_type=RestrictValueType.LOG_FP, + scaling_impl_type=scaling_impl_type, + narrow_range=narrow_range, + ) + if scaling_impl_type == ScalingImplType.PARAMETER: + checkpoint = { + "act_quant_proxy.fused_activation_quant_proxy.\ +tensor_quant.scaling_impl.learned_value": torch.tensor( + 0.49 + ).type( + torch.FloatTensor + ) + } + b_act.load_state_dict(checkpoint) + + bo.export_finn_onnx(b_act, ishape, export_onnx_path) + model = ModelWrapper(export_onnx_path) + model = model.transform(InferShapes()) + inp_tensor = np.random.uniform(low=min_val, high=max_val, size=ishape).astype( + np.float32 + ) + idict = {model.graph.input[0].name: inp_tensor} + odict = oxe.execute_onnx(model, idict, True) + produced = odict[model.graph.output[0].name] + inp_tensor = torch.from_numpy(inp_tensor).float() + b_act.eval() + expected = b_act.forward(inp_tensor).detach().numpy() + if not np.isclose(produced, expected, atol=1e-3).all(): + print( + "abits: ", + abits, + " | narrow_range: ", + narrow_range, + " | min_val: ", + min_val, + " | max_val: ", + max_val, + ) + print("layer scale: ", b_act.quant_act_scale().type(torch.FloatTensor).detach()) + print("export scale: ", b_act.export_act_scale) + if abits < 5: + print( + "thres:", + ", ".join(["{:8.4f}".format(x) for x in b_act.export_thres[0]]), + ) + print("input:", ", ".join(["{:8.4f}".format(x) for x in inp_tensor[0]])) + print("prod :", ", ".join(["{:8.4f}".format(x) for x in produced[0]])) + print("expec:", ", ".join(["{:8.4f}".format(x) for x in expected[0]])) + + assert np.isclose(produced, expected, atol=1e-3).all() + os.remove(export_onnx_path)