Skip to content
Snippets Groups Projects
Commit 528e7a15 authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

use new export_finn_onnx function from Brevitas

parent b8bc8745
No related branches found
No related tags found
No related merge requests found
......@@ -74,29 +74,28 @@ def test_brevitas_to_onnx_export():
export_onnx_path = "test_output_lfc.onnx"
with torch.no_grad():
lfc = LFC(weight_bit_width=1, act_bit_width=1, in_bit_width=1)
lfc = lfc.eval()
bo.prepare_for_onnx_export(lfc, True)
torch.onnx.export(
lfc, torch.empty(784, dtype=torch.float), export_onnx_path, verbose=True
)
bo.export_finn_onnx(lfc, (1, 784), export_onnx_path)
model = onnx.load(export_onnx_path)
# TODO the following way of testing is highly sensitive to small changes
# in PyTorch ONNX export: the order, names, count... of nodes could
# easily change between different versions, and break this test.
assert len(model.graph.input) == 32
assert len(model.graph.node) == 33
assert len(model.graph.input) == 21
assert len(model.graph.node) == 25
assert len(model.graph.output) == 1
assert model.graph.output[0].type.tensor_type.shape.dim[1].dim_value == 10
assert model.graph.node[12].op_type == "QuantizedHardTanh"
assert model.graph.node[13].op_type == "Constant"
assert model.graph.node[14].op_type == "MatMul"
assert model.graph.node[12].output[0] == model.graph.node[14].input[1]
assert model.graph.node[13].output[0] == model.graph.node[14].input[0]
act_node = model.graph.node[8]
assert act_node.op_type == "QuantizedHardTanh"
assert act_node.attribute[0].name == "activation_qnt"
assert act_node.attribute[0].s.decode("utf-8") == "BIPOLAR"
matmul_node = model.graph.node[9]
assert matmul_node.op_type == "MatMul"
assert matmul_node.attribute[1].name == "weight_qnt"
assert matmul_node.attribute[1].s.decode("utf-8") == "BIPOLAR"
assert act_node.output[0] == matmul_node.input[1]
inits = [x.name for x in model.graph.initializer]
assert matmul_node.input[0] in inits
init_ind = inits.index(matmul_node.input[0])
int_weights_pytorch = lfc.features[2].int_weight.detach().numpy()
int_weights_onnx = nph.to_array(model.graph.node[13].attribute[0].t)
int_weights_onnx = nph.to_array(model.graph.initializer[init_ind])
assert (int_weights_onnx == int_weights_pytorch).all()
assert model.graph.node[12].attribute[0].name == "activation_qnt"
assert model.graph.node[12].attribute[0].s.decode("utf-8") == "BIPOLAR"
assert model.graph.node[14].attribute[1].name == "weight_qnt"
assert model.graph.node[14].attribute[1].s.decode("utf-8") == "BIPOLAR"
os.remove(export_onnx_path)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment