Skip to content
Snippets Groups Projects
Commit 7938f93c authored by auphelia's avatar auphelia
Browse files

[Test] Add test with data layout is None for MoveTransposePastScalarMul

parent f724b641
No related branches found
No related tags found
No related merge requests found
......@@ -17,7 +17,7 @@ import finn.core.onnx_exec as oxe
# scalar mul
@pytest.mark.parametrize("scalar", [True, False])
# data layout
@pytest.mark.parametrize("data_layout", [DataLayout.NHWC, DataLayout.NCHW])
@pytest.mark.parametrize("data_layout", [None, DataLayout.NHWC, DataLayout.NCHW])
def test_move_transpose_past_scalar_mul(perm, scalar, data_layout):
inp = helper.make_tensor_value_info("inp", TensorProto.FLOAT, [1, 2, 3, 4])
# to determine out_size we need to calculate with "perm" for this test case
......@@ -47,11 +47,12 @@ def test_move_transpose_past_scalar_mul(perm, scalar, data_layout):
# initialize values
a0_values = np.random.uniform(low=0, high=1, size=tuple(a0_size)).astype(np.float32)
model.set_initializer("a0", a0_values)
model.set_tensor_layout("inp", data_layout)
if data_layout is not None:
model.set_tensor_layout("inp", data_layout)
model = model.transform(InferDataLayouts())
model = model.transform(InferShapes())
model = model.transform(InferDataTypes())
model = model.transform(InferDataLayouts())
model = model.transform(GiveUniqueNodeNames())
model = model.transform(GiveReadableTensorNames())
......@@ -62,7 +63,7 @@ def test_move_transpose_past_scalar_mul(perm, scalar, data_layout):
assert oxe.compare_execution(model, model_transformed, idict)
# check if order changed
if scalar is True:
if scalar is True and data_layout is not None:
assert model_transformed.graph.node[0] != model.graph.node[0]
assert model_transformed.graph.node[1] != model.graph.node[1]
assert model_transformed.graph.node[0].op_type == "Mul"
......@@ -74,7 +75,8 @@ def test_move_transpose_past_scalar_mul(perm, scalar, data_layout):
else:
assert model_transformed.graph.node[0] == model.graph.node[0]
assert model_transformed.graph.node[1] == model.graph.node[1]
mul_input = model_transformed.graph.node[1].input[0]
mul_output = model_transformed.graph.node[1].output[0]
assert model_transformed.get_tensor_layout(mul_input) != data_layout
assert model_transformed.get_tensor_layout(mul_output) != data_layout
if data_layout is not None:
mul_input = model_transformed.graph.node[1].input[0]
mul_output = model_transformed.graph.node[1].output[0]
assert model_transformed.get_tensor_layout(mul_input) != data_layout
assert model_transformed.get_tensor_layout(mul_output) != data_layout
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment