Skip to content
Snippets Groups Projects
Commit 73b8df59 authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Test] cybsec: use Brevitas' new capabilities for inp qnt marking

parent bdb15cbd
No related branches found
No related tags found
No related merge requests found
......@@ -28,6 +28,7 @@
import torch
from brevitas.nn import QuantLinear, QuantReLU
from brevitas.quant_tensor import QuantTensor
import torch.nn as nn
import numpy as np
from brevitas.core.quant import QuantType
......@@ -115,22 +116,31 @@ def test_end2end_cybsec_mlp_export():
model_for_export = CybSecMLPForExport(model)
export_onnx_path = get_checkpoint_name("export")
input_shape = (1, 600)
bo.export_finn_onnx(model_for_export, input_shape, export_onnx_path)
input_a = np.random.randint(0, 1, size=input_shape).astype(np.float32)
input_a = 2 * input_a - 1
scale = 1.0
input_t = torch.from_numpy(input_a * scale)
input_qt = QuantTensor(
input_t, scale=torch.tensor(scale), bit_width=torch.tensor(1.0), signed=True
)
bo.export_finn_onnx(
model_for_export, export_path=export_onnx_path, input_t=input_qt
)
assert os.path.isfile(export_onnx_path)
# fix input datatype
finn_model = ModelWrapper(export_onnx_path)
finnonnx_in_tensor_name = finn_model.graph.input[0].name
finn_model.set_tensor_datatype(finnonnx_in_tensor_name, DataType.BIPOLAR)
finn_model.save(export_onnx_path)
assert tuple(finn_model.get_tensor_shape(finnonnx_in_tensor_name)) == (1, 600)
# verify a few exported ops
assert finn_model.graph.node[0].op_type == "Add"
assert finn_model.graph.node[1].op_type == "Div"
assert finn_model.graph.node[2].op_type == "MatMul"
assert finn_model.graph.node[1].op_type == "Add"
assert finn_model.graph.node[2].op_type == "Div"
assert finn_model.graph.node[3].op_type == "MatMul"
assert finn_model.graph.node[-1].op_type == "MultiThreshold"
# verify datatypes on some tensors
assert finn_model.get_tensor_datatype(finnonnx_in_tensor_name) == DataType.BIPOLAR
first_matmul_w_name = finn_model.graph.node[2].input[1]
first_matmul_w_name = finn_model.graph.node[3].input[1]
assert finn_model.get_tensor_datatype(first_matmul_w_name) == DataType.INT2
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment