Skip to content
Snippets Groups Projects
Commit e96ca0f2 authored by auphelia's avatar auphelia
Browse files

[Streamline] Extend MoveAddPastConv to move channelwise add nodes

parent a8bbc11f
No related branches found
No related tags found
No related merge requests found
......@@ -217,7 +217,7 @@ class MoveScalarAddPastMatMul(Transformation):
class MoveAddPastConv(Transformation):
"""Move scalar add operations past conv operations. We want to have adds
"""Move scalar and channelwise add operations past conv operations. We want to have adds
next to each other such that they can be collapsed into a single add."""
def apply(self, model):
......@@ -242,6 +242,8 @@ class MoveAddPastConv(Transformation):
add_weight_name = n.input[1]
conv_in_name = consumer.input[0]
conv_in_shape = model.get_tensor_shape(conv_in_name)
# assume datalayout to be NCHW
channels = conv_in_shape[1]
A = model.get_initializer(add_weight_name)
assert A is not None, "Initializer for add weights is not set."
start_name = n.input[0]
......@@ -252,11 +254,17 @@ class MoveAddPastConv(Transformation):
pads = list(get_by_name(consumer.attribute, "pads").ints)
if sum(pads) == 0:
using_padding = False
if all(x == 1 for x in A.shape) and not using_padding:
if (
all(x == 1 for x in A.shape) or A.shape == (1, channels, 1, 1)
) and not using_padding:
# create a tensor filled with the add constant, in
# the shape expected by the convolution
conv_in_const = np.zeros(conv_in_shape, dtype=np.float32)
conv_in_const.fill(A.item())
if A.shape == (1, channels, 1, 1):
for ch in range(channels):
conv_in_const[0][ch].fill(A[0][ch].item())
else:
conv_in_const.fill(A.item())
# create an execution context and put in const input
exec_ctx = model.make_empty_exec_context()
exec_ctx[conv_in_name] = conv_in_const
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment