Skip to content
Snippets Groups Projects
Commit dbee47b9 authored by Hendrik Borras's avatar Hendrik Borras
Browse files

Added support for catching UserWarnings emitted by FINN during onnx execution as errors.

parent c59cc447
No related branches found
No related tags found
No related merge requests found
......@@ -35,6 +35,7 @@ import brevitas.onnx as bo
import numpy as np
import onnx
import onnx.numpy_helper as nph
import warnings
from brevitas.export.onnx.generic.manager import BrevitasONNXManager
from pkgutil import get_data
from qonnx.util.cleanup import cleanup
......@@ -140,13 +141,22 @@ def test_QONNX_to_FINN(model_name, wbits, abits):
# Compare output
model = ModelWrapper(qonnx_base_path.format("whole_trafo"))
input_dict = {model.graph.input[0].name: input_tensor}
output_dict = oxe.execute_onnx(model, input_dict, False)
test_output = output_dict[model.graph.output[0].name]
assert np.isclose(test_output, finn_export_output).all(), (
"The output of the FINN model "
"and the QONNX -> FINN converted model should match."
)
with warnings.catch_warnings(record=True) as warn_list:
warnings.simplefilter("always")
output_dict = oxe.execute_onnx(model, input_dict, False)
test_output = output_dict[model.graph.output[0].name]
assert np.isclose(test_output, finn_export_output).all(), (
"The output of the FINN model "
"and the QONNX -> FINN converted model should match."
)
# Check for UserWarnings
for warn in warn_list:
if issubclass(warn.category, UserWarning):
raise RuntimeError(
"Treating the following warning as an error, "
"since the warning is potentially unfixable for the user: "
+ str(warn)
)
# Run analysis passes on the converted model
model = ModelWrapper(qonnx_base_path.format("whole_trafo"))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment