diff --git a/src/finn/core/modelwrapper.py b/src/finn/core/modelwrapper.py index b8be06d4b55875937d434cb941efc0891b1655c6..ee69d7e74562252ae1f45c52a577b7ecf0cc895b 100644 --- a/src/finn/core/modelwrapper.py +++ b/src/finn/core/modelwrapper.py @@ -6,6 +6,7 @@ import onnx.numpy_helper as np_helper from onnx import TensorProto import finn.core.utils as util +from finn.core.datatype import DataType class ModelWrapper: @@ -88,6 +89,40 @@ class ModelWrapper: # TODO check that all constants are initializers return True + def get_tensor_datatype(self, tensor_name): + """Returns the FINN DataType of tensor with given name.""" + graph = self._model_proto.graph + qnt_annotations = graph.quantization_annotation + ret = util.get_by_name(qnt_annotations, tensor_name, "tensor_name") + if ret is not None: + ret = util.get_by_name( + ret.quant_parameter_tensor_names, "finn_datatype", "key" + ) + if ret is not None: + return DataType[ret.value] + # TODO maybe use native ONNX tensor type instead of assuming fp32? + return DataType["FLOAT32"] + + def set_tensor_datatype(self, tensor_name, datatype): + """Sets the FINN DataType of tensor with given name.""" + graph = self._model_proto.graph + qnt_annotations = graph.quantization_annotation + ret = util.get_by_name(qnt_annotations, tensor_name, "tensor_name") + if ret is not None: + ret = util.get_by_name( + ret.quant_parameter_tensor_names, "finn_datatype", "key" + ) + if ret is not None: + ret.value = datatype.name + else: + qa = onnx.TensorAnnotation() + dt = onnx.StringStringEntryProto() + dt.key = "finn_datatype" + dt.value = datatype.name + qa.tensor_name = tensor_name + qa.quant_parameter_tensor_names.append(dt) + qnt_annotations.append(qa) + def get_tensor_shape(self, tensor_name): """Returns the shape of tensor with given name, if it has ValueInfoProto.""" graph = self._model_proto.graph diff --git a/tests/test_brevitas_export.py b/tests/test_brevitas_export.py index 83babecf501788ee1600b49bc3ef461712877517..c86ef663b1f3814d9e5f0d9622e62c5c9504ffed 100644 --- a/tests/test_brevitas_export.py +++ b/tests/test_brevitas_export.py @@ -13,6 +13,7 @@ from torch.nn import BatchNorm1d, Dropout, Module, ModuleList import finn.core.onnx_exec as oxe import finn.transformation.infer_shapes as si +from finn.core.datatype import DataType from finn.core.modelwrapper import ModelWrapper FC_OUT_FEATURES = [1024, 1024, 1024] @@ -100,18 +101,12 @@ def test_brevitas_to_onnx_export(): matmul_node = model.graph.node[4] assert matmul_node.op_type == "MatMul" assert act_node.output[0] == matmul_node.input[0] - inits = [x.name for x in model.graph.initializer] - qnt_annotations = { - a.tensor_name: a.quant_parameter_tensor_names[0].value - for a in model.graph.quantization_annotation - } - assert qnt_annotations[matmul_node.input[0]] == "BIPOLAR" - assert matmul_node.input[1] in inits - assert qnt_annotations[matmul_node.input[1]] == "BIPOLAR" - init_ind = inits.index(matmul_node.input[1]) + assert model.get_tensor_datatype(matmul_node.input[0]) == DataType.BIPOLAR + W = model.get_initializer(matmul_node.input[1]) + assert W is not None + assert model.get_tensor_datatype(matmul_node.input[1]) == DataType.BIPOLAR int_weights_pytorch = lfc.features[2].int_weight.transpose(1, 0).detach().numpy() - int_weights_onnx = nph.to_array(model.graph.initializer[init_ind]) - assert (int_weights_onnx == int_weights_pytorch).all() + assert (W == int_weights_pytorch).all() os.remove(export_onnx_path)