Skip to content
Snippets Groups Projects
Commit 53f5879a authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Test] update conv bn2affine test with actual execution

parent 6a0470ef
No related branches found
No related tags found
No related merge requests found
......@@ -28,10 +28,12 @@
import os
from pkgutil import get_data
import pkg_resources as pk
import brevitas.onnx as bo
import onnx
import onnx.numpy_helper as nph
import numpy as np
import finn.core.onnx_exec as oxe
from finn.core.modelwrapper import ModelWrapper
......@@ -64,11 +66,17 @@ def test_batchnorm_to_affine_cnv_w1a1():
model = ModelWrapper(export_onnx_path)
model = model.transform(InferShapes())
model = model.transform(FoldConstants())
model.save("old.onnx")
# TODO shape inference failing on transformed model below -- needs debug
fn = pk.resource_filename("finn", "data/cifar10/cifar10-test-data-class3.npz")
input_tensor = np.load(fn)["arr_0"].astype(np.float32)
assert input_tensor.shape == (1, 3, 32, 32)
input_dict = {"0": input_tensor}
output_dict = oxe.execute_onnx(model, input_dict)
expected = output_dict[list(output_dict.keys())[0]]
new_model = model.transform(BatchNormToAffine())
# check that there are no BN nodes left
# TODO replace this with execution test
op_types = list(map(lambda x: x.op_type, new_model.graph.node))
assert "BatchNormalization" not in op_types
# os.remove(export_onnx_path)
output_dict_p = oxe.execute_onnx(new_model, input_dict)
produced = output_dict_p[list(output_dict_p.keys())[0]]
assert np.isclose(expected, produced).all()
os.remove(export_onnx_path)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment