From b1eabe4bd58e7731158c5eccf5b3ccd9eea41f27 Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu <yamanu@xilinx.com> Date: Thu, 5 Sep 2019 23:10:59 +0100 Subject: [PATCH] use quantization annotations instead of attributes --- tests/test_brevitas_export.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/tests/test_brevitas_export.py b/tests/test_brevitas_export.py index 80a5bca68..3eb585725 100644 --- a/tests/test_brevitas_export.py +++ b/tests/test_brevitas_export.py @@ -74,7 +74,7 @@ def test_brevitas_to_onnx_export(): export_onnx_path = "test_output_lfc.onnx" with torch.no_grad(): lfc = LFC(weight_bit_width=1, act_bit_width=1, in_bit_width=1) - bo.export_finn_onnx(lfc, (1, 784), export_onnx_path) + bo.export_finn_onnx(lfc, (1, 1, 28, 28), export_onnx_path) model = onnx.load(export_onnx_path) # TODO the following way of testing is highly sensitive to small changes # in PyTorch ONNX export: the order, names, count... of nodes could @@ -85,15 +85,17 @@ def test_brevitas_to_onnx_export(): assert model.graph.output[0].type.tensor_type.shape.dim[1].dim_value == 10 act_node = model.graph.node[8] assert act_node.op_type == "QuantizedHardTanh" - assert act_node.attribute[0].name == "activation_qnt" - assert act_node.attribute[0].s.decode("utf-8") == "BIPOLAR" matmul_node = model.graph.node[9] assert matmul_node.op_type == "MatMul" - assert matmul_node.attribute[1].name == "weight_qnt" - assert matmul_node.attribute[1].s.decode("utf-8") == "BIPOLAR" assert act_node.output[0] == matmul_node.input[1] inits = [x.name for x in model.graph.initializer] + qnt_annotations = { + a.tensor_name: a.quant_parameter_tensor_names[0].value + for a in model.graph.quantization_annotation + } + assert qnt_annotations[matmul_node.input[0]] == "BIPOLAR" assert matmul_node.input[0] in inits + assert qnt_annotations[matmul_node.input[1]] == "BIPOLAR" init_ind = inits.index(matmul_node.input[0]) int_weights_pytorch = lfc.features[2].int_weight.detach().numpy() int_weights_onnx = nph.to_array(model.graph.initializer[init_ind]) -- GitLab