Skip to content
Snippets Groups Projects
Commit bbbaac0a authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Test] add cnv testcase to batchnorm2affine

parent ade2ad7e
No related branches found
No related tags found
No related merge requests found
......@@ -12,15 +12,10 @@ from finn.transformation.fold_constants import FoldConstants
from finn.transformation.infer_shapes import InferShapes
from finn.util.test import get_test_model_trained
export_onnx_path = "test_output_lfc.onnx"
transformed_onnx_path = "test_output_lfc_transformed.onnx"
# TODO get from config instead, hardcoded to Docker path for now
trained_lfc_checkpoint = (
"/workspace/brevitas_cnv_lfc/pretrained_models/LFC_1W1A/checkpoints/best.tar"
)
export_onnx_path = "test_output_bn2affine.onnx"
def test_batchnorm_to_affine():
def test_batchnorm_to_affine_lfc_w1a1():
lfc = get_test_model_trained("LFC", 1, 1)
bo.export_finn_onnx(lfc, (1, 1, 28, 28), export_onnx_path)
model = ModelWrapper(export_onnx_path)
......@@ -33,3 +28,18 @@ def test_batchnorm_to_affine():
input_dict = {"0": nph.to_array(input_tensor)}
assert oxe.compare_execution(model, new_model, input_dict)
os.remove(export_onnx_path)
def test_batchnorm_to_affine_cnv_w1a1():
lfc = get_test_model_trained("CNV", 1, 1)
bo.export_finn_onnx(lfc, (1, 3, 32, 32), export_onnx_path)
model = ModelWrapper(export_onnx_path)
model = model.transform(InferShapes())
model = model.transform(FoldConstants())
# TODO shape inference failing on transformed model below -- needs debug
new_model = model.transform(BatchNormToAffine())
# check that there are no BN nodes left
# TODO replace this with execution test
op_types = list(map(lambda x: x.op_type, new_model.graph.node))
assert "BatchNormalization" not in op_types
os.remove(export_onnx_path)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment