diff --git a/tests/transformation/test_batchnorm_to_affine.py b/tests/transformation/test_batchnorm_to_affine.py index d3bf26bfa0185b9305b1cb331d72270d85884738..d23934ce2b24531e13f106abe2d3108406ac8cb4 100644 --- a/tests/transformation/test_batchnorm_to_affine.py +++ b/tests/transformation/test_batchnorm_to_affine.py @@ -12,15 +12,10 @@ from finn.transformation.fold_constants import FoldConstants from finn.transformation.infer_shapes import InferShapes from finn.util.test import get_test_model_trained -export_onnx_path = "test_output_lfc.onnx" -transformed_onnx_path = "test_output_lfc_transformed.onnx" -# TODO get from config instead, hardcoded to Docker path for now -trained_lfc_checkpoint = ( - "/workspace/brevitas_cnv_lfc/pretrained_models/LFC_1W1A/checkpoints/best.tar" -) +export_onnx_path = "test_output_bn2affine.onnx" -def test_batchnorm_to_affine(): +def test_batchnorm_to_affine_lfc_w1a1(): lfc = get_test_model_trained("LFC", 1, 1) bo.export_finn_onnx(lfc, (1, 1, 28, 28), export_onnx_path) model = ModelWrapper(export_onnx_path) @@ -33,3 +28,18 @@ def test_batchnorm_to_affine(): input_dict = {"0": nph.to_array(input_tensor)} assert oxe.compare_execution(model, new_model, input_dict) os.remove(export_onnx_path) + + +def test_batchnorm_to_affine_cnv_w1a1(): + lfc = get_test_model_trained("CNV", 1, 1) + bo.export_finn_onnx(lfc, (1, 3, 32, 32), export_onnx_path) + model = ModelWrapper(export_onnx_path) + model = model.transform(InferShapes()) + model = model.transform(FoldConstants()) + # TODO shape inference failing on transformed model below -- needs debug + new_model = model.transform(BatchNormToAffine()) + # check that there are no BN nodes left + # TODO replace this with execution test + op_types = list(map(lambda x: x.op_type, new_model.graph.node)) + assert "BatchNormalization" not in op_types + os.remove(export_onnx_path)