From e96ca0f2a821d97214f287a95430e7bbcd4f7dde Mon Sep 17 00:00:00 2001 From: auphelia <jakobapk@web.de> Date: Wed, 1 Jul 2020 14:51:44 +0100 Subject: [PATCH] [Streamline] Extend MoveAddPastConv to move channelwise add nodes --- src/finn/transformation/streamline/reorder.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/src/finn/transformation/streamline/reorder.py b/src/finn/transformation/streamline/reorder.py index a12c31c96..c3451e000 100644 --- a/src/finn/transformation/streamline/reorder.py +++ b/src/finn/transformation/streamline/reorder.py @@ -217,7 +217,7 @@ class MoveScalarAddPastMatMul(Transformation): class MoveAddPastConv(Transformation): - """Move scalar add operations past conv operations. We want to have adds + """Move scalar and channelwise add operations past conv operations. We want to have adds next to each other such that they can be collapsed into a single add.""" def apply(self, model): @@ -242,6 +242,8 @@ class MoveAddPastConv(Transformation): add_weight_name = n.input[1] conv_in_name = consumer.input[0] conv_in_shape = model.get_tensor_shape(conv_in_name) + # assume datalayout to be NCHW + channels = conv_in_shape[1] A = model.get_initializer(add_weight_name) assert A is not None, "Initializer for add weights is not set." start_name = n.input[0] @@ -252,11 +254,17 @@ class MoveAddPastConv(Transformation): pads = list(get_by_name(consumer.attribute, "pads").ints) if sum(pads) == 0: using_padding = False - if all(x == 1 for x in A.shape) and not using_padding: + if ( + all(x == 1 for x in A.shape) or A.shape == (1, channels, 1, 1) + ) and not using_padding: # create a tensor filled with the add constant, in # the shape expected by the convolution conv_in_const = np.zeros(conv_in_shape, dtype=np.float32) - conv_in_const.fill(A.item()) + if A.shape == (1, channels, 1, 1): + for ch in range(channels): + conv_in_const[0][ch].fill(A[0][ch].item()) + else: + conv_in_const.fill(A.item()) # create an execution context and put in const input exec_ctx = model.make_empty_exec_context() exec_ctx[conv_in_name] = conv_in_const -- GitLab