Skip to content
Snippets Groups Projects
Commit dcb16326 authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Test] fix LFC export test after Brevitas updates

parent c2e272b9
No related branches found
No related tags found
No related merge requests found
......@@ -5,8 +5,6 @@ from operator import mul
import brevitas.onnx as bo
import onnx
import onnx.numpy_helper as nph
import torch
import torch.onnx
from models.common import get_act_quant, get_quant_linear, get_quant_type, get_stats_op
from torch.nn import BatchNorm1d, Dropout, Module, ModuleList
......@@ -64,40 +62,39 @@ def test_brevitas_to_onnx_export():
)
def forward(self, x):
x = x.view(1, 784)
x = 2.0 * x - 1.0
x = x.view(x.shape[0], -1)
for mod in self.features:
x = mod(x)
out = self.fc(x)
return out
export_onnx_path = "test_output_lfc.onnx"
with torch.no_grad():
lfc = LFC(weight_bit_width=1, act_bit_width=1, in_bit_width=1)
bo.export_finn_onnx(lfc, (1, 1, 28, 28), export_onnx_path)
model = onnx.load(export_onnx_path)
# TODO the following way of testing is highly sensitive to small changes
# in PyTorch ONNX export: the order, names, count... of nodes could
# easily change between different versions, and break this test.
assert len(model.graph.input) == 25
assert len(model.graph.node) == 29
assert len(model.graph.output) == 1
assert model.graph.output[0].type.tensor_type.shape.dim[1].dim_value == 10
act_node = model.graph.node[8]
assert act_node.op_type == "Sign"
matmul_node = model.graph.node[9]
assert matmul_node.op_type == "MatMul"
assert act_node.output[0] == matmul_node.input[1]
inits = [x.name for x in model.graph.initializer]
qnt_annotations = {
a.tensor_name: a.quant_parameter_tensor_names[0].value
for a in model.graph.quantization_annotation
}
assert qnt_annotations[matmul_node.input[0]] == "BIPOLAR"
assert matmul_node.input[0] in inits
assert qnt_annotations[matmul_node.input[1]] == "BIPOLAR"
init_ind = inits.index(matmul_node.input[0])
int_weights_pytorch = lfc.features[2].int_weight.detach().numpy()
int_weights_onnx = nph.to_array(model.graph.initializer[init_ind])
assert (int_weights_onnx == int_weights_pytorch).all()
os.remove(export_onnx_path)
lfc = LFC(weight_bit_width=1, act_bit_width=1, in_bit_width=1)
bo.export_finn_onnx(lfc, (1, 1, 28, 28), export_onnx_path)
model = onnx.load(export_onnx_path)
# TODO the following way of testing is highly sensitive to small changes
# in PyTorch ONNX export: the order, names, count... of nodes could
# easily change between different versions, and break this test.
assert len(model.graph.input) == 23
assert len(model.graph.node) == 24
assert len(model.graph.output) == 1
assert model.graph.output[0].type.tensor_type.shape.dim[1].dim_value == 10
act_node = model.graph.node[3]
assert act_node.op_type == "Sign"
matmul_node = model.graph.node[4]
assert matmul_node.op_type == "MatMul"
assert act_node.output[0] == matmul_node.input[0]
inits = [x.name for x in model.graph.initializer]
qnt_annotations = {
a.tensor_name: a.quant_parameter_tensor_names[0].value
for a in model.graph.quantization_annotation
}
assert qnt_annotations[matmul_node.input[0]] == "BIPOLAR"
assert matmul_node.input[1] in inits
assert qnt_annotations[matmul_node.input[1]] == "BIPOLAR"
init_ind = inits.index(matmul_node.input[1])
int_weights_pytorch = lfc.features[2].int_weight.transpose(1, 0).detach().numpy()
int_weights_onnx = nph.to_array(model.graph.initializer[init_ind])
assert (int_weights_onnx == int_weights_pytorch).all()
os.remove(export_onnx_path)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment