From 6a0470ef7e1358814df40f1686a886189ced5937 Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <yamanu@xilinx.com>
Date: Fri, 6 Mar 2020 16:25:34 +0000
Subject: [PATCH] [BN2Affine] handle BN param shapes better

---
 src/finn/transformation/batchnorm_to_affine.py | 10 ++++++++++
 1 file changed, 10 insertions(+)

diff --git a/src/finn/transformation/batchnorm_to_affine.py b/src/finn/transformation/batchnorm_to_affine.py
index 77657cf5e..401c59164 100644
--- a/src/finn/transformation/batchnorm_to_affine.py
+++ b/src/finn/transformation/batchnorm_to_affine.py
@@ -67,6 +67,16 @@ class BatchNormToAffine(Transformation):
                     if consumer.op_type == "Squeeze":
                         bn_output = consumer.output[0]
                 data_shape = model.get_tensor_shape(bn_input)
+                assert A.ndim == B.ndim, "Unexpected mul/add dims in BatchNormToAffine"
+                assert (
+                    len(data_shape) >= A.ndim
+                ), "Unexpected number of dims found in BatchNormToAffine"
+                # reshape the mul/add constants to match the data shape/dims
+                # by adding (1,) dimensions to the right
+                n_spatial_dims = len(data_shape) - 2
+                target_shape = (1, -1) + tuple(1 for i in range(n_spatial_dims))
+                A = A.reshape(target_shape)
+                B = B.reshape(target_shape)
                 # create value_info and initializers for Mul and Add constants
                 mul_const = oh.make_tensor_value_info(
                     model.make_new_valueinfo_name(), TensorProto.FLOAT, A.shape
-- 
GitLab