diff --git a/src/finn/transformation/batchnorm_to_affine.py b/src/finn/transformation/batchnorm_to_affine.py index 77657cf5e2ef14e38aa817e895488fd6dd310cde..401c5916415cd327a52a43f89c076bd7abd40647 100644 --- a/src/finn/transformation/batchnorm_to_affine.py +++ b/src/finn/transformation/batchnorm_to_affine.py @@ -67,6 +67,16 @@ class BatchNormToAffine(Transformation): if consumer.op_type == "Squeeze": bn_output = consumer.output[0] data_shape = model.get_tensor_shape(bn_input) + assert A.ndim == B.ndim, "Unexpected mul/add dims in BatchNormToAffine" + assert ( + len(data_shape) >= A.ndim + ), "Unexpected number of dims found in BatchNormToAffine" + # reshape the mul/add constants to match the data shape/dims + # by adding (1,) dimensions to the right + n_spatial_dims = len(data_shape) - 2 + target_shape = (1, -1) + tuple(1 for i in range(n_spatial_dims)) + A = A.reshape(target_shape) + B = B.reshape(target_shape) # create value_info and initializers for Mul and Add constants mul_const = oh.make_tensor_value_info( model.make_new_valueinfo_name(), TensorProto.FLOAT, A.shape