diff --git a/tests/test_brevitas_export.py b/tests/test_brevitas_export.py index 80a5bca68b9e1b4520926e577508634ab2a3cf00..3eb58572590408619391a45996377bc6f506503f 100644 --- a/tests/test_brevitas_export.py +++ b/tests/test_brevitas_export.py @@ -74,7 +74,7 @@ def test_brevitas_to_onnx_export(): export_onnx_path = "test_output_lfc.onnx" with torch.no_grad(): lfc = LFC(weight_bit_width=1, act_bit_width=1, in_bit_width=1) - bo.export_finn_onnx(lfc, (1, 784), export_onnx_path) + bo.export_finn_onnx(lfc, (1, 1, 28, 28), export_onnx_path) model = onnx.load(export_onnx_path) # TODO the following way of testing is highly sensitive to small changes # in PyTorch ONNX export: the order, names, count... of nodes could @@ -85,15 +85,17 @@ def test_brevitas_to_onnx_export(): assert model.graph.output[0].type.tensor_type.shape.dim[1].dim_value == 10 act_node = model.graph.node[8] assert act_node.op_type == "QuantizedHardTanh" - assert act_node.attribute[0].name == "activation_qnt" - assert act_node.attribute[0].s.decode("utf-8") == "BIPOLAR" matmul_node = model.graph.node[9] assert matmul_node.op_type == "MatMul" - assert matmul_node.attribute[1].name == "weight_qnt" - assert matmul_node.attribute[1].s.decode("utf-8") == "BIPOLAR" assert act_node.output[0] == matmul_node.input[1] inits = [x.name for x in model.graph.initializer] + qnt_annotations = { + a.tensor_name: a.quant_parameter_tensor_names[0].value + for a in model.graph.quantization_annotation + } + assert qnt_annotations[matmul_node.input[0]] == "BIPOLAR" assert matmul_node.input[0] in inits + assert qnt_annotations[matmul_node.input[1]] == "BIPOLAR" init_ind = inits.index(matmul_node.input[0]) int_weights_pytorch = lfc.features[2].int_weight.detach().numpy() int_weights_onnx = nph.to_array(model.graph.initializer[init_ind])