Skip to content
Snippets Groups Projects
Commit 573c470b authored by Tobi-Alonso's avatar Tobi-Alonso
Browse files

[BUGFIX] Add past conv with padding is not currently supported. Adding check for that.

parent b83183bd
No related branches found
No related tags found
No related merge requests found
......@@ -244,7 +244,13 @@ class MoveScalarAddPastConv(Transformation):
start_name = n.input[0]
end_name = consumer.output[0]
conv_out_shape = model.get_tensor_shape(end_name)
if all(x == 1 for x in A.shape):
using_padding = True
for att_idx, attr in enumerate(consumer.attribute):
if attr.name == "pads":
if sum(attr.ints) == 0:
using_padding = False
if all(x == 1 for x in A.shape) and not using_padding:
# create a tensor filled with the add constant, in
# the shape expected by the convolution
conv_in_const = np.zeros(conv_in_shape, dtype=np.float32)
......@@ -256,7 +262,8 @@ class MoveScalarAddPastConv(Transformation):
execute_node(conv_node, exec_ctx, model.graph)
# retrieve the conv output
Anew = exec_ctx[end_name]
# strip out repetition
# strip out repetition if no padding
Anew = Anew[0, :, 0, 0].reshape(1, -1, 1, 1)
# update the add weight
model.set_initializer(add_weight_name, Anew)
......@@ -274,6 +281,7 @@ class MoveScalarAddPastConv(Transformation):
graph.node.remove(add_node)
graph.node.insert(node_ind, add_node)
graph_modified = True
model = model.transform(InferShapes())
return (model, graph_modified)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment