Skip to content
Snippets Groups Projects
Commit 7e04af91 authored by Hendrik Borras's avatar Hendrik Borras
Browse files

Enable QONNX ingestion for test_end2end_cybsec_mlp_* tests.

parent 658c1730
No related branches found
No related tags found
No related merge requests found
......@@ -40,13 +40,16 @@ import torch
import torch.nn as nn
import wget
from brevitas.core.quant import QuantType
from brevitas.export.onnx.generic.manager import BrevitasONNXManager
from brevitas.nn import QuantIdentity, QuantLinear, QuantReLU
from brevitas.quant_tensor import QuantTensor
from qonnx.util.cleanup import cleanup as qonnx_cleanup
import finn.builder.build_dataflow as build
import finn.builder.build_dataflow_config as build_cfg
from finn.core.datatype import DataType
from finn.core.modelwrapper import ModelWrapper
from finn.transformation.qonnx.convert_qonnx_to_finn import ConvertQONNXtoFINN
from finn.util.basic import make_build_dir
from finn.util.test import get_build_env, load_test_checkpoint_or_skip
......@@ -55,13 +58,13 @@ build_kind = "zynq"
build_dir = os.environ["FINN_BUILD_DIR"]
def get_checkpoint_name(step):
def get_checkpoint_name(step, QONNX_export):
if step == "build":
# checkpoint for build step is an entire dir
return build_dir + "/end2end_cybsecmlp_build"
return build_dir + "/end2end_cybsecmlp_build_QONNX-%d" % (QONNX_export)
else:
# other checkpoints are onnx files
return build_dir + "/end2end_cybsecmlp_%s.onnx" % (step)
return build_dir + "/end2end_cybsecmlp_QONNX-%d_%s.onnx" % (QONNX_export, step)
class CybSecMLPForExport(nn.Module):
......@@ -82,7 +85,8 @@ class CybSecMLPForExport(nn.Module):
return out_final
def test_end2end_cybsec_mlp_export():
@pytest.mark.parametrize("QONNX_export", [False, True])
def test_end2end_cybsec_mlp_export(QONNX_export):
assets_dir = pk.resource_filename("finn.qnn-data", "cybsec-mlp/")
# load up trained net in Brevitas
input_size = 593
......@@ -116,7 +120,7 @@ def test_end2end_cybsec_mlp_export():
W_new = np.pad(W_orig, [(0, 0), (0, 7)])
model[0].weight.data = torch.from_numpy(W_new)
model_for_export = CybSecMLPForExport(model)
export_onnx_path = get_checkpoint_name("export")
export_onnx_path = get_checkpoint_name("export", QONNX_export)
input_shape = (1, 600)
# create a QuantTensor instance to mark the input as bipolar during export
input_a = np.random.randint(0, 1, size=input_shape).astype(np.float32)
......@@ -127,32 +131,60 @@ def test_end2end_cybsec_mlp_export():
input_t, scale=torch.tensor(scale), bit_width=torch.tensor(1.0), signed=True
)
bo.export_finn_onnx(
model_for_export, export_path=export_onnx_path, input_t=input_qt
)
if QONNX_export:
# With the BrevitasONNXManager we need to manually set
# the FINN DataType at the input
BrevitasONNXManager.export(
model_for_export, input_shape, export_path=export_onnx_path
)
model = ModelWrapper(export_onnx_path)
model.set_tensor_datatype(model.graph.input[0].name, DataType["BIPOLAR"])
model.save(export_onnx_path)
qonnx_cleanup(export_onnx_path, out_file=export_onnx_path)
model = ModelWrapper(export_onnx_path)
model = model.transform(ConvertQONNXtoFINN())
model.save(export_onnx_path)
else:
bo.export_finn_onnx(
model_for_export, export_path=export_onnx_path, input_t=input_qt
)
assert os.path.isfile(export_onnx_path)
# fix input datatype
finn_model = ModelWrapper(export_onnx_path)
finnonnx_in_tensor_name = finn_model.graph.input[0].name
assert tuple(finn_model.get_tensor_shape(finnonnx_in_tensor_name)) == (1, 600)
# verify a few exported ops
assert finn_model.graph.node[1].op_type == "Add"
assert finn_model.graph.node[2].op_type == "Div"
assert finn_model.graph.node[3].op_type == "MatMul"
assert finn_model.graph.node[-1].op_type == "MultiThreshold"
if QONNX_export:
# The first "Mul" node dosen't exist in the QONNX export,
# because the QuantTensor scale is not exported.
# However, this node would have been unity scale anyways and
# the models are still equivalent.
assert finn_model.graph.node[0].op_type == "Add"
assert finn_model.graph.node[1].op_type == "Div"
assert finn_model.graph.node[2].op_type == "MatMul"
assert finn_model.graph.node[-1].op_type == "MultiThreshold"
else:
assert finn_model.graph.node[0].op_type == "Mul"
assert finn_model.get_initializer(finn_model.graph.node[0].input[1]) == 1.0
assert finn_model.graph.node[1].op_type == "Add"
assert finn_model.graph.node[2].op_type == "Div"
assert finn_model.graph.node[3].op_type == "MatMul"
assert finn_model.graph.node[-1].op_type == "MultiThreshold"
# verify datatypes on some tensors
assert finn_model.get_tensor_datatype(finnonnx_in_tensor_name) == DataType.BIPOLAR
first_matmul_w_name = finn_model.graph.node[3].input[1]
first_matmul_w_name = finn_model.get_nodes_by_op_type("MatMul")[0].input[1]
assert finn_model.get_tensor_datatype(first_matmul_w_name) == DataType.INT2
@pytest.mark.slow
@pytest.mark.vivado
def test_end2end_cybsec_mlp_build():
model_file = get_checkpoint_name("export")
# @pytest.mark.parametrize("QONNX_export", [False, True])
@pytest.mark.parametrize("QONNX_export", [True])
def test_end2end_cybsec_mlp_build(QONNX_export):
model_file = get_checkpoint_name("export", QONNX_export)
load_test_checkpoint_or_skip(model_file)
build_env = get_build_env(build_kind, target_clk_ns)
output_dir = make_build_dir("test_end2end_cybsec_mlp_build")
output_dir = make_build_dir(f"test_end2end_cybsec_mlp_build_QONNX-{QONNX_export}")
cfg = build.DataflowBuildConfig(
output_dir=output_dir,
......@@ -190,13 +222,14 @@ def test_end2end_cybsec_mlp_build():
est_res_dict = json.load(f)
assert est_res_dict["total"]["LUT"] == 11360.0
assert est_res_dict["total"]["BRAM_18K"] == 36.0
shutil.copytree(output_dir + "/deploy", get_checkpoint_name("build"))
shutil.copytree(output_dir + "/deploy", get_checkpoint_name("build", QONNX_export))
def test_end2end_cybsec_mlp_run_on_hw():
@pytest.mark.parametrize("QONNX_export", [False, True])
def test_end2end_cybsec_mlp_run_on_hw(QONNX_export):
build_env = get_build_env(build_kind, target_clk_ns)
assets_dir = pk.resource_filename("finn.qnn-data", "cybsec-mlp/")
deploy_dir = get_checkpoint_name("build")
deploy_dir = get_checkpoint_name("build", QONNX_export)
if not os.path.isdir(deploy_dir):
pytest.skip(deploy_dir + " not found from previous test step, skipping")
driver_dir = deploy_dir + "/driver"
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment