diff --git a/src/finn/core/modelwrapper.py b/src/finn/core/modelwrapper.py
index b8be06d4b55875937d434cb941efc0891b1655c6..ee69d7e74562252ae1f45c52a577b7ecf0cc895b 100644
--- a/src/finn/core/modelwrapper.py
+++ b/src/finn/core/modelwrapper.py
@@ -6,6 +6,7 @@ import onnx.numpy_helper as np_helper
 from onnx import TensorProto
 
 import finn.core.utils as util
+from finn.core.datatype import DataType
 
 
 class ModelWrapper:
@@ -88,6 +89,40 @@ class ModelWrapper:
         # TODO check that all constants are initializers
         return True
 
+    def get_tensor_datatype(self, tensor_name):
+        """Returns the FINN DataType of tensor with given name."""
+        graph = self._model_proto.graph
+        qnt_annotations = graph.quantization_annotation
+        ret = util.get_by_name(qnt_annotations, tensor_name, "tensor_name")
+        if ret is not None:
+            ret = util.get_by_name(
+                ret.quant_parameter_tensor_names, "finn_datatype", "key"
+            )
+            if ret is not None:
+                return DataType[ret.value]
+        # TODO maybe use native ONNX tensor type instead of assuming fp32?
+        return DataType["FLOAT32"]
+
+    def set_tensor_datatype(self, tensor_name, datatype):
+        """Sets the FINN DataType of tensor with given name."""
+        graph = self._model_proto.graph
+        qnt_annotations = graph.quantization_annotation
+        ret = util.get_by_name(qnt_annotations, tensor_name, "tensor_name")
+        if ret is not None:
+            ret = util.get_by_name(
+                ret.quant_parameter_tensor_names, "finn_datatype", "key"
+            )
+            if ret is not None:
+                ret.value = datatype.name
+        else:
+            qa = onnx.TensorAnnotation()
+            dt = onnx.StringStringEntryProto()
+            dt.key = "finn_datatype"
+            dt.value = datatype.name
+            qa.tensor_name = tensor_name
+            qa.quant_parameter_tensor_names.append(dt)
+            qnt_annotations.append(qa)
+
     def get_tensor_shape(self, tensor_name):
         """Returns the shape of tensor with given name, if it has ValueInfoProto."""
         graph = self._model_proto.graph
diff --git a/tests/test_brevitas_export.py b/tests/test_brevitas_export.py
index 83babecf501788ee1600b49bc3ef461712877517..c86ef663b1f3814d9e5f0d9622e62c5c9504ffed 100644
--- a/tests/test_brevitas_export.py
+++ b/tests/test_brevitas_export.py
@@ -13,6 +13,7 @@ from torch.nn import BatchNorm1d, Dropout, Module, ModuleList
 
 import finn.core.onnx_exec as oxe
 import finn.transformation.infer_shapes as si
+from finn.core.datatype import DataType
 from finn.core.modelwrapper import ModelWrapper
 
 FC_OUT_FEATURES = [1024, 1024, 1024]
@@ -100,18 +101,12 @@ def test_brevitas_to_onnx_export():
     matmul_node = model.graph.node[4]
     assert matmul_node.op_type == "MatMul"
     assert act_node.output[0] == matmul_node.input[0]
-    inits = [x.name for x in model.graph.initializer]
-    qnt_annotations = {
-        a.tensor_name: a.quant_parameter_tensor_names[0].value
-        for a in model.graph.quantization_annotation
-    }
-    assert qnt_annotations[matmul_node.input[0]] == "BIPOLAR"
-    assert matmul_node.input[1] in inits
-    assert qnt_annotations[matmul_node.input[1]] == "BIPOLAR"
-    init_ind = inits.index(matmul_node.input[1])
+    assert model.get_tensor_datatype(matmul_node.input[0]) == DataType.BIPOLAR
+    W = model.get_initializer(matmul_node.input[1])
+    assert W is not None
+    assert model.get_tensor_datatype(matmul_node.input[1]) == DataType.BIPOLAR
     int_weights_pytorch = lfc.features[2].int_weight.transpose(1, 0).detach().numpy()
-    int_weights_onnx = nph.to_array(model.graph.initializer[init_ind])
-    assert (int_weights_onnx == int_weights_pytorch).all()
+    assert (W == int_weights_pytorch).all()
     os.remove(export_onnx_path)