diff --git a/tests/test_infer_datatypes.py b/tests/test_infer_datatypes.py new file mode 100644 index 0000000000000000000000000000000000000000..fd1c99a1da5f5bdb93f0b8b089fddb20e3c078ed --- /dev/null +++ b/tests/test_infer_datatypes.py @@ -0,0 +1,40 @@ +import os + +import brevitas.onnx as bo +import torch +from models.LFC import LFC + +import finn.transformation.fold_constants as fc +import finn.transformation.general as tg +import finn.transformation.infer_datatypes as id +import finn.transformation.infer_shapes as si +from finn.core.datatype import DataType +from finn.core.modelwrapper import ModelWrapper + +export_onnx_path = "test_output_lfc.onnx" +# TODO get from config instead, hardcoded to Docker path for now +trained_lfc_checkpoint = ( + "/workspace/brevitas_cnv_lfc/pretrained_models/LFC_1W1A/checkpoints/best.tar" +) + + +def test_infer_datatypes(): + lfc = LFC(weight_bit_width=1, act_bit_width=1, in_bit_width=1) + checkpoint = torch.load(trained_lfc_checkpoint, map_location="cpu") + lfc.load_state_dict(checkpoint["state_dict"]) + bo.export_finn_onnx(lfc, (1, 1, 28, 28), export_onnx_path) + model = ModelWrapper(export_onnx_path) + model = model.transform_single(si.infer_shapes) + model = model.transform_repeated(fc.fold_constants) + model = model.transform_single(tg.give_unique_node_names) + model = model.transform_single(tg.give_readable_tensor_names) + model = model.transform_repeated(id.infer_datatypes) + assert model.get_tensor_datatype("MatMul_0_out0") == DataType.INT32 + assert model.get_tensor_datatype("MatMul_1_out0") == DataType.INT32 + assert model.get_tensor_datatype("MatMul_2_out0") == DataType.INT32 + assert model.get_tensor_datatype("MatMul_3_out0") == DataType.INT32 + assert model.get_tensor_datatype("Sign_0_out0") == DataType.BIPOLAR + assert model.get_tensor_datatype("Sign_1_out0") == DataType.BIPOLAR + assert model.get_tensor_datatype("Sign_2_out0") == DataType.BIPOLAR + assert model.get_tensor_datatype("Sign_3_out0") == DataType.BIPOLAR + os.remove(export_onnx_path)