Skip to content
Snippets Groups Projects
Commit b1eabe4b authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

use quantization annotations instead of attributes

parent 528e7a15
No related branches found
No related tags found
No related merge requests found
......@@ -74,7 +74,7 @@ def test_brevitas_to_onnx_export():
export_onnx_path = "test_output_lfc.onnx"
with torch.no_grad():
lfc = LFC(weight_bit_width=1, act_bit_width=1, in_bit_width=1)
bo.export_finn_onnx(lfc, (1, 784), export_onnx_path)
bo.export_finn_onnx(lfc, (1, 1, 28, 28), export_onnx_path)
model = onnx.load(export_onnx_path)
# TODO the following way of testing is highly sensitive to small changes
# in PyTorch ONNX export: the order, names, count... of nodes could
......@@ -85,15 +85,17 @@ def test_brevitas_to_onnx_export():
assert model.graph.output[0].type.tensor_type.shape.dim[1].dim_value == 10
act_node = model.graph.node[8]
assert act_node.op_type == "QuantizedHardTanh"
assert act_node.attribute[0].name == "activation_qnt"
assert act_node.attribute[0].s.decode("utf-8") == "BIPOLAR"
matmul_node = model.graph.node[9]
assert matmul_node.op_type == "MatMul"
assert matmul_node.attribute[1].name == "weight_qnt"
assert matmul_node.attribute[1].s.decode("utf-8") == "BIPOLAR"
assert act_node.output[0] == matmul_node.input[1]
inits = [x.name for x in model.graph.initializer]
qnt_annotations = {
a.tensor_name: a.quant_parameter_tensor_names[0].value
for a in model.graph.quantization_annotation
}
assert qnt_annotations[matmul_node.input[0]] == "BIPOLAR"
assert matmul_node.input[0] in inits
assert qnt_annotations[matmul_node.input[1]] == "BIPOLAR"
init_ind = inits.index(matmul_node.input[0])
int_weights_pytorch = lfc.features[2].int_weight.detach().numpy()
int_weights_onnx = nph.to_array(model.graph.initializer[init_ind])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment