From c2d59d3ff5fb70c47804e4b97e2d02fee1976429 Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu <yamanu@xilinx.com> Date: Fri, 9 Apr 2021 14:33:28 +0200 Subject: [PATCH] [Test] end2end test for UNSW-NB15 MLP --- tests/end2end/test_end2end_cybsec_mlp.py | 217 +++++++++++++++++++++++ 1 file changed, 217 insertions(+) create mode 100644 tests/end2end/test_end2end_cybsec_mlp.py diff --git a/tests/end2end/test_end2end_cybsec_mlp.py b/tests/end2end/test_end2end_cybsec_mlp.py new file mode 100644 index 000000000..13fb33eaf --- /dev/null +++ b/tests/end2end/test_end2end_cybsec_mlp.py @@ -0,0 +1,217 @@ +import torch +from brevitas.nn import QuantLinear, QuantReLU +import torch.nn as nn +import numpy as np +from brevitas.core.quant import QuantType +from brevitas.nn import QuantIdentity +import brevitas.onnx as bo +from finn.core.modelwrapper import ModelWrapper +from finn.core.datatype import DataType +import finn.builder.build_dataflow as build +import finn.builder.build_dataflow_config as build_cfg +import os +import shutil +from finn.util.test import get_build_env, load_test_checkpoint_or_skip +import pytest +from finn.util.basic import make_build_dir +import pkg_resources as pk +import json +import wget +import subprocess + +target_clk_ns = 10 +build_kind = "zynq" +build_dir = os.environ["FINN_BUILD_DIR"] + + +def get_checkpoint_name(step): + if step == "build": + # checkpoint for build step is an entire dir + return build_dir + "/end2end_cybsecmlp_build" + else: + # other checkpoints are onnx files + return build_dir + "/end2end_cybsecmlp_%s.onnx" % (step) + + +class CybSecMLPForExport(nn.Module): + def __init__(self, my_pretrained_model): + super(CybSecMLPForExport, self).__init__() + self.pretrained = my_pretrained_model + self.qnt_output = QuantIdentity( + quant_type=QuantType.BINARY, bit_width=1, min_val=-1.0, max_val=1.0 + ) + + def forward(self, x): + # assume x contains bipolar {-1,1} elems + # shift from {-1,1} -> {0,1} since that is the + # input range for the trained network + x = (x + torch.tensor([1.0])) / 2.0 + out_original = self.pretrained(x) + out_final = self.qnt_output(out_original) # output as {-1,1} + return out_final + + +def test_end2end_cybsec_mlp_export(): + assets_dir = pk.resource_filename("finn.qnn-data", "cybsec-mlp/") + # load up trained net in Brevitas + input_size = 593 + hidden1 = 64 + hidden2 = 64 + hidden3 = 64 + weight_bit_width = 2 + act_bit_width = 2 + num_classes = 1 + model = nn.Sequential( + QuantLinear(input_size, hidden1, bias=True, weight_bit_width=weight_bit_width), + nn.BatchNorm1d(hidden1), + nn.Dropout(0.5), + QuantReLU(bit_width=act_bit_width), + QuantLinear(hidden1, hidden2, bias=True, weight_bit_width=weight_bit_width), + nn.BatchNorm1d(hidden2), + nn.Dropout(0.5), + QuantReLU(bit_width=act_bit_width), + QuantLinear(hidden2, hidden3, bias=True, weight_bit_width=weight_bit_width), + nn.BatchNorm1d(hidden3), + nn.Dropout(0.5), + QuantReLU(bit_width=act_bit_width), + QuantLinear(hidden3, num_classes, bias=True, weight_bit_width=weight_bit_width), + ) + trained_state_dict = torch.load(assets_dir + "/state_dict.pth")[ + "models_state_dict" + ][0] + model.load_state_dict(trained_state_dict, strict=False) + W_orig = model[0].weight.data.detach().numpy() + # pad the second (593-sized) dimensions with 7 zeroes at the end + W_new = np.pad(W_orig, [(0, 0), (0, 7)]) + model[0].weight.data = torch.from_numpy(W_new) + model_for_export = CybSecMLPForExport(model) + export_onnx_path = get_checkpoint_name("export") + input_shape = (1, 600) + bo.export_finn_onnx(model_for_export, input_shape, export_onnx_path) + assert os.path.isfile(export_onnx_path) + # fix input datatype + finn_model = ModelWrapper(export_onnx_path) + finnonnx_in_tensor_name = finn_model.graph.input[0].name + finn_model.set_tensor_datatype(finnonnx_in_tensor_name, DataType.BIPOLAR) + finn_model.save(export_onnx_path) + assert tuple(finn_model.get_tensor_shape(finnonnx_in_tensor_name)) == (1, 600) + assert len(finn_model.graph.node) == 30 + assert finn_model.graph.node[0].op_type == "Add" + assert finn_model.graph.node[1].op_type == "Div" + assert finn_model.graph.node[2].op_type == "MatMul" + assert finn_model.graph.node[-1].op_type == "MultiThreshold" + + +@pytest.mark.slow +@pytest.mark.vivado +def test_end2end_cybsec_mlp_build(): + model_file = get_checkpoint_name("export") + load_test_checkpoint_or_skip(model_file) + build_env = get_build_env(build_kind, target_clk_ns) + output_dir = make_build_dir("test_end2end_cybsec_mlp_build") + + cfg = build.DataflowBuildConfig( + output_dir=output_dir, + target_fps=1000000, + synth_clk_period_ns=target_clk_ns, + board=build_env["board"], + shell_flow_type=build_cfg.ShellFlowType.VIVADO_ZYNQ, + generate_outputs=[ + build_cfg.DataflowOutputType.ESTIMATE_REPORTS, + build_cfg.DataflowOutputType.BITFILE, + build_cfg.DataflowOutputType.PYNQ_DRIVER, + build_cfg.DataflowOutputType.DEPLOYMENT_PACKAGE, + ], + ) + build.build_dataflow_cfg(model_file, cfg) + # check the generated files + assert os.path.isfile(output_dir + "/time_per_step.json") + assert os.path.isfile(output_dir + "/final_hw_config.json") + assert os.path.isfile(output_dir + "/driver/driver.py") + est_cycles_report = output_dir + "/report/estimate_layer_cycles.json" + assert os.path.isfile(est_cycles_report) + est_res_report = output_dir + "/report/estimate_layer_resources.json" + assert os.path.isfile(est_res_report) + assert os.path.isfile(output_dir + "/report/estimate_network_performance.json") + assert os.path.isfile(output_dir + "/bitfile/finn-accel.bit") + assert os.path.isfile(output_dir + "/bitfile/finn-accel.hwh") + assert os.path.isfile(output_dir + "/report/post_synth_resources.xml") + assert os.path.isfile(output_dir + "/report/post_route_timing.rpt") + # examine the report contents + with open(est_cycles_report, "r") as f: + est_cycles_dict = json.load(f) + assert est_cycles_dict["StreamingFCLayer_Batch_0"] == 80 + assert est_cycles_dict["StreamingFCLayer_Batch_1"] == 64 + with open(est_res_report, "r") as f: + est_res_dict = json.load(f) + assert est_res_dict["total"]["LUT"] == 11360.0 + assert est_res_dict["total"]["BRAM_18K"] == 36.0 + shutil.copytree(output_dir + "/deploy", get_checkpoint_name("build")) + + +def test_end2end_cybsec_mlp_run_on_hw(): + build_env = get_build_env(build_kind, target_clk_ns) + assets_dir = pk.resource_filename("finn.qnn-data", "cybsec-mlp/") + deploy_dir = get_checkpoint_name("build") + driver_dir = deploy_dir + "/driver" + assert os.path.isdir(deploy_dir) + assert os.path.isdir(driver_dir) + # put all assets into driver dir + shutil.copy(assets_dir + "/validate-unsw-nb15.py", driver_dir) + # put a copy of binarized dataset into driver dir + dataset_url = ( + "https://zenodo.org/record/4519767/files/unsw_nb15_binarized.npz?download=1" + ) + dataset_local = driver_dir + "/unsw_nb15_binarized.npz" + if not os.path.isfile(dataset_local): + wget.download(dataset_url, out=dataset_local) + assert os.path.isfile(dataset_local) + # create a shell script for running validation: 10 batches x 10 imgs + with open(driver_dir + "/validate.sh", "w") as f: + f.write( + """#!/bin/bash +cd %s/driver +echo %s | sudo -S python3.6 validate-unsw-nb15.py --batchsize=10 --limit_batches=10 + """ + % ( + build_env["target_dir"] + "/end2end_cybsecmlp_build", + build_env["password"], + ) + ) + # set up rsync command + remote_target = "%s@%s:%s" % ( + build_env["username"], + build_env["ip"], + build_env["target_dir"], + ) + rsync_res = subprocess.run( + [ + "sshpass", + "-p", + build_env["password"], + "rsync", + "-avz", + deploy_dir, + remote_target, + ] + ) + assert rsync_res.returncode == 0 + remote_verif_cmd = [ + "sshpass", + "-p", + build_env["password"], + "ssh", + "%s@%s" % (build_env["username"], build_env["ip"]), + "sh", + build_env["target_dir"] + "/end2end_cybsecmlp_build/driver/validate.sh", + ] + verif_res = subprocess.run( + remote_verif_cmd, + stdout=subprocess.PIPE, + universal_newlines=True, + input=build_env["password"], + ) + assert verif_res.returncode == 0 + log_output = verif_res.stdout.split("\n") + assert log_output[-3] == "batch 10 / 10 : total OK 93 NOK 7" + assert log_output[-2] == "Final accuracy: 93.000000" -- GitLab