diff --git a/tests/transformation/test_move_transpose_past_scalar_mul.py b/tests/transformation/test_move_transpose_past_scalar_mul.py index 2fef09a78cc2f1b532c767a6dfd360b34779ff34..7e48c61a297e11a668bb74df990bdb99594bd0de 100644 --- a/tests/transformation/test_move_transpose_past_scalar_mul.py +++ b/tests/transformation/test_move_transpose_past_scalar_mul.py @@ -4,8 +4,10 @@ import numpy as np from onnx import TensorProto, helper from finn.core.modelwrapper import ModelWrapper +import finn.core.data_layout as DataLayout from finn.transformation.infer_shapes import InferShapes from finn.transformation.infer_datatypes import InferDataTypes +from finn.transformation.infer_data_layouts import InferDataLayouts from finn.transformation.general import GiveUniqueNodeNames, GiveReadableTensorNames from finn.transformation.streamline.reorder import MoveTransposePastScalarMul import finn.core.onnx_exec as oxe @@ -14,7 +16,9 @@ import finn.core.onnx_exec as oxe @pytest.mark.parametrize("perm", [[0, 2, 3, 1], [0, 1, 3, 2], [3, 2, 0, 1]]) # scalar mul @pytest.mark.parametrize("scalar", [True, False]) -def test_move_transpose_past_scalar_mul(perm, scalar): +# data layout +@pytest.mark.parametrize("data_layout", [DataLayout.NHWC, DataLayout.NCHW]) +def test_move_transpose_past_scalar_mul(perm, scalar, data_layout): inp = helper.make_tensor_value_info("inp", TensorProto.FLOAT, [1, 2, 3, 4]) # to determine out_size we need to calculate with "perm" for this test case dummy_in = np.random.uniform(low=0, high=1, size=(1, 2, 3, 4)).astype(np.float32) @@ -43,9 +47,11 @@ def test_move_transpose_past_scalar_mul(perm, scalar): # initialize values a0_values = np.random.uniform(low=0, high=1, size=tuple(a0_size)).astype(np.float32) model.set_initializer("a0", a0_values) + model.set_tensor_layout("inp", data_layout) model = model.transform(InferShapes()) model = model.transform(InferDataTypes()) + model = model.transform(InferDataLayouts()) model = model.transform(GiveUniqueNodeNames()) model = model.transform(GiveReadableTensorNames()) @@ -61,6 +67,14 @@ def test_move_transpose_past_scalar_mul(perm, scalar): assert model_transformed.graph.node[1] != model.graph.node[1] assert model_transformed.graph.node[0].op_type == "Mul" assert model_transformed.graph.node[1].op_type == "Transpose" + mul_input = model_transformed.graph.node[0].input[0] + mul_output = model_transformed.graph.node[0].output[0] + assert model_transformed.get_tensor_layout(mul_input) == data_layout + assert model_transformed.get_tensor_layout(mul_output) == data_layout else: assert model_transformed.graph.node[0] == model.graph.node[0] assert model_transformed.graph.node[1] == model.graph.node[1] + mul_input = model_transformed.graph.node[1].input[0] + mul_output = model_transformed.graph.node[1].output[0] + assert model_transformed.get_tensor_layout(mul_input) != data_layout + assert model_transformed.get_tensor_layout(mul_output) != data_layout