Skip to content
Snippets Groups Projects
Commit 7e04af91 authored by Hendrik Borras's avatar Hendrik Borras
Browse files

Enable QONNX ingestion for test_end2end_cybsec_mlp_* tests.

parent 658c1730
No related branches found
No related tags found
No related merge requests found
...@@ -40,13 +40,16 @@ import torch ...@@ -40,13 +40,16 @@ import torch
import torch.nn as nn import torch.nn as nn
import wget import wget
from brevitas.core.quant import QuantType from brevitas.core.quant import QuantType
from brevitas.export.onnx.generic.manager import BrevitasONNXManager
from brevitas.nn import QuantIdentity, QuantLinear, QuantReLU from brevitas.nn import QuantIdentity, QuantLinear, QuantReLU
from brevitas.quant_tensor import QuantTensor from brevitas.quant_tensor import QuantTensor
from qonnx.util.cleanup import cleanup as qonnx_cleanup
import finn.builder.build_dataflow as build import finn.builder.build_dataflow as build
import finn.builder.build_dataflow_config as build_cfg import finn.builder.build_dataflow_config as build_cfg
from finn.core.datatype import DataType from finn.core.datatype import DataType
from finn.core.modelwrapper import ModelWrapper from finn.core.modelwrapper import ModelWrapper
from finn.transformation.qonnx.convert_qonnx_to_finn import ConvertQONNXtoFINN
from finn.util.basic import make_build_dir from finn.util.basic import make_build_dir
from finn.util.test import get_build_env, load_test_checkpoint_or_skip from finn.util.test import get_build_env, load_test_checkpoint_or_skip
...@@ -55,13 +58,13 @@ build_kind = "zynq" ...@@ -55,13 +58,13 @@ build_kind = "zynq"
build_dir = os.environ["FINN_BUILD_DIR"] build_dir = os.environ["FINN_BUILD_DIR"]
def get_checkpoint_name(step): def get_checkpoint_name(step, QONNX_export):
if step == "build": if step == "build":
# checkpoint for build step is an entire dir # checkpoint for build step is an entire dir
return build_dir + "/end2end_cybsecmlp_build" return build_dir + "/end2end_cybsecmlp_build_QONNX-%d" % (QONNX_export)
else: else:
# other checkpoints are onnx files # other checkpoints are onnx files
return build_dir + "/end2end_cybsecmlp_%s.onnx" % (step) return build_dir + "/end2end_cybsecmlp_QONNX-%d_%s.onnx" % (QONNX_export, step)
class CybSecMLPForExport(nn.Module): class CybSecMLPForExport(nn.Module):
...@@ -82,7 +85,8 @@ class CybSecMLPForExport(nn.Module): ...@@ -82,7 +85,8 @@ class CybSecMLPForExport(nn.Module):
return out_final return out_final
def test_end2end_cybsec_mlp_export(): @pytest.mark.parametrize("QONNX_export", [False, True])
def test_end2end_cybsec_mlp_export(QONNX_export):
assets_dir = pk.resource_filename("finn.qnn-data", "cybsec-mlp/") assets_dir = pk.resource_filename("finn.qnn-data", "cybsec-mlp/")
# load up trained net in Brevitas # load up trained net in Brevitas
input_size = 593 input_size = 593
...@@ -116,7 +120,7 @@ def test_end2end_cybsec_mlp_export(): ...@@ -116,7 +120,7 @@ def test_end2end_cybsec_mlp_export():
W_new = np.pad(W_orig, [(0, 0), (0, 7)]) W_new = np.pad(W_orig, [(0, 0), (0, 7)])
model[0].weight.data = torch.from_numpy(W_new) model[0].weight.data = torch.from_numpy(W_new)
model_for_export = CybSecMLPForExport(model) model_for_export = CybSecMLPForExport(model)
export_onnx_path = get_checkpoint_name("export") export_onnx_path = get_checkpoint_name("export", QONNX_export)
input_shape = (1, 600) input_shape = (1, 600)
# create a QuantTensor instance to mark the input as bipolar during export # create a QuantTensor instance to mark the input as bipolar during export
input_a = np.random.randint(0, 1, size=input_shape).astype(np.float32) input_a = np.random.randint(0, 1, size=input_shape).astype(np.float32)
...@@ -127,32 +131,60 @@ def test_end2end_cybsec_mlp_export(): ...@@ -127,32 +131,60 @@ def test_end2end_cybsec_mlp_export():
input_t, scale=torch.tensor(scale), bit_width=torch.tensor(1.0), signed=True input_t, scale=torch.tensor(scale), bit_width=torch.tensor(1.0), signed=True
) )
bo.export_finn_onnx( if QONNX_export:
model_for_export, export_path=export_onnx_path, input_t=input_qt # With the BrevitasONNXManager we need to manually set
) # the FINN DataType at the input
BrevitasONNXManager.export(
model_for_export, input_shape, export_path=export_onnx_path
)
model = ModelWrapper(export_onnx_path)
model.set_tensor_datatype(model.graph.input[0].name, DataType["BIPOLAR"])
model.save(export_onnx_path)
qonnx_cleanup(export_onnx_path, out_file=export_onnx_path)
model = ModelWrapper(export_onnx_path)
model = model.transform(ConvertQONNXtoFINN())
model.save(export_onnx_path)
else:
bo.export_finn_onnx(
model_for_export, export_path=export_onnx_path, input_t=input_qt
)
assert os.path.isfile(export_onnx_path) assert os.path.isfile(export_onnx_path)
# fix input datatype # fix input datatype
finn_model = ModelWrapper(export_onnx_path) finn_model = ModelWrapper(export_onnx_path)
finnonnx_in_tensor_name = finn_model.graph.input[0].name finnonnx_in_tensor_name = finn_model.graph.input[0].name
assert tuple(finn_model.get_tensor_shape(finnonnx_in_tensor_name)) == (1, 600) assert tuple(finn_model.get_tensor_shape(finnonnx_in_tensor_name)) == (1, 600)
# verify a few exported ops # verify a few exported ops
assert finn_model.graph.node[1].op_type == "Add" if QONNX_export:
assert finn_model.graph.node[2].op_type == "Div" # The first "Mul" node dosen't exist in the QONNX export,
assert finn_model.graph.node[3].op_type == "MatMul" # because the QuantTensor scale is not exported.
assert finn_model.graph.node[-1].op_type == "MultiThreshold" # However, this node would have been unity scale anyways and
# the models are still equivalent.
assert finn_model.graph.node[0].op_type == "Add"
assert finn_model.graph.node[1].op_type == "Div"
assert finn_model.graph.node[2].op_type == "MatMul"
assert finn_model.graph.node[-1].op_type == "MultiThreshold"
else:
assert finn_model.graph.node[0].op_type == "Mul"
assert finn_model.get_initializer(finn_model.graph.node[0].input[1]) == 1.0
assert finn_model.graph.node[1].op_type == "Add"
assert finn_model.graph.node[2].op_type == "Div"
assert finn_model.graph.node[3].op_type == "MatMul"
assert finn_model.graph.node[-1].op_type == "MultiThreshold"
# verify datatypes on some tensors # verify datatypes on some tensors
assert finn_model.get_tensor_datatype(finnonnx_in_tensor_name) == DataType.BIPOLAR assert finn_model.get_tensor_datatype(finnonnx_in_tensor_name) == DataType.BIPOLAR
first_matmul_w_name = finn_model.graph.node[3].input[1] first_matmul_w_name = finn_model.get_nodes_by_op_type("MatMul")[0].input[1]
assert finn_model.get_tensor_datatype(first_matmul_w_name) == DataType.INT2 assert finn_model.get_tensor_datatype(first_matmul_w_name) == DataType.INT2
@pytest.mark.slow @pytest.mark.slow
@pytest.mark.vivado @pytest.mark.vivado
def test_end2end_cybsec_mlp_build(): # @pytest.mark.parametrize("QONNX_export", [False, True])
model_file = get_checkpoint_name("export") @pytest.mark.parametrize("QONNX_export", [True])
def test_end2end_cybsec_mlp_build(QONNX_export):
model_file = get_checkpoint_name("export", QONNX_export)
load_test_checkpoint_or_skip(model_file) load_test_checkpoint_or_skip(model_file)
build_env = get_build_env(build_kind, target_clk_ns) build_env = get_build_env(build_kind, target_clk_ns)
output_dir = make_build_dir("test_end2end_cybsec_mlp_build") output_dir = make_build_dir(f"test_end2end_cybsec_mlp_build_QONNX-{QONNX_export}")
cfg = build.DataflowBuildConfig( cfg = build.DataflowBuildConfig(
output_dir=output_dir, output_dir=output_dir,
...@@ -190,13 +222,14 @@ def test_end2end_cybsec_mlp_build(): ...@@ -190,13 +222,14 @@ def test_end2end_cybsec_mlp_build():
est_res_dict = json.load(f) est_res_dict = json.load(f)
assert est_res_dict["total"]["LUT"] == 11360.0 assert est_res_dict["total"]["LUT"] == 11360.0
assert est_res_dict["total"]["BRAM_18K"] == 36.0 assert est_res_dict["total"]["BRAM_18K"] == 36.0
shutil.copytree(output_dir + "/deploy", get_checkpoint_name("build")) shutil.copytree(output_dir + "/deploy", get_checkpoint_name("build", QONNX_export))
def test_end2end_cybsec_mlp_run_on_hw(): @pytest.mark.parametrize("QONNX_export", [False, True])
def test_end2end_cybsec_mlp_run_on_hw(QONNX_export):
build_env = get_build_env(build_kind, target_clk_ns) build_env = get_build_env(build_kind, target_clk_ns)
assets_dir = pk.resource_filename("finn.qnn-data", "cybsec-mlp/") assets_dir = pk.resource_filename("finn.qnn-data", "cybsec-mlp/")
deploy_dir = get_checkpoint_name("build") deploy_dir = get_checkpoint_name("build", QONNX_export)
if not os.path.isdir(deploy_dir): if not os.path.isdir(deploy_dir):
pytest.skip(deploy_dir + " not found from previous test step, skipping") pytest.skip(deploy_dir + " not found from previous test step, skipping")
driver_dir = deploy_dir + "/driver" driver_dir = deploy_dir + "/driver"
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment