From 528e7a15d244bdc463d10ee60e6abad1cc696458 Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu <yamanu@xilinx.com> Date: Wed, 4 Sep 2019 15:51:34 +0100 Subject: [PATCH] use new export_finn_onnx function from Brevitas --- tests/test_brevitas_export.py | 33 ++++++++++++++++----------------- 1 file changed, 16 insertions(+), 17 deletions(-) diff --git a/tests/test_brevitas_export.py b/tests/test_brevitas_export.py index b3aa98793..80a5bca68 100644 --- a/tests/test_brevitas_export.py +++ b/tests/test_brevitas_export.py @@ -74,29 +74,28 @@ def test_brevitas_to_onnx_export(): export_onnx_path = "test_output_lfc.onnx" with torch.no_grad(): lfc = LFC(weight_bit_width=1, act_bit_width=1, in_bit_width=1) - lfc = lfc.eval() - bo.prepare_for_onnx_export(lfc, True) - torch.onnx.export( - lfc, torch.empty(784, dtype=torch.float), export_onnx_path, verbose=True - ) + bo.export_finn_onnx(lfc, (1, 784), export_onnx_path) model = onnx.load(export_onnx_path) # TODO the following way of testing is highly sensitive to small changes # in PyTorch ONNX export: the order, names, count... of nodes could # easily change between different versions, and break this test. - assert len(model.graph.input) == 32 - assert len(model.graph.node) == 33 + assert len(model.graph.input) == 21 + assert len(model.graph.node) == 25 assert len(model.graph.output) == 1 assert model.graph.output[0].type.tensor_type.shape.dim[1].dim_value == 10 - assert model.graph.node[12].op_type == "QuantizedHardTanh" - assert model.graph.node[13].op_type == "Constant" - assert model.graph.node[14].op_type == "MatMul" - assert model.graph.node[12].output[0] == model.graph.node[14].input[1] - assert model.graph.node[13].output[0] == model.graph.node[14].input[0] + act_node = model.graph.node[8] + assert act_node.op_type == "QuantizedHardTanh" + assert act_node.attribute[0].name == "activation_qnt" + assert act_node.attribute[0].s.decode("utf-8") == "BIPOLAR" + matmul_node = model.graph.node[9] + assert matmul_node.op_type == "MatMul" + assert matmul_node.attribute[1].name == "weight_qnt" + assert matmul_node.attribute[1].s.decode("utf-8") == "BIPOLAR" + assert act_node.output[0] == matmul_node.input[1] + inits = [x.name for x in model.graph.initializer] + assert matmul_node.input[0] in inits + init_ind = inits.index(matmul_node.input[0]) int_weights_pytorch = lfc.features[2].int_weight.detach().numpy() - int_weights_onnx = nph.to_array(model.graph.node[13].attribute[0].t) + int_weights_onnx = nph.to_array(model.graph.initializer[init_ind]) assert (int_weights_onnx == int_weights_pytorch).all() - assert model.graph.node[12].attribute[0].name == "activation_qnt" - assert model.graph.node[12].attribute[0].s.decode("utf-8") == "BIPOLAR" - assert model.graph.node[14].attribute[1].name == "weight_qnt" - assert model.graph.node[14].attribute[1].s.decode("utf-8") == "BIPOLAR" os.remove(export_onnx_path) -- GitLab