Skip to content
Snippets Groups Projects
Commit 00814af4 authored by auphelia's avatar auphelia
Browse files

Merge branch 'feature/sparsity_annotation' into feature/depthwise_convolution

parents 80e3d258 779ebda8
No related branches found
No related tags found
No related merge requests found
......@@ -510,3 +510,41 @@ class ModelWrapper:
qa.tensor_name = tensor_name
qa.quant_parameter_tensor_names.append(dt)
qnt_annotations.append(qa)
def get_tensor_sparsity(self, tensor_name):
"""Returns the sparsity of a given tensor as dictionary."""
graph = self._model_proto.graph
qnt_annotations = graph.quantization_annotation
ret = util.get_by_name(qnt_annotations, tensor_name, "tensor_name")
if ret is not None:
ret = util.get_by_name(
ret.quant_parameter_tensor_names, "tensor_sparsity", "key"
)
if ret is not None:
return eval(ret.value)
return None
def set_tensor_sparsity(self, tensor_name, sparsity_dict):
"""Sets the sparsity annotation of a tensor with given name."""
graph = self._model_proto.graph
qnt_annotations = graph.quantization_annotation
ret = util.get_by_name(qnt_annotations, tensor_name, "tensor_name")
if ret is not None:
ret_ts = util.get_by_name(
ret.quant_parameter_tensor_names, "tensor_sparsity", "key"
)
if ret_ts is not None:
ret_ts.value = str(sparsity_dict)
else:
ts = onnx.StringStringEntryProto()
ts.key = "tensor_sparsity"
ts.value = str(sparsity_dict)
ret.quant_parameter_tensor_names.append(ts)
else:
qa = onnx.TensorAnnotation()
dt = onnx.StringStringEntryProto()
dt.key = "tensor_sparsity"
dt.value = str(sparsity_dict)
qa.tensor_name = tensor_name
qa.quant_parameter_tensor_names.append(dt)
qnt_annotations.append(qa)
......@@ -73,6 +73,11 @@ def test_modelwrapper():
inp_layout = DataLayout.NCHW
model.set_tensor_layout(inp_name, inp_layout)
assert model.get_tensor_layout(inp_name) == inp_layout
inp_sparsity = model.get_tensor_sparsity(inp_name)
assert inp_sparsity is None
inp_sparsity = {"dw": {"kernel_shape": 3}}
model.set_tensor_sparsity(inp_name, inp_sparsity)
assert model.get_tensor_sparsity(inp_name) == inp_sparsity
os.remove(export_onnx_path)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment