From b1eabe4bd58e7731158c5eccf5b3ccd9eea41f27 Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <yamanu@xilinx.com>
Date: Thu, 5 Sep 2019 23:10:59 +0100
Subject: [PATCH] use quantization annotations instead of attributes

---
 tests/test_brevitas_export.py | 12 +++++++-----
 1 file changed, 7 insertions(+), 5 deletions(-)

diff --git a/tests/test_brevitas_export.py b/tests/test_brevitas_export.py
index 80a5bca68..3eb585725 100644
--- a/tests/test_brevitas_export.py
+++ b/tests/test_brevitas_export.py
@@ -74,7 +74,7 @@ def test_brevitas_to_onnx_export():
     export_onnx_path = "test_output_lfc.onnx"
     with torch.no_grad():
         lfc = LFC(weight_bit_width=1, act_bit_width=1, in_bit_width=1)
-        bo.export_finn_onnx(lfc, (1, 784), export_onnx_path)
+        bo.export_finn_onnx(lfc, (1, 1, 28, 28), export_onnx_path)
         model = onnx.load(export_onnx_path)
         # TODO the following way of testing is highly sensitive to small changes
         # in PyTorch ONNX export: the order, names, count... of nodes could
@@ -85,15 +85,17 @@ def test_brevitas_to_onnx_export():
         assert model.graph.output[0].type.tensor_type.shape.dim[1].dim_value == 10
         act_node = model.graph.node[8]
         assert act_node.op_type == "QuantizedHardTanh"
-        assert act_node.attribute[0].name == "activation_qnt"
-        assert act_node.attribute[0].s.decode("utf-8") == "BIPOLAR"
         matmul_node = model.graph.node[9]
         assert matmul_node.op_type == "MatMul"
-        assert matmul_node.attribute[1].name == "weight_qnt"
-        assert matmul_node.attribute[1].s.decode("utf-8") == "BIPOLAR"
         assert act_node.output[0] == matmul_node.input[1]
         inits = [x.name for x in model.graph.initializer]
+        qnt_annotations = {
+            a.tensor_name: a.quant_parameter_tensor_names[0].value
+            for a in model.graph.quantization_annotation
+        }
+        assert qnt_annotations[matmul_node.input[0]] == "BIPOLAR"
         assert matmul_node.input[0] in inits
+        assert qnt_annotations[matmul_node.input[1]] == "BIPOLAR"
         init_ind = inits.index(matmul_node.input[0])
         int_weights_pytorch = lfc.features[2].int_weight.detach().numpy()
         int_weights_onnx = nph.to_array(model.graph.initializer[init_ind])
-- 
GitLab