Skip to content
Snippets Groups Projects
Commit 6a0470ef authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[BN2Affine] handle BN param shapes better

parent 0136b4d2
No related branches found
No related tags found
No related merge requests found
......@@ -67,6 +67,16 @@ class BatchNormToAffine(Transformation):
if consumer.op_type == "Squeeze":
bn_output = consumer.output[0]
data_shape = model.get_tensor_shape(bn_input)
assert A.ndim == B.ndim, "Unexpected mul/add dims in BatchNormToAffine"
assert (
len(data_shape) >= A.ndim
), "Unexpected number of dims found in BatchNormToAffine"
# reshape the mul/add constants to match the data shape/dims
# by adding (1,) dimensions to the right
n_spatial_dims = len(data_shape) - 2
target_shape = (1, -1) + tuple(1 for i in range(n_spatial_dims))
A = A.reshape(target_shape)
B = B.reshape(target_shape)
# create value_info and initializers for Mul and Add constants
mul_const = oh.make_tensor_value_info(
model.make_new_valueinfo_name(), TensorProto.FLOAT, A.shape
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment