From 5f99ef71088b1474d221167a647dd16f2e7817d2 Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu <maltanar@gmail.com> Date: Thu, 14 May 2020 11:19:16 +0100 Subject: [PATCH] [Test] fix broken test_brevitas_act_export --- .../test_brevitas_non_scaled_QuantHardTanh_export.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/tests/brevitas/test_brevitas_non_scaled_QuantHardTanh_export.py b/tests/brevitas/test_brevitas_non_scaled_QuantHardTanh_export.py index a50e8eae7..c22f30c6f 100644 --- a/tests/brevitas/test_brevitas_non_scaled_QuantHardTanh_export.py +++ b/tests/brevitas/test_brevitas_non_scaled_QuantHardTanh_export.py @@ -5,11 +5,11 @@ import brevitas.onnx as bo from brevitas.nn import QuantHardTanh from brevitas.core.restrict_val import RestrictValueType from brevitas.core.scaling import ScalingImplType -from models.common import get_quant_type import pytest from finn.core.modelwrapper import ModelWrapper import finn.core.onnx_exec as oxe from finn.transformation.infer_shapes import InferShapes +from brevitas.core.quant import QuantType export_onnx_path = "test_act.onnx" @@ -18,6 +18,14 @@ export_onnx_path = "test_act.onnx" @pytest.mark.parametrize("narrow_range", [False, True]) @pytest.mark.parametrize("max_val", [1.0, 1 - 2 ** (-7)]) def test_brevitas_act_export(abits, narrow_range, max_val): + def get_quant_type(bit_width): + if bit_width is None: + return QuantType.FP + elif bit_width == 1: + return QuantType.BINARY + else: + return QuantType.INT + act_quant_type = get_quant_type(abits) min_val = -1.0 ishape = (1, 10) -- GitLab