diff --git a/tests/test_brevitas_export.py b/tests/test_brevitas_export.py
index 80a5bca68b9e1b4520926e577508634ab2a3cf00..3eb58572590408619391a45996377bc6f506503f 100644
--- a/tests/test_brevitas_export.py
+++ b/tests/test_brevitas_export.py
@@ -74,7 +74,7 @@ def test_brevitas_to_onnx_export():
     export_onnx_path = "test_output_lfc.onnx"
     with torch.no_grad():
         lfc = LFC(weight_bit_width=1, act_bit_width=1, in_bit_width=1)
-        bo.export_finn_onnx(lfc, (1, 784), export_onnx_path)
+        bo.export_finn_onnx(lfc, (1, 1, 28, 28), export_onnx_path)
         model = onnx.load(export_onnx_path)
         # TODO the following way of testing is highly sensitive to small changes
         # in PyTorch ONNX export: the order, names, count... of nodes could
@@ -85,15 +85,17 @@ def test_brevitas_to_onnx_export():
         assert model.graph.output[0].type.tensor_type.shape.dim[1].dim_value == 10
         act_node = model.graph.node[8]
         assert act_node.op_type == "QuantizedHardTanh"
-        assert act_node.attribute[0].name == "activation_qnt"
-        assert act_node.attribute[0].s.decode("utf-8") == "BIPOLAR"
         matmul_node = model.graph.node[9]
         assert matmul_node.op_type == "MatMul"
-        assert matmul_node.attribute[1].name == "weight_qnt"
-        assert matmul_node.attribute[1].s.decode("utf-8") == "BIPOLAR"
         assert act_node.output[0] == matmul_node.input[1]
         inits = [x.name for x in model.graph.initializer]
+        qnt_annotations = {
+            a.tensor_name: a.quant_parameter_tensor_names[0].value
+            for a in model.graph.quantization_annotation
+        }
+        assert qnt_annotations[matmul_node.input[0]] == "BIPOLAR"
         assert matmul_node.input[0] in inits
+        assert qnt_annotations[matmul_node.input[1]] == "BIPOLAR"
         init_ind = inits.index(matmul_node.input[0])
         int_weights_pytorch = lfc.features[2].int_weight.detach().numpy()
         int_weights_onnx = nph.to_array(model.graph.initializer[init_ind])