diff --git a/tests/end2end/test_end2end_cybsec_mlp.py b/tests/end2end/test_end2end_cybsec_mlp.py index 57e873516a6d3539c125c6f5fcc0a405e939a800..962153579024b1fbeae17b061c7ef8837e735fa1 100644 --- a/tests/end2end/test_end2end_cybsec_mlp.py +++ b/tests/end2end/test_end2end_cybsec_mlp.py @@ -28,6 +28,7 @@ import torch from brevitas.nn import QuantLinear, QuantReLU +from brevitas.quant_tensor import QuantTensor import torch.nn as nn import numpy as np from brevitas.core.quant import QuantType @@ -115,22 +116,31 @@ def test_end2end_cybsec_mlp_export(): model_for_export = CybSecMLPForExport(model) export_onnx_path = get_checkpoint_name("export") input_shape = (1, 600) - bo.export_finn_onnx(model_for_export, input_shape, export_onnx_path) + + input_a = np.random.randint(0, 1, size=input_shape).astype(np.float32) + input_a = 2 * input_a - 1 + scale = 1.0 + input_t = torch.from_numpy(input_a * scale) + input_qt = QuantTensor( + input_t, scale=torch.tensor(scale), bit_width=torch.tensor(1.0), signed=True + ) + + bo.export_finn_onnx( + model_for_export, export_path=export_onnx_path, input_t=input_qt + ) assert os.path.isfile(export_onnx_path) # fix input datatype finn_model = ModelWrapper(export_onnx_path) finnonnx_in_tensor_name = finn_model.graph.input[0].name - finn_model.set_tensor_datatype(finnonnx_in_tensor_name, DataType.BIPOLAR) - finn_model.save(export_onnx_path) assert tuple(finn_model.get_tensor_shape(finnonnx_in_tensor_name)) == (1, 600) # verify a few exported ops - assert finn_model.graph.node[0].op_type == "Add" - assert finn_model.graph.node[1].op_type == "Div" - assert finn_model.graph.node[2].op_type == "MatMul" + assert finn_model.graph.node[1].op_type == "Add" + assert finn_model.graph.node[2].op_type == "Div" + assert finn_model.graph.node[3].op_type == "MatMul" assert finn_model.graph.node[-1].op_type == "MultiThreshold" # verify datatypes on some tensors assert finn_model.get_tensor_datatype(finnonnx_in_tensor_name) == DataType.BIPOLAR - first_matmul_w_name = finn_model.graph.node[2].input[1] + first_matmul_w_name = finn_model.graph.node[3].input[1] assert finn_model.get_tensor_datatype(first_matmul_w_name) == DataType.INT2