From 73b8df59b1d06cfb8a484650873dc931bf31c443 Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu <maltanar@gmail.com> Date: Mon, 10 May 2021 11:48:57 +0100 Subject: [PATCH] [Test] cybsec: use Brevitas' new capabilities for inp qnt marking --- tests/end2end/test_end2end_cybsec_mlp.py | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/tests/end2end/test_end2end_cybsec_mlp.py b/tests/end2end/test_end2end_cybsec_mlp.py index 57e873516..962153579 100644 --- a/tests/end2end/test_end2end_cybsec_mlp.py +++ b/tests/end2end/test_end2end_cybsec_mlp.py @@ -28,6 +28,7 @@ import torch from brevitas.nn import QuantLinear, QuantReLU +from brevitas.quant_tensor import QuantTensor import torch.nn as nn import numpy as np from brevitas.core.quant import QuantType @@ -115,22 +116,31 @@ def test_end2end_cybsec_mlp_export(): model_for_export = CybSecMLPForExport(model) export_onnx_path = get_checkpoint_name("export") input_shape = (1, 600) - bo.export_finn_onnx(model_for_export, input_shape, export_onnx_path) + + input_a = np.random.randint(0, 1, size=input_shape).astype(np.float32) + input_a = 2 * input_a - 1 + scale = 1.0 + input_t = torch.from_numpy(input_a * scale) + input_qt = QuantTensor( + input_t, scale=torch.tensor(scale), bit_width=torch.tensor(1.0), signed=True + ) + + bo.export_finn_onnx( + model_for_export, export_path=export_onnx_path, input_t=input_qt + ) assert os.path.isfile(export_onnx_path) # fix input datatype finn_model = ModelWrapper(export_onnx_path) finnonnx_in_tensor_name = finn_model.graph.input[0].name - finn_model.set_tensor_datatype(finnonnx_in_tensor_name, DataType.BIPOLAR) - finn_model.save(export_onnx_path) assert tuple(finn_model.get_tensor_shape(finnonnx_in_tensor_name)) == (1, 600) # verify a few exported ops - assert finn_model.graph.node[0].op_type == "Add" - assert finn_model.graph.node[1].op_type == "Div" - assert finn_model.graph.node[2].op_type == "MatMul" + assert finn_model.graph.node[1].op_type == "Add" + assert finn_model.graph.node[2].op_type == "Div" + assert finn_model.graph.node[3].op_type == "MatMul" assert finn_model.graph.node[-1].op_type == "MultiThreshold" # verify datatypes on some tensors assert finn_model.get_tensor_datatype(finnonnx_in_tensor_name) == DataType.BIPOLAR - first_matmul_w_name = finn_model.graph.node[2].input[1] + first_matmul_w_name = finn_model.graph.node[3].input[1] assert finn_model.get_tensor_datatype(first_matmul_w_name) == DataType.INT2 -- GitLab