diff --git a/src/finn/core/modelwrapper.py b/src/finn/core/modelwrapper.py index ed32426abcc8ea71428a7f746a99454e8e4a2c17..2896b09e0f54d6d0492c5330ec5da4110e257d30 100644 --- a/src/finn/core/modelwrapper.py +++ b/src/finn/core/modelwrapper.py @@ -510,3 +510,41 @@ class ModelWrapper: qa.tensor_name = tensor_name qa.quant_parameter_tensor_names.append(dt) qnt_annotations.append(qa) + + def get_tensor_sparsity(self, tensor_name): + """Returns the sparsity of a given tensor as dictionary.""" + graph = self._model_proto.graph + qnt_annotations = graph.quantization_annotation + ret = util.get_by_name(qnt_annotations, tensor_name, "tensor_name") + if ret is not None: + ret = util.get_by_name( + ret.quant_parameter_tensor_names, "tensor_sparsity", "key" + ) + if ret is not None: + return eval(ret.value) + return None + + def set_tensor_sparsity(self, tensor_name, sparsity_dict): + """Sets the sparsity annotation of a tensor with given name.""" + graph = self._model_proto.graph + qnt_annotations = graph.quantization_annotation + ret = util.get_by_name(qnt_annotations, tensor_name, "tensor_name") + if ret is not None: + ret_ts = util.get_by_name( + ret.quant_parameter_tensor_names, "tensor_sparsity", "key" + ) + if ret_ts is not None: + ret_ts.value = str(sparsity_dict) + else: + ts = onnx.StringStringEntryProto() + ts.key = "tensor_sparsity" + ts.value = str(sparsity_dict) + ret.quant_parameter_tensor_names.append(ts) + else: + qa = onnx.TensorAnnotation() + dt = onnx.StringStringEntryProto() + dt.key = "tensor_sparsity" + dt.value = str(sparsity_dict) + qa.tensor_name = tensor_name + qa.quant_parameter_tensor_names.append(dt) + qnt_annotations.append(qa) diff --git a/tests/core/test_modelwrapper.py b/tests/core/test_modelwrapper.py index 4bd9385536bc6721c66726169dfa4c69e5f06772..5fa9b23bad5c5b67f65530c55f862f889c07b1ac 100644 --- a/tests/core/test_modelwrapper.py +++ b/tests/core/test_modelwrapper.py @@ -73,6 +73,11 @@ def test_modelwrapper(): inp_layout = DataLayout.NCHW model.set_tensor_layout(inp_name, inp_layout) assert model.get_tensor_layout(inp_name) == inp_layout + inp_sparsity = model.get_tensor_sparsity(inp_name) + assert inp_sparsity is None + inp_sparsity = {"dw": {"kernel_shape": 3}} + model.set_tensor_sparsity(inp_name, inp_sparsity) + assert model.get_tensor_sparsity(inp_name) == inp_sparsity os.remove(export_onnx_path)