diff --git a/tests/transformation/test_move_flatten_past_affine.py b/tests/transformation/test_move_flatten_past_affine.py index ba0a7f888658663e4aab530336316ddd72f8baee..e5588965ce1c2d006529b0f8895c489e7a38b6c7 100644 --- a/tests/transformation/test_move_flatten_past_affine.py +++ b/tests/transformation/test_move_flatten_past_affine.py @@ -25,27 +25,40 @@ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import pytest import numpy as np from onnx import TensorProto, helper from finn.core.modelwrapper import ModelWrapper from finn.core.datatype import DataType +import finn.core.data_layout as DataLayout from finn.util.basic import gen_finn_dt_tensor from finn.transformation.infer_shapes import InferShapes from finn.transformation.infer_datatypes import InferDataTypes +from finn.transformation.infer_data_layouts import InferDataLayouts from finn.transformation.general import GiveUniqueNodeNames, GiveReadableTensorNames from finn.transformation.streamline.reorder import MoveFlattenPastAffine import finn.core.onnx_exec as oxe +# data layout +@pytest.mark.parametrize("data_layout", [DataLayout.NHWC, DataLayout.NCHW]) +# batch size +@pytest.mark.parametrize("batch_size", [1, 2]) +def test_move_flatten_past_affine(data_layout, batch_size): + if data_layout == DataLayout.NHWC: + ishape = [batch_size, 1, 1, 1024] + oshape = [batch_size, 1000] + else: + ishape = [batch_size, 1024, 1, 1] + oshape = [batch_size, 1000] -def test_move_flatten_past_affine(): - inp = helper.make_tensor_value_info("inp", TensorProto.FLOAT, [1, 1, 1, 1024]) + inp = helper.make_tensor_value_info("inp", TensorProto.FLOAT, ishape) a0 = helper.make_tensor_value_info("a1", TensorProto.FLOAT, [1024, 1000]) a1 = helper.make_tensor_value_info("a2", TensorProto.FLOAT, []) a2 = helper.make_tensor_value_info("a3", TensorProto.FLOAT, [1000]) + outp = helper.make_tensor_value_info("outp", TensorProto.FLOAT, oshape) - outp = helper.make_tensor_value_info("outp", TensorProto.FLOAT, [1, 1000]) flatten_node = helper.make_node("Flatten", ["inp"], ["flatten_out"]) matmul_node = helper.make_node("MatMul", ["flatten_out", "a0"], ["matmul_out"]) mul_node = helper.make_node("Mul", ["matmul_out", "a1"], ["mul_out"]) @@ -71,18 +84,23 @@ def test_move_flatten_past_affine(): model.set_initializer("a2", a2_values) model.set_tensor_datatype("inp", DataType.INT2) - model.set_tensor_datatype("flatten_out", DataType.INT2) + model.set_tensor_layout("inp", data_layout) model = model.transform(InferShapes()) model = model.transform(InferDataTypes()) + model = model.transform(InferDataLayouts()) model = model.transform(GiveUniqueNodeNames()) model = model.transform(GiveReadableTensorNames()) # compare execution before and after transformation - inp_values = gen_finn_dt_tensor(DataType.INT2, [1, 1, 1, 1024]) + inp_values = gen_finn_dt_tensor(DataType.INT2, ishape) idict = {"inp": inp_values} model_transformed = model.transform(MoveFlattenPastAffine()) assert oxe.compare_execution(model, model_transformed, idict) - # check if nodes have new order in transformed graph - assert model.graph != model_transformed.graph - assert model_transformed.graph.node[-1].op_type == "Flatten" + # depending on data layout check if graph is transformed or not + if data_layout == DataLayout.NHWC: + # check if nodes have new order in transformed graph + assert model.graph != model_transformed.graph + assert model_transformed.graph.node[-1].op_type == "Flatten" + else: + assert model.graph == model_transformed.graph