diff --git a/tests/transformation/test_move_transpose_past_scalar_mul.py b/tests/transformation/test_move_transpose_past_scalar_mul.py index 7e48c61a297e11a668bb74df990bdb99594bd0de..e434fc7d4f683120176e18a2bfa9da99d9ee0b0e 100644 --- a/tests/transformation/test_move_transpose_past_scalar_mul.py +++ b/tests/transformation/test_move_transpose_past_scalar_mul.py @@ -17,7 +17,7 @@ import finn.core.onnx_exec as oxe # scalar mul @pytest.mark.parametrize("scalar", [True, False]) # data layout -@pytest.mark.parametrize("data_layout", [DataLayout.NHWC, DataLayout.NCHW]) +@pytest.mark.parametrize("data_layout", [None, DataLayout.NHWC, DataLayout.NCHW]) def test_move_transpose_past_scalar_mul(perm, scalar, data_layout): inp = helper.make_tensor_value_info("inp", TensorProto.FLOAT, [1, 2, 3, 4]) # to determine out_size we need to calculate with "perm" for this test case @@ -47,11 +47,12 @@ def test_move_transpose_past_scalar_mul(perm, scalar, data_layout): # initialize values a0_values = np.random.uniform(low=0, high=1, size=tuple(a0_size)).astype(np.float32) model.set_initializer("a0", a0_values) - model.set_tensor_layout("inp", data_layout) + if data_layout is not None: + model.set_tensor_layout("inp", data_layout) + model = model.transform(InferDataLayouts()) model = model.transform(InferShapes()) model = model.transform(InferDataTypes()) - model = model.transform(InferDataLayouts()) model = model.transform(GiveUniqueNodeNames()) model = model.transform(GiveReadableTensorNames()) @@ -62,7 +63,7 @@ def test_move_transpose_past_scalar_mul(perm, scalar, data_layout): assert oxe.compare_execution(model, model_transformed, idict) # check if order changed - if scalar is True: + if scalar is True and data_layout is not None: assert model_transformed.graph.node[0] != model.graph.node[0] assert model_transformed.graph.node[1] != model.graph.node[1] assert model_transformed.graph.node[0].op_type == "Mul" @@ -74,7 +75,8 @@ def test_move_transpose_past_scalar_mul(perm, scalar, data_layout): else: assert model_transformed.graph.node[0] == model.graph.node[0] assert model_transformed.graph.node[1] == model.graph.node[1] - mul_input = model_transformed.graph.node[1].input[0] - mul_output = model_transformed.graph.node[1].output[0] - assert model_transformed.get_tensor_layout(mul_input) != data_layout - assert model_transformed.get_tensor_layout(mul_output) != data_layout + if data_layout is not None: + mul_input = model_transformed.graph.node[1].input[0] + mul_output = model_transformed.graph.node[1].output[0] + assert model_transformed.get_tensor_layout(mul_input) != data_layout + assert model_transformed.get_tensor_layout(mul_output) != data_layout