From 7938f93c038ea93777067935e86dea5da6484f04 Mon Sep 17 00:00:00 2001
From: auphelia <jakobapk@web.de>
Date: Fri, 26 Jun 2020 15:30:01 +0100
Subject: [PATCH] [Test] Add test with data layout is None for
 MoveTransposePastScalarMul

---
 .../test_move_transpose_past_scalar_mul.py     | 18 ++++++++++--------
 1 file changed, 10 insertions(+), 8 deletions(-)

diff --git a/tests/transformation/test_move_transpose_past_scalar_mul.py b/tests/transformation/test_move_transpose_past_scalar_mul.py
index 7e48c61a2..e434fc7d4 100644
--- a/tests/transformation/test_move_transpose_past_scalar_mul.py
+++ b/tests/transformation/test_move_transpose_past_scalar_mul.py
@@ -17,7 +17,7 @@ import finn.core.onnx_exec as oxe
 # scalar mul
 @pytest.mark.parametrize("scalar", [True, False])
 # data layout
-@pytest.mark.parametrize("data_layout", [DataLayout.NHWC, DataLayout.NCHW])
+@pytest.mark.parametrize("data_layout", [None, DataLayout.NHWC, DataLayout.NCHW])
 def test_move_transpose_past_scalar_mul(perm, scalar, data_layout):
     inp = helper.make_tensor_value_info("inp", TensorProto.FLOAT, [1, 2, 3, 4])
     # to determine out_size we need to calculate with "perm" for this test case
@@ -47,11 +47,12 @@ def test_move_transpose_past_scalar_mul(perm, scalar, data_layout):
     # initialize values
     a0_values = np.random.uniform(low=0, high=1, size=tuple(a0_size)).astype(np.float32)
     model.set_initializer("a0", a0_values)
-    model.set_tensor_layout("inp", data_layout)
+    if data_layout is not None:
+        model.set_tensor_layout("inp", data_layout)
+        model = model.transform(InferDataLayouts())
 
     model = model.transform(InferShapes())
     model = model.transform(InferDataTypes())
-    model = model.transform(InferDataLayouts())
     model = model.transform(GiveUniqueNodeNames())
     model = model.transform(GiveReadableTensorNames())
 
@@ -62,7 +63,7 @@ def test_move_transpose_past_scalar_mul(perm, scalar, data_layout):
     assert oxe.compare_execution(model, model_transformed, idict)
 
     # check if order changed
-    if scalar is True:
+    if scalar is True and data_layout is not None:
         assert model_transformed.graph.node[0] != model.graph.node[0]
         assert model_transformed.graph.node[1] != model.graph.node[1]
         assert model_transformed.graph.node[0].op_type == "Mul"
@@ -74,7 +75,8 @@ def test_move_transpose_past_scalar_mul(perm, scalar, data_layout):
     else:
         assert model_transformed.graph.node[0] == model.graph.node[0]
         assert model_transformed.graph.node[1] == model.graph.node[1]
-        mul_input = model_transformed.graph.node[1].input[0]
-        mul_output = model_transformed.graph.node[1].output[0]
-        assert model_transformed.get_tensor_layout(mul_input) != data_layout
-        assert model_transformed.get_tensor_layout(mul_output) != data_layout
+        if data_layout is not None:
+            mul_input = model_transformed.graph.node[1].input[0]
+            mul_output = model_transformed.graph.node[1].output[0]
+            assert model_transformed.get_tensor_layout(mul_input) != data_layout
+            assert model_transformed.get_tensor_layout(mul_output) != data_layout
-- 
GitLab