import pytest

import numpy as np
from onnx import TensorProto, helper

from finn.core.modelwrapper import ModelWrapper
import finn.core.data_layout as DataLayout
from finn.transformation.infer_shapes import InferShapes
from finn.transformation.infer_datatypes import InferDataTypes
from finn.transformation.infer_data_layouts import InferDataLayouts
from finn.transformation.general import GiveUniqueNodeNames, GiveReadableTensorNames
from finn.transformation.streamline.absorb import AbsorbTransposeIntoFlatten
import finn.core.onnx_exec as oxe

# permutation of transpose node
@pytest.mark.parametrize("perm", [[0, 2, 3, 1], [0, 1, 3, 2], [3, 2, 0, 1]])
# reshape or flatten
@pytest.mark.parametrize("shape", [None, [1, -1], [-1, 1]])
# input shape
@pytest.mark.parametrize("ishape", [[1, 1, 1, 4], [2, 4, 1, 1], [1, 2, 2, 4]])
# datalayout
@pytest.mark.parametrize("data_layout", ["NCHW", "NHWC"])
def test_absorb_transp_into_flatten(perm, shape, ishape, data_layout):
    inp = helper.make_tensor_value_info("inp", TensorProto.FLOAT, ishape)
    transp_node = helper.make_node("Transpose", ["inp"], ["transp_out"], perm=perm)
    dummy_in = np.random.uniform(low=0, high=1, size=tuple(ishape)).astype(np.float32)
    if shape is None:
        shape_node = helper.make_node("Flatten", ["transp_out"], ["outp"])
        dummy_in = dummy_in.transpose(tuple(perm))
        oshape = dummy_in.reshape(dummy_in.shape[0], -1).shape
        outp = helper.make_tensor_value_info("outp", TensorProto.FLOAT, oshape)
        shape0 = None
    else:
        shape0 = helper.make_tensor_value_info("shape0", TensorProto.FLOAT, shape)
        shape_node = helper.make_node("Reshape", ["transp_out", "shape0"], ["outp"])
        oshape = dummy_in.transpose(tuple(perm)).reshape(tuple(shape)).shape
        outp = helper.make_tensor_value_info("outp", TensorProto.FLOAT, oshape)

    graph = helper.make_graph(
        nodes=[transp_node, shape_node],
        name="absorb-transpose-graph",
        inputs=[inp],
        outputs=[outp],
    )

    model = helper.make_model(graph, producer_name="absorb_transpose_model")
    model = ModelWrapper(model)
    if shape is not None:
        model.graph.value_info.append(shape0)
        model.set_initializer("shape0", np.asarray(shape))
    if data_layout == "NCHW":
        model.set_tensor_layout("inp", DataLayout.NCHW)
    else:
        model.set_tensor_layout("inp", DataLayout.NHWC)
    model = model.transform(InferShapes())
    model = model.transform(InferDataTypes())
    model = model.transform(InferDataLayouts())
    model = model.transform(GiveUniqueNodeNames())
    model = model.transform(GiveReadableTensorNames())
    # model.save("test.onnx")
    model_transformed = model.transform(AbsorbTransposeIntoFlatten())
    # model_transformed.save("test2.onnx")

    # verify transformation
    inp_values = np.random.uniform(low=-1, high=1, size=tuple(ishape)).astype(
        np.float32
    )
    idict = {model.graph.input[0].name: inp_values}
    assert oxe.compare_execution(model, model_transformed, idict)

    # only some of the parameter combinations lead to a graph that will be changed when
    # AbsorbTransposeIntoFlatten is applied

    if shape == [-1, 1]:  # not a flatten operation, so the graph will not be changed
        assert model.graph == model_transformed.graph

    elif perm == [
        3,
        2,
        0,
        1,
    ]:  # the first dimension is also part of the transpose operation
        # so the graph will not be changed
        assert model.graph == model_transformed.graph

    # the following cases are the ones in which the model is transformed
    # because we tested the parameters shape and perm befire we can only consider ishape
    # and data_layout (the transformed model should only contain a "Flatten" node)
    elif ishape == [1, 1, 1, 4] and data_layout == "NHWC":
        assert model_transformed.graph.node[0].op_type == "Flatten"

    elif ishape == [2, 4, 1, 1] and data_layout == "NCHW" and shape is None:
        # If the first  dimension of the input tensor is not 1, flatten and
        # reshape (with shape = [1, -1]) would lead to different results
        assert model_transformed.graph.node[0].op_type == "Flatten"

    # all other cases lead to an unchanged model
    else:
        assert model.graph == model_transformed.graph