Skip to content
Snippets Groups Projects
Commit 031edad6 authored by Hendrik Borras's avatar Hendrik Borras
Browse files

Enable QONNX ingestion for test_brevitas_act_export_relu_imagenet test.

parent 7e04af91
No related branches found
No related tags found
No related merge requests found
......@@ -114,7 +114,10 @@ scaling_impl.learned_value": torch.tensor(
@pytest.mark.parametrize("abits", [2, 4, 8])
@pytest.mark.parametrize("max_val", [1.0, 1.5, 1 - 2 ** (-7)])
@pytest.mark.parametrize("scaling_per_channel", [True, False])
def test_brevitas_act_export_relu_imagenet(abits, max_val, scaling_per_channel):
@pytest.mark.parametrize("QONNX_export", [False, True])
def test_brevitas_act_export_relu_imagenet(
abits, max_val, scaling_per_channel, QONNX_export
):
out_channels = 32
ishape = (1, out_channels, 1, 1)
min_val = -1.0
......@@ -126,7 +129,7 @@ def test_brevitas_act_export_relu_imagenet(abits, max_val, scaling_per_channel):
restrict_scaling_type=RestrictValueType.LOG_FP,
scaling_min_val=2e-16,
max_val=6.0,
return_quant_tensor=True,
return_quant_tensor=False,
per_channel_broadcastable_shape=(1, out_channels, 1, 1),
)
if scaling_per_channel is True:
......@@ -140,7 +143,15 @@ scaling_impl.learned_value": rand_tensor.type(
)
}
b_act.load_state_dict(checkpoint)
bo.export_finn_onnx(b_act, ishape, export_onnx_path)
if QONNX_export:
m_path = export_onnx_path
BrevitasONNXManager.export(b_act, ishape, m_path)
qonnx_cleanup(m_path, out_file=m_path)
model = ModelWrapper(m_path)
model = model.transform(ConvertQONNXtoFINN())
model.save(m_path)
else:
bo.export_finn_onnx(b_act, ishape, export_onnx_path)
model = ModelWrapper(export_onnx_path)
model = model.transform(InferShapes())
inp_tensor = np.random.uniform(low=min_val, high=max_val, size=ishape).astype(
......@@ -151,7 +162,7 @@ scaling_impl.learned_value": rand_tensor.type(
produced = odict[model.graph.output[0].name]
inp_tensor = torch.from_numpy(inp_tensor).float()
b_act.eval()
expected = b_act.forward(inp_tensor).tensor.detach().numpy()
expected = b_act.forward(inp_tensor).detach().numpy()
if not np.isclose(produced, expected, atol=1e-3).all():
print(abits, max_val)
print("scale: ", b_act.quant_act_scale().type(torch.FloatTensor).detach())
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment