Skip to content
Snippets Groups Projects
Unverified Commit 4dd4ed3b authored by auphelia's avatar auphelia Committed by GitHub
Browse files

Merge pull request #114 from quetric/feature/add_past_conv_check_padding

Feature/add past conv check padding
parents 316ca19d b2f8bc71
No related branches found
No related tags found
No related merge requests found
......@@ -244,7 +244,12 @@ class MoveScalarAddPastConv(Transformation):
start_name = n.input[0]
end_name = consumer.output[0]
conv_out_shape = model.get_tensor_shape(end_name)
if all(x == 1 for x in A.shape):
using_padding = True
pads = list(get_by_name(consumer.attribute, "pads").ints)
if sum(pads) == 0:
using_padding = False
if all(x == 1 for x in A.shape) and not using_padding:
# create a tensor filled with the add constant, in
# the shape expected by the convolution
conv_in_const = np.zeros(conv_in_shape, dtype=np.float32)
......@@ -256,7 +261,8 @@ class MoveScalarAddPastConv(Transformation):
execute_node(conv_node, exec_ctx, model.graph)
# retrieve the conv output
Anew = exec_ctx[end_name]
# strip out repetition
# strip out repetition if no padding
Anew = Anew[0, :, 0, 0].reshape(1, -1, 1, 1)
# update the add weight
model.set_initializer(add_weight_name, Anew)
......@@ -274,6 +280,7 @@ class MoveScalarAddPastConv(Transformation):
graph.node.remove(add_node)
graph.node.insert(node_ind, add_node)
graph_modified = True
model = model.transform(InferShapes())
return (model, graph_modified)
......
......@@ -12,6 +12,85 @@ from finn.transformation.streamline import (
)
@pytest.mark.parametrize("padding", [False, True])
@pytest.mark.parametrize(
"test_args", [("Add", MoveScalarAddPastConv()), ("Mul", MoveScalarMulPastConv())],
)
def test_move_scalar_past_conv(test_args, padding):
scalar_op = test_args[0]
transf_fxn = test_args[1]
in_feature_dim = 7
in_chn = 3
stages = 2
kernel_size = 3
out_feature_dim = (
in_feature_dim if padding else in_feature_dim - (kernel_size // 2 * 2) * stages
)
input_shape = [1, in_chn, in_feature_dim, in_feature_dim]
output_shape = [1, in_chn, out_feature_dim, out_feature_dim]
conv_param_shape = [in_chn, in_chn, kernel_size, kernel_size]
conv_config = {}
conv_config["dilations"] = [1, 1]
conv_config["group"] = 1
conv_config["kernel_shape"] = [kernel_size, kernel_size]
if padding:
conv_config["pads"] = [1, 1, 1, 1]
else:
conv_config["pads"] = [0, 0, 0, 0]
conv_config["strides"] = [1, 1]
top_in = oh.make_tensor_value_info("top_in", TensorProto.FLOAT, input_shape)
top_out = oh.make_tensor_value_info("top_out", TensorProto.FLOAT, output_shape)
value_info = [oh.make_tensor_value_info("p1", TensorProto.FLOAT, [1])]
value_info += [oh.make_tensor_value_info("p2", TensorProto.FLOAT, conv_param_shape)]
value_info += [oh.make_tensor_value_info("p3", TensorProto.FLOAT, conv_param_shape)]
modelproto = oh.make_model(
oh.make_graph(
name="test",
inputs=[top_in],
outputs=[top_out],
value_info=value_info,
nodes=[
oh.make_node(scalar_op, ["top_in", "p1"], ["t1"]),
oh.make_node("Conv", ["t1", "p2"], ["t2"], **conv_config),
oh.make_node("Conv", ["t2", "p3"], ["top_out"], **conv_config),
],
)
)
model = ModelWrapper(modelproto)
model = model.transform(InferShapes())
np.random.seed(0)
model.set_initializer("p1", *np.random.rand(1).astype(np.float32))
model.set_initializer("p2", np.random.rand(*conv_param_shape).astype(np.float32))
model.set_initializer("p3", np.random.rand(*conv_param_shape).astype(np.float32))
new_model = model.transform(transf_fxn)
inp_dict = {"top_in": np.random.rand(*input_shape).astype(np.float32)}
assert ox.compare_execution(model, new_model, inp_dict)
if scalar_op == "Add":
if padding:
assert new_model.graph.node[0].op_type == scalar_op
assert new_model.graph.node[1].op_type == "Conv"
assert new_model.graph.node[2].op_type == "Conv"
else:
assert new_model.graph.node[0].op_type == "Conv"
assert new_model.graph.node[1].op_type == scalar_op
assert new_model.graph.node[2].op_type == "Conv"
else:
assert new_model.graph.node[0].op_type == "Conv"
assert new_model.graph.node[1].op_type == "Conv"
assert new_model.graph.node[2].op_type == scalar_op
@pytest.mark.parametrize(
"test_args", [("Add", MoveScalarAddPastConv()), ("Mul", MoveScalarMulPastConv())],
)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment