diff --git a/src/finn/transformation/batchnorm_to_affine.py b/src/finn/transformation/batchnorm_to_affine.py index 77657cf5e2ef14e38aa817e895488fd6dd310cde..401c5916415cd327a52a43f89c076bd7abd40647 100644 --- a/src/finn/transformation/batchnorm_to_affine.py +++ b/src/finn/transformation/batchnorm_to_affine.py @@ -67,6 +67,16 @@ class BatchNormToAffine(Transformation): if consumer.op_type == "Squeeze": bn_output = consumer.output[0] data_shape = model.get_tensor_shape(bn_input) + assert A.ndim == B.ndim, "Unexpected mul/add dims in BatchNormToAffine" + assert ( + len(data_shape) >= A.ndim + ), "Unexpected number of dims found in BatchNormToAffine" + # reshape the mul/add constants to match the data shape/dims + # by adding (1,) dimensions to the right + n_spatial_dims = len(data_shape) - 2 + target_shape = (1, -1) + tuple(1 for i in range(n_spatial_dims)) + A = A.reshape(target_shape) + B = B.reshape(target_shape) # create value_info and initializers for Mul and Add constants mul_const = oh.make_tensor_value_info( model.make_new_valueinfo_name(), TensorProto.FLOAT, A.shape diff --git a/tests/transformation/test_batchnorm_to_affine.py b/tests/transformation/test_batchnorm_to_affine.py index 8728707589ade72fb1b21ca0333c4d0757ac7df0..1ee817ea167864830ee2a0a02fe6806aa5977a26 100644 --- a/tests/transformation/test_batchnorm_to_affine.py +++ b/tests/transformation/test_batchnorm_to_affine.py @@ -28,10 +28,12 @@ import os from pkgutil import get_data +import pkg_resources as pk import brevitas.onnx as bo import onnx import onnx.numpy_helper as nph +import numpy as np import finn.core.onnx_exec as oxe from finn.core.modelwrapper import ModelWrapper @@ -58,18 +60,23 @@ def test_batchnorm_to_affine_lfc_w1a1(): os.remove(export_onnx_path) -# cnv batchnorm to affine not yet supported - -# def test_batchnorm_to_affine_cnv_w1a1(): -# lfc = get_test_model_trained("CNV", 1, 1) -# bo.export_finn_onnx(lfc, (1, 3, 32, 32), export_onnx_path) -# model = ModelWrapper(export_onnx_path) -# model = model.transform(InferShapes()) -# model = model.transform(FoldConstants()) -# # TODO shape inference failing on transformed model below -- needs debug -# new_model = model.transform(BatchNormToAffine()) -# # check that there are no BN nodes left -# # TODO replace this with execution test -# op_types = list(map(lambda x: x.op_type, new_model.graph.node)) -# assert "BatchNormalization" not in op_types -# os.remove(export_onnx_path) +def test_batchnorm_to_affine_cnv_w1a1(): + lfc = get_test_model_trained("CNV", 1, 1) + bo.export_finn_onnx(lfc, (1, 3, 32, 32), export_onnx_path) + model = ModelWrapper(export_onnx_path) + model = model.transform(InferShapes()) + model = model.transform(FoldConstants()) + fn = pk.resource_filename("finn", "data/cifar10/cifar10-test-data-class3.npz") + input_tensor = np.load(fn)["arr_0"].astype(np.float32) + assert input_tensor.shape == (1, 3, 32, 32) + input_dict = {"0": input_tensor} + output_dict = oxe.execute_onnx(model, input_dict) + expected = output_dict[list(output_dict.keys())[0]] + new_model = model.transform(BatchNormToAffine()) + # check that there are no BN nodes left + op_types = list(map(lambda x: x.op_type, new_model.graph.node)) + assert "BatchNormalization" not in op_types + output_dict_p = oxe.execute_onnx(new_model, input_dict) + produced = output_dict_p[list(output_dict_p.keys())[0]] + assert np.isclose(expected, produced).all() + os.remove(export_onnx_path)