From e00cf5cf574cf7d945ed164c8b19b562fd574f3c Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu <yamanu@xilinx.com> Date: Thu, 7 Nov 2019 16:01:34 +0000 Subject: [PATCH] [Test] add test_infer_datatypes --- tests/test_infer_datatypes.py | 40 +++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) create mode 100644 tests/test_infer_datatypes.py diff --git a/tests/test_infer_datatypes.py b/tests/test_infer_datatypes.py new file mode 100644 index 000000000..fd1c99a1d --- /dev/null +++ b/tests/test_infer_datatypes.py @@ -0,0 +1,40 @@ +import os + +import brevitas.onnx as bo +import torch +from models.LFC import LFC + +import finn.transformation.fold_constants as fc +import finn.transformation.general as tg +import finn.transformation.infer_datatypes as id +import finn.transformation.infer_shapes as si +from finn.core.datatype import DataType +from finn.core.modelwrapper import ModelWrapper + +export_onnx_path = "test_output_lfc.onnx" +# TODO get from config instead, hardcoded to Docker path for now +trained_lfc_checkpoint = ( + "/workspace/brevitas_cnv_lfc/pretrained_models/LFC_1W1A/checkpoints/best.tar" +) + + +def test_infer_datatypes(): + lfc = LFC(weight_bit_width=1, act_bit_width=1, in_bit_width=1) + checkpoint = torch.load(trained_lfc_checkpoint, map_location="cpu") + lfc.load_state_dict(checkpoint["state_dict"]) + bo.export_finn_onnx(lfc, (1, 1, 28, 28), export_onnx_path) + model = ModelWrapper(export_onnx_path) + model = model.transform_single(si.infer_shapes) + model = model.transform_repeated(fc.fold_constants) + model = model.transform_single(tg.give_unique_node_names) + model = model.transform_single(tg.give_readable_tensor_names) + model = model.transform_repeated(id.infer_datatypes) + assert model.get_tensor_datatype("MatMul_0_out0") == DataType.INT32 + assert model.get_tensor_datatype("MatMul_1_out0") == DataType.INT32 + assert model.get_tensor_datatype("MatMul_2_out0") == DataType.INT32 + assert model.get_tensor_datatype("MatMul_3_out0") == DataType.INT32 + assert model.get_tensor_datatype("Sign_0_out0") == DataType.BIPOLAR + assert model.get_tensor_datatype("Sign_1_out0") == DataType.BIPOLAR + assert model.get_tensor_datatype("Sign_2_out0") == DataType.BIPOLAR + assert model.get_tensor_datatype("Sign_3_out0") == DataType.BIPOLAR + os.remove(export_onnx_path) -- GitLab