Skip to content
Snippets Groups Projects
Commit dff518d3 authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Test] compare against expected output in conv lowering test

parent 656460d4
No related branches found
No related tags found
No related merge requests found
......@@ -27,9 +27,9 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import os
import pkg_resources as pk
import brevitas.onnx as bo
import numpy as np
from finn.core.modelwrapper import ModelWrapper
......@@ -37,6 +37,7 @@ from finn.transformation.fold_constants import FoldConstants
from finn.transformation.infer_shapes import InferShapes
from finn.util.test import get_test_model_trained
from finn.transformation.lower_convs_to_matmul import LowerConvsToMatMul
import finn.core.onnx_exec as oxe
export_onnx_path = "test_output_cnv.onnx"
......@@ -47,6 +48,16 @@ def test_conv_lowering_cnv_w1a1():
model = ModelWrapper(export_onnx_path)
model = model.transform(InferShapes())
model = model.transform(FoldConstants())
fn = pk.resource_filename("finn", "data/cifar10/cifar10-test-data-class3.npz")
input_tensor = np.load(fn)["arr_0"].astype(np.float32)
assert input_tensor.shape == (1, 3, 32, 32)
# execute imported model to get expected answer
input_dict = {"0": input_tensor}
output_dict_e = oxe.execute_onnx(model, input_dict)
expected = output_dict_e[list(output_dict_e.keys())[0]]
# execute transformed model and compare
model = model.transform(LowerConvsToMatMul())
model.save("test.onnx")
output_dict_p = oxe.execute_onnx(model, input_dict)
produced = output_dict_p[list(output_dict_p.keys())[0]]
assert np.isclose(produced, expected).all()
os.remove(export_onnx_path)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment