import os
import shutil
from functools import reduce
from operator import mul

import brevitas.onnx as bo
import numpy as np
import onnx
import onnx.numpy_helper as nph
import onnx.shape_inference as si
import torch
import wget
from models.common import get_act_quant, get_quant_linear, get_quant_type, get_stats_op
from torch.nn import BatchNorm1d, Dropout, Module, ModuleList

import finn.core.onnx_exec as oxe

FC_OUT_FEATURES = [1024, 1024, 1024]
INTERMEDIATE_FC_PER_OUT_CH_SCALING = True
LAST_FC_PER_OUT_CH_SCALING = False
IN_DROPOUT = 0.2
HIDDEN_DROPOUT = 0.2

mnist_onnx_url_base = "https://onnxzoo.blob.core.windows.net/models/opset_8/mnist"
mnist_onnx_filename = "mnist.tar.gz"
mnist_onnx_local_dir = "/tmp/mnist_onnx"
export_onnx_path = "test_output_lfc.onnx"
# TODO get from config instead, hardcoded to Docker path for now
trained_lfc_checkpoint = (
    "/workspace/brevitas_cnv_lfc/pretrained_models/LFC_1W1A/checkpoints/best.tar"
)


class LFC(Module):
    def __init__(
        self,
        num_classes=10,
        weight_bit_width=None,
        act_bit_width=None,
        in_bit_width=None,
        in_ch=1,
        in_features=(28, 28),
    ):
        super(LFC, self).__init__()

        weight_quant_type = get_quant_type(weight_bit_width)
        act_quant_type = get_quant_type(act_bit_width)
        in_quant_type = get_quant_type(in_bit_width)
        stats_op = get_stats_op(weight_quant_type)

        self.features = ModuleList()
        self.features.append(get_act_quant(in_bit_width, in_quant_type))
        self.features.append(Dropout(p=IN_DROPOUT))
        in_features = reduce(mul, in_features)
        for out_features in FC_OUT_FEATURES:
            self.features.append(
                get_quant_linear(
                    in_features=in_features,
                    out_features=out_features,
                    per_out_ch_scaling=INTERMEDIATE_FC_PER_OUT_CH_SCALING,
                    bit_width=weight_bit_width,
                    quant_type=weight_quant_type,
                    stats_op=stats_op,
                )
            )
            in_features = out_features
            self.features.append(BatchNorm1d(num_features=in_features))
            self.features.append(get_act_quant(act_bit_width, act_quant_type))
            self.features.append(Dropout(p=HIDDEN_DROPOUT))
        self.fc = get_quant_linear(
            in_features=in_features,
            out_features=num_classes,
            per_out_ch_scaling=LAST_FC_PER_OUT_CH_SCALING,
            bit_width=weight_bit_width,
            quant_type=weight_quant_type,
            stats_op=stats_op,
        )

    def forward(self, x):
        x = x.view(1, 784)
        # removing the torch.tensor here creates a float64 op for some reason..
        # so explicitly wrapped with torch.tensor to make a float32 one instead
        x = 2.0 * x - torch.tensor([1.0])
        for mod in self.features:
            x = mod(x)
        out = self.fc(x)
        return out


def test_brevitas_to_onnx_export():
    lfc = LFC(weight_bit_width=1, act_bit_width=1, in_bit_width=1)
    bo.export_finn_onnx(lfc, (1, 1, 28, 28), export_onnx_path)
    model = onnx.load(export_onnx_path)
    # TODO the following way of testing is highly sensitive to small changes
    # in PyTorch ONNX export: the order, names, count... of nodes could
    # easily change between different versions, and break this test.
    assert len(model.graph.input) == 23
    assert len(model.graph.node) == 24
    assert len(model.graph.output) == 1
    assert model.graph.output[0].type.tensor_type.shape.dim[1].dim_value == 10
    act_node = model.graph.node[3]
    assert act_node.op_type == "Sign"
    matmul_node = model.graph.node[4]
    assert matmul_node.op_type == "MatMul"
    assert act_node.output[0] == matmul_node.input[0]
    inits = [x.name for x in model.graph.initializer]
    qnt_annotations = {
        a.tensor_name: a.quant_parameter_tensor_names[0].value
        for a in model.graph.quantization_annotation
    }
    assert qnt_annotations[matmul_node.input[0]] == "BIPOLAR"
    assert matmul_node.input[1] in inits
    assert qnt_annotations[matmul_node.input[1]] == "BIPOLAR"
    init_ind = inits.index(matmul_node.input[1])
    int_weights_pytorch = lfc.features[2].int_weight.transpose(1, 0).detach().numpy()
    int_weights_onnx = nph.to_array(model.graph.initializer[init_ind])
    assert (int_weights_onnx == int_weights_pytorch).all()
    os.remove(export_onnx_path)


def test_brevitas_trained_lfc_pytorch():
    # load pretrained weights into LFC-w1a1
    lfc = LFC(weight_bit_width=1, act_bit_width=1, in_bit_width=1).eval()
    checkpoint = torch.load(trained_lfc_checkpoint, map_location="cpu")
    lfc.load_state_dict(checkpoint["state_dict"])
    # download some MNIST test data
    try:
        os.remove("/tmp/" + mnist_onnx_filename)
    except OSError:
        pass
    dl_ret = wget.download(mnist_onnx_url_base + "/" + mnist_onnx_filename, out="/tmp")
    shutil.unpack_archive(dl_ret, mnist_onnx_local_dir)
    # load one of the test vectors
    input_tensor = onnx.TensorProto()
    with open(mnist_onnx_local_dir + "/mnist/test_data_set_0/input_0.pb", "rb") as f:
        input_tensor.ParseFromString(f.read())
    input_tensor = torch.from_numpy(nph.to_array(input_tensor)).float()
    assert input_tensor.shape == (1, 1, 28, 28)
    # do forward pass in PyTorch/Brevitas
    produced = lfc.forward(input_tensor).detach().numpy()
    expected = [
        [
            3.3253,
            -2.5652,
            9.2157,
            -1.4251,
            1.4251,
            -3.3728,
            0.2850,
            -0.5700,
            7.0781,
            -1.2826,
        ]
    ]
    assert np.isclose(produced, expected, atol=1e-4).all()
    # remove the downloaded model and extracted files
    os.remove(dl_ret)
    shutil.rmtree(mnist_onnx_local_dir)


def test_brevitas_to_onnx_export_and_exec():
    lfc = LFC(weight_bit_width=1, act_bit_width=1, in_bit_width=1)
    checkpoint = torch.load(trained_lfc_checkpoint, map_location="cpu")
    lfc.load_state_dict(checkpoint["state_dict"])
    bo.export_finn_onnx(lfc, (1, 1, 28, 28), export_onnx_path)
    model = onnx.load(export_onnx_path)
    # call ONNX shape inference to make sure we have value_info fields for all
    # the intermediate tensors in the graph
    model = si.infer_shapes(model)
    try:
        os.remove("/tmp/" + mnist_onnx_filename)
    except OSError:
        pass
    dl_ret = wget.download(mnist_onnx_url_base + "/" + mnist_onnx_filename, out="/tmp")
    shutil.unpack_archive(dl_ret, mnist_onnx_local_dir)
    # load one of the test vectors
    input_tensor = onnx.TensorProto()
    with open(mnist_onnx_local_dir + "/mnist/test_data_set_0/input_0.pb", "rb") as f:
        input_tensor.ParseFromString(f.read())
    # run using FINN-based execution
    input_dict = {"0": nph.to_array(input_tensor)}
    output_dict = oxe.execute_onnx(model, input_dict)
    produced = output_dict[list(output_dict.keys())[0]]
    # run using PyTorch/Brevitas
    input_tensor = torch.from_numpy(nph.to_array(input_tensor)).float()
    assert input_tensor.shape == (1, 1, 28, 28)
    # do forward pass in PyTorch/Brevitas
    expected = lfc.forward(input_tensor).detach().numpy()
    assert np.isclose(produced, expected, atol=1e-3).all()
    # remove the downloaded model and extracted files
    os.remove(dl_ret)
    shutil.rmtree(mnist_onnx_local_dir)
    os.remove(export_onnx_path)