Skip to content
Snippets Groups Projects
Commit 8f2561ae authored by Tobi-Alonso's avatar Tobi-Alonso
Browse files

[TEST] Add test for scalar past conv

parent 573c470b
No related branches found
No related tags found
No related merge requests found
......@@ -12,6 +12,85 @@ from finn.transformation.streamline import (
)
@pytest.mark.parametrize("padding", [False, True])
@pytest.mark.parametrize(
"test_args", [("Add", MoveScalarAddPastConv()), ("Mul", MoveScalarMulPastConv())],
)
def test_move_scalar_past_conv(test_args, padding):
scalar_op = test_args[0]
transf_fxn = test_args[1]
in_feature_dim = 7
in_chn = 3
stages = 2
kernel_size = 3
out_feature_dim = (
in_feature_dim if padding else in_feature_dim - (kernel_size // 2 * 2) * stages
)
input_shape = [1, in_chn, in_feature_dim, in_feature_dim]
output_shape = [1, in_chn, out_feature_dim, out_feature_dim]
conv_param_shape = [in_chn, in_chn, kernel_size, kernel_size]
conv_config = {}
conv_config["dilations"] = [1, 1]
conv_config["group"] = 1
conv_config["kernel_shape"] = [kernel_size, kernel_size]
if padding:
conv_config["pads"] = [1, 1, 1, 1]
else:
conv_config["pads"] = [0, 0, 0, 0]
conv_config["strides"] = [1, 1]
top_in = oh.make_tensor_value_info("top_in", TensorProto.FLOAT, input_shape)
top_out = oh.make_tensor_value_info("top_out", TensorProto.FLOAT, output_shape)
value_info = [oh.make_tensor_value_info("p1", TensorProto.FLOAT, [1])]
value_info += [oh.make_tensor_value_info("p2", TensorProto.FLOAT, conv_param_shape)]
value_info += [oh.make_tensor_value_info("p3", TensorProto.FLOAT, conv_param_shape)]
modelproto = oh.make_model(
oh.make_graph(
name="test",
inputs=[top_in],
outputs=[top_out],
value_info=value_info,
nodes=[
oh.make_node(scalar_op, ["top_in", "p1"], ["t1"]),
oh.make_node("Conv", ["t1", "p2"], ["t2"], **conv_config),
oh.make_node("Conv", ["t2", "p3"], ["top_out"], **conv_config),
],
)
)
model = ModelWrapper(modelproto)
model = model.transform(InferShapes())
np.random.seed(0)
model.set_initializer("p1", *np.random.rand(1).astype(np.float32))
model.set_initializer("p2", np.random.rand(*conv_param_shape).astype(np.float32))
model.set_initializer("p3", np.random.rand(*conv_param_shape).astype(np.float32))
new_model = model.transform(transf_fxn)
inp_dict = {"top_in": np.random.rand(*input_shape).astype(np.float32)}
assert ox.compare_execution(model, new_model, inp_dict)
if scalar_op == "Add":
if padding:
assert new_model.graph.node[0].op_type == scalar_op
assert new_model.graph.node[1].op_type == "Conv"
assert new_model.graph.node[2].op_type == "Conv"
else:
assert new_model.graph.node[0].op_type == "Conv"
assert new_model.graph.node[1].op_type == scalar_op
assert new_model.graph.node[2].op_type == "Conv"
else:
assert new_model.graph.node[0].op_type == "Conv"
assert new_model.graph.node[1].op_type == "Conv"
assert new_model.graph.node[2].op_type == scalar_op
@pytest.mark.parametrize(
"test_args", [("Add", MoveScalarAddPastConv()), ("Mul", MoveScalarMulPastConv())],
)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment