From 5f99ef71088b1474d221167a647dd16f2e7817d2 Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <maltanar@gmail.com>
Date: Thu, 14 May 2020 11:19:16 +0100
Subject: [PATCH] [Test] fix broken test_brevitas_act_export

---
 .../test_brevitas_non_scaled_QuantHardTanh_export.py   | 10 +++++++++-
 1 file changed, 9 insertions(+), 1 deletion(-)

diff --git a/tests/brevitas/test_brevitas_non_scaled_QuantHardTanh_export.py b/tests/brevitas/test_brevitas_non_scaled_QuantHardTanh_export.py
index a50e8eae7..c22f30c6f 100644
--- a/tests/brevitas/test_brevitas_non_scaled_QuantHardTanh_export.py
+++ b/tests/brevitas/test_brevitas_non_scaled_QuantHardTanh_export.py
@@ -5,11 +5,11 @@ import brevitas.onnx as bo
 from brevitas.nn import QuantHardTanh
 from brevitas.core.restrict_val import RestrictValueType
 from brevitas.core.scaling import ScalingImplType
-from models.common import get_quant_type
 import pytest
 from finn.core.modelwrapper import ModelWrapper
 import finn.core.onnx_exec as oxe
 from finn.transformation.infer_shapes import InferShapes
+from brevitas.core.quant import QuantType
 
 export_onnx_path = "test_act.onnx"
 
@@ -18,6 +18,14 @@ export_onnx_path = "test_act.onnx"
 @pytest.mark.parametrize("narrow_range", [False, True])
 @pytest.mark.parametrize("max_val", [1.0, 1 - 2 ** (-7)])
 def test_brevitas_act_export(abits, narrow_range, max_val):
+    def get_quant_type(bit_width):
+        if bit_width is None:
+            return QuantType.FP
+        elif bit_width == 1:
+            return QuantType.BINARY
+        else:
+            return QuantType.INT
+
     act_quant_type = get_quant_type(abits)
     min_val = -1.0
     ishape = (1, 10)
-- 
GitLab