{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Train a Quantized MLP on UNSW-NB15 with Brevitas"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<font color=\"red\">**Live FINN tutorial:** We recommend clicking **Cell -> Run All** when you start reading this notebook for \"latency hiding\".</font>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this notebook, we will show how to create, train and export a quantized Multi Layer Perceptron (MLP) with quantized weights and activations with [Brevitas](https://github.com/Xilinx/brevitas).\n",
    "Specifically, the task at hand will be to label network packets as normal or suspicious (e.g. originating from an attacker, virus, malware or otherwise) by training on a quantized variant of the UNSW-NB15 dataset. \n",
    "\n",
    "**You won't need a GPU to train the neural net.** This MLP will be small enough to train on a modern x86 CPU, so no GPU is required to follow this tutorial  Alternatively, we provide pre-trained parameters for the MLP if you want to skip the training entirely.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## A quick introduction to the task and the dataset\n",
    "\n",
    "*The task:* The goal of [*network intrusion detection*](https://ieeexplore.ieee.org/abstract/document/283931) is to identify, preferably in real time, unauthorized use, misuse, and abuse of computer systems by both system insiders and external penetrators. This may be achieved by a mix of techniques, and machine-learning (ML) based techniques are increasing in popularity. \n",
    "\n",
    "*The dataset:* Several datasets are available for use in ML-based methods for intrusion detection.\n",
    "The **UNSW-NB15** is one such dataset created by the Australian Centre for Cyber Security (ACCS) to provide a comprehensive network based data set which can reflect modern network traffic scenarios. You can find more details about the dataset on [its homepage](https://www.unsw.adfa.edu.au/unsw-canberra-cyber/cybersecurity/ADFA-NB15-Datasets/).\n",
    "\n",
    "*Performance considerations:* FPGAs are commonly used for implementing high-performance packet processing systems that still provide a degree of programmability. To avoid introducing bottlenecks on the network, the DNN implementation must be capable of detecting malicious ones at line rate, which can be millions of packets per second, and is expected to increase further as next-generation networking solutions provide increased\n",
    "throughput. This is a good reason to consider FPGA acceleration for this particular use-case."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Outline\n",
    "-------------\n",
    "\n",
    "* [Load the UNSW_NB15 Dataset](#load_dataset) \n",
    "* [Define the Quantized MLP Model](#define_quantized_mlp)\n",
    "* [Define Train and Test  Methods](#train_test)\n",
    "    * [(Option 1) Train the Model from Scratch](#train_scratch)\n",
    "    * [(Option 2) Load Pre-Trained Parameters](#load_pretrained)\n",
    "* [Network Surgery Before Export](#network_surgery)\n",
    "* [Export to FINN-ONNX](#export_finn_onnx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import onnx\n",
    "import torch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**This is important -- always import onnx before torch**. This is a workaround for a [known bug](https://github.com/onnx/onnx/issues/2394)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load the UNSW_NB15 Dataset <a id='load_dataset'></a>\n",
    "\n",
    "### Dataset Quantization <a id='dataset_qnt'></a>\n",
    "\n",
    "The goal of this notebook is to train a Quantized Neural Network (QNN) to be later deployed as an FPGA accelerator generated by the FINN compiler. Although we can choose a variety of different precisions for the input, [Murovic and Trost](https://ev.fe.uni-lj.si/1-2-2019/Murovic.pdf) have previously shown we can actually binarize the inputs and still get good (90%+) accuracy."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We will create a binarized representation for the dataset by following the procedure defined by Murovic and Trost, which we repeat briefly here:\n",
    "\n",
    "* Original features have different formats ranging from integers, floating numbers to strings.\n",
    "* Integers, which for example represent a packet lifetime, are binarized with as many bits as to include the maximum value. \n",
    "* Another case is with features formatted as strings (protocols), which are binarized by simply counting the number of all different strings for each feature and coding them in the appropriate number of bits.\n",
    "* Floating-point numbers are reformatted into fixed-point representation.\n",
    "* In the end, each sample is transformed into a 593-bit wide binary vector. \n",
    "* All vectors are labeled as bad (0) or normal (1)\n",
    "\n",
    "Following Murovic and Trost's open-source implementation provided as a Matlab script [here](https://github.com/TadejMurovic/BNN_Deployment/blob/master/cybersecurity_dataset_unswb15.m), we've created a [Python version](dataloader_quantized.py).\n",
    "\n",
    "<font color=\"red\">**Live FINN tutorial:** Downloading the original dataset and quantizing it can take some time, so we provide a download link to the pre-quantized version for your convenience. </font>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "--2022-01-13 11:50:50--  https://zenodo.org/record/4519767/files/unsw_nb15_binarized.npz?download=1\n",
      "Resolving zenodo.org (zenodo.org)... 137.138.76.77\n",
      "Connecting to zenodo.org (zenodo.org)|137.138.76.77|:443... connected.\n",
      "HTTP request sent, awaiting response... 200 OK\n",
      "Length: 13391907 (13M) [application/octet-stream]\n",
      "Saving to: 'unsw_nb15_binarized.npz'\n",
      "\n",
      "     0K .......... .......... .......... .......... ..........  0% 1.71M 7s\n",
      "    50K .......... .......... .......... .......... ..........  0% 2.20M 7s\n",
      "   100K .......... .......... .......... .......... ..........  1% 1.72M 7s\n",
      "   150K .......... .......... .......... .......... ..........  1% 2.59M 6s\n",
      "   200K .......... .......... .......... .......... ..........  1% 1.63M 7s\n",
      "   250K .......... .......... .......... .......... ..........  2% 1.97M 7s\n",
      "   300K .......... .......... .......... .......... ..........  2% 2.56M 6s\n",
      "   350K .......... .......... .......... .......... ..........  3% 1.10M 7s\n",
      "   400K .......... .......... .......... .......... ..........  3% 2.69M 7s\n",
      "   450K .......... .......... .......... .......... ..........  3% 2.03M 6s\n",
      "   500K .......... .......... .......... .......... ..........  4% 2.05M 6s\n",
      "   550K .......... .......... .......... .......... ..........  4% 1.97M 6s\n",
      "   600K .......... .......... .......... .......... ..........  4% 2.19M 6s\n",
      "   650K .......... .......... .......... .......... ..........  5% 2.31M 6s\n",
      "   700K .......... .......... .......... .......... ..........  5% 1.51M 6s\n",
      "   750K .......... .......... .......... .......... ..........  6% 2.82M 6s\n",
      "   800K .......... .......... .......... .......... ..........  6% 2.35M 6s\n",
      "   850K .......... .......... .......... .......... ..........  6% 1.08M 6s\n",
      "   900K .......... .......... .......... .......... ..........  7% 1.94M 6s\n",
      "   950K .......... .......... .......... .......... ..........  7% 1.69M 6s\n",
      "  1000K .......... .......... .......... .......... ..........  8% 2.32M 6s\n",
      "  1050K .......... .......... .......... .......... ..........  8% 2.34M 6s\n",
      "  1100K .......... .......... .......... .......... ..........  8% 2.58M 6s\n",
      "  1150K .......... .......... .......... .......... ..........  9%  949K 6s\n",
      "  1200K .......... .......... .......... .......... ..........  9% 1.41M 6s\n",
      "  1250K .......... .......... .......... .......... ..........  9% 4.16M 6s\n",
      "  1300K .......... .......... .......... .......... .......... 10% 2.29M 6s\n",
      "  1350K .......... .......... .......... .......... .......... 10% 2.14M 6s\n",
      "  1400K .......... .......... .......... .......... .......... 11% 2.21M 6s\n",
      "  1450K .......... .......... .......... .......... .......... 11% 2.22M 6s\n",
      "  1500K .......... .......... .......... .......... .......... 11% 1.79M 6s\n",
      "  1550K .......... .......... .......... .......... .......... 12%  990K 6s\n",
      "  1600K .......... .......... .......... .......... .......... 12% 2.12M 6s\n",
      "  1650K .......... .......... .......... .......... .......... 12% 1.81M 6s\n",
      "  1700K .......... .......... .......... .......... .......... 13% 2.50M 6s\n",
      "  1750K .......... .......... .......... .......... .......... 13% 2.39M 6s\n",
      "  1800K .......... .......... .......... .......... .......... 14%  993K 6s\n",
      "  1850K .......... .......... .......... .......... .......... 14% 3.66M 6s\n",
      "  1900K .......... .......... .......... .......... .......... 14% 2.05M 6s\n",
      "  1950K .......... .......... .......... .......... .......... 15% 1.83M 6s\n",
      "  2000K .......... .......... .......... .......... .......... 15% 2.60M 6s\n",
      "  2050K .......... .......... .......... .......... .......... 16% 2.04M 6s\n",
      "  2100K .......... .......... .......... .......... .......... 16% 2.24M 6s\n",
      "  2150K .......... .......... .......... .......... .......... 16% 2.02M 6s\n",
      "  2200K .......... .......... .......... .......... .......... 17% 1.54M 6s\n",
      "  2250K .......... .......... .......... .......... .......... 17% 1.06M 6s\n",
      "  2300K .......... .......... .......... .......... .......... 17% 3.33M 6s\n",
      "  2350K .......... .......... .......... .......... .......... 18% 1.77M 6s\n",
      "  2400K .......... .......... .......... .......... .......... 18% 1.44M 6s\n",
      "  2450K .......... .......... .......... .......... .......... 19% 2.65M 6s\n",
      "  2500K .......... .......... .......... .......... .......... 19% 1.78M 6s\n",
      "  2550K .......... .......... .......... .......... .......... 19% 1.39M 6s\n",
      "  2600K .......... .......... .......... .......... .......... 20% 2.13M 5s\n",
      "  2650K .......... .......... .......... .......... .......... 20% 1.77M 5s\n",
      "  2700K .......... .......... .......... .......... .......... 21% 1.60M 5s\n",
      "  2750K .......... .......... .......... .......... .......... 21%  677K 6s\n",
      "  2800K .......... .......... .......... .......... .......... 21% 5.16M 6s\n",
      "  2850K .......... .......... .......... .......... .......... 22% 1.72M 5s\n",
      "  2900K .......... .......... .......... .......... .......... 22% 13.3M 5s\n",
      "  2950K .......... .......... .......... .......... .......... 22% 2.50M 5s\n",
      "  3000K .......... .......... .......... .......... .......... 23% 1.37M 5s\n",
      "  3050K .......... .......... .......... .......... .......... 23%  734K 5s\n",
      "  3100K .......... .......... .......... .......... .......... 24% 9.25M 5s\n",
      "  3150K .......... .......... .......... .......... .......... 24% 1.18M 5s\n",
      "  3200K .......... .......... .......... .......... .......... 24% 60.6M 5s\n",
      "  3250K .......... .......... .......... .......... .......... 25% 1.64M 5s\n",
      "  3300K .......... .......... .......... .......... .......... 25% 1.76M 5s\n",
      "  3350K .......... .......... .......... .......... .......... 25% 1.54M 5s\n",
      "  3400K .......... .......... .......... .......... .......... 26% 1.56M 5s\n",
      "  3450K .......... .......... .......... .......... .......... 26% 1018K 5s\n",
      "  3500K .......... .......... .......... .......... .......... 27% 1.55M 5s\n",
      "  3550K .......... .......... .......... .......... .......... 27% 1.21M 5s\n",
      "  3600K .......... .......... .......... .......... .......... 27% 2.78M 5s\n",
      "  3650K .......... .......... .......... .......... .......... 28% 2.79M 5s\n",
      "  3700K .......... .......... .......... .......... .......... 28% 2.10M 5s\n",
      "  3750K .......... .......... .......... .......... .......... 29% 1.44M 5s\n",
      "  3800K .......... .......... .......... .......... .......... 29% 2.38M 5s\n",
      "  3850K .......... .......... .......... .......... .......... 29% 2.87M 5s\n",
      "  3900K .......... .......... .......... .......... .......... 30% 1.56M 5s\n",
      "  3950K .......... .......... .......... .......... .......... 30% 2.50M 5s\n",
      "  4000K .......... .......... .......... .......... .......... 30% 1.25M 5s\n",
      "  4050K .......... .......... .......... .......... .......... 31% 2.16M 5s\n",
      "  4100K .......... .......... .......... .......... .......... 31% 1.62M 5s\n",
      "  4150K .......... .......... .......... .......... .......... 32% 3.11M 5s\n",
      "  4200K .......... .......... .......... .......... .......... 32% 1.25M 5s\n",
      "  4250K .......... .......... .......... .......... .......... 32% 5.60M 5s\n",
      "  4300K .......... .......... .......... .......... .......... 33% 1.64M 5s\n",
      "  4350K .......... .......... .......... .......... .......... 33% 1.17M 5s\n",
      "  4400K .......... .......... .......... .......... .......... 34% 2.21M 5s\n",
      "  4450K .......... .......... .......... .......... .......... 34% 1.32M 5s\n"
     ]
    }
   ],
   "source": [
    "! wget -O unsw_nb15_binarized.npz https://zenodo.org/record/4519767/files/unsw_nb15_binarized.npz?download=1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can extract the binarized numpy arrays from the .npz archive and wrap them as a PyTorch `TensorDataset` as follows:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  4500K .......... .......... .......... .......... .......... 34% 1.93M 5s\n",
      "  4550K .......... .......... .......... .......... .......... 35% 2.94M 5s\n",
      "  4600K .......... .......... .......... .......... .......... 35% 1.73M 5s\n",
      "  4650K .......... .......... .......... .......... .......... 35% 1.79M 5s\n",
      "  4700K .......... .......... .......... .......... .......... 36% 1.82M 5s\n",
      "  4750K .......... .......... .......... .......... .......... 36% 3.63M 4s\n",
      "  4800K .......... .......... .......... .......... .......... 37% 1.59M 4s\n",
      "  4850K .......... .......... .......... .......... .......... 37% 1.92M 4s\n",
      "  4900K .......... .......... .......... .......... .......... 37% 2.08M 4s\n",
      "  4950K .......... .......... .......... .......... .......... 38% 1.46M 4s\n",
      "  5000K .......... .......... .......... .......... .......... 38% 3.30M 4s\n",
      "  5050K .......... .......... .......... .......... .......... 38% 1.73M 4s\n",
      "  5100K .......... .......... .......... .......... .......... 39% 2.67M 4s\n",
      "  5150K .......... .......... .......... .......... .......... 39% 1016K 4s\n",
      "  5200K .......... .......... .......... .......... .......... 40% 2.61M 4s\n",
      "  5250K .......... .......... .......... .......... .......... 40% 2.05M 4s\n",
      "  5300K .......... .......... .......... .......... .......... 40% 2.23M 4s\n",
      "  5350K .......... .......... .......... .......... .......... 41% 1.68M 4s\n",
      "  5400K .......... .......... .......... .......... .......... 41% 2.31M 4s\n",
      "  5450K .......... .......... .......... .......... .......... 42% 1.60M 4s\n",
      "  5500K .......... .......... .......... .......... .......... 42% 2.89M 4s\n",
      "  5550K .......... .......... .......... .......... .......... 42% 1.11M 4s\n",
      "  5600K .......... .......... .......... .......... .......... 43% 2.18M 4s\n",
      "  5650K .......... .......... .......... .......... .......... 43% 1.07M 4s\n",
      "  5700K .......... .......... .......... .......... .......... 43% 2.19M 4s\n",
      "  5750K .......... .......... .......... .......... .......... 44% 2.17M 4s\n",
      "  5800K .......... .......... .......... .......... .......... 44% 2.00M 4s\n",
      "  5850K .......... .......... .......... .......... .......... 45% 2.21M 4s\n",
      "  5900K .......... .......... .......... .......... .......... 45% 2.19M 4s\n",
      "  5950K .......... .......... .......... .......... .......... 45% 1.02M 4s\n",
      "  6000K .......... .......... .......... .......... .......... 46% 1.83M 4s\n",
      "  6050K .......... .......... .......... .......... .......... 46% 2.18M 4s\n",
      "  6100K .......... .......... .......... .......... .......... 47% 2.65M 4s\n",
      "  6150K .......... .......... .......... .......... .......... 47% 2.19M 4s\n",
      "  6200K .......... .......... .......... .......... .......... 47% 2.26M 4s\n",
      "  6250K .......... .......... .......... .......... .......... 48% 2.07M 4s\n",
      "  6300K .......... .......... .......... .......... .......... 48% 2.21M 4s\n",
      "  6350K .......... .......... .......... .......... .......... 48% 1.08M 4s\n",
      "  6400K .......... .......... .......... .......... .......... 49% 1.77M 4s\n",
      "  6450K .......... .......... .......... .......... .......... 49% 1.16M 4s\n",
      "  6500K .......... .......... .......... .......... .......... 50% 1.40M 4s\n",
      "  6550K .......... .......... .......... .......... .......... 50% 1.79M 4s\n",
      "  6600K .......... .......... .......... .......... .......... 50% 1.35M 3s\n",
      "  6650K .......... .......... .......... .......... .......... 51% 1.82M 3s\n",
      "  6700K .......... .......... .......... .......... .......... 51% 3.76M 3s\n",
      "  6750K .......... .......... .......... .......... .......... 51% 1.02M 3s\n",
      "  6800K .......... .......... .......... .......... .......... 52% 1.69M 3s\n",
      "  6850K .......... .......... .......... .......... .......... 52% 2.08M 3s\n",
      "  6900K .......... .......... .......... .......... .......... 53% 1.29M 3s\n",
      "  6950K .......... .......... .......... .......... .......... 53% 2.20M 3s\n",
      "  7000K .......... .......... .......... .......... .......... 53% 1.05M 3s\n",
      "  7050K .......... .......... .......... .......... .......... 54% 2.72M 3s\n",
      "  7100K .......... .......... .......... .......... .......... 54% 2.09M 3s\n",
      "  7150K .......... .......... .......... .......... .......... 55% 1.59M 3s\n",
      "  7200K .......... .......... .......... .......... .......... 55% 1.79M 3s\n",
      "  7250K .......... .......... .......... .......... .......... 55% 3.40M 3s\n",
      "  7300K .......... .......... .......... .......... .......... 56% 1.66M 3s\n",
      "  7350K .......... .......... .......... .......... .......... 56% 2.87M 3s\n",
      "  7400K .......... .......... .......... .......... .......... 56% 1.64M 3s\n",
      "  7450K .......... .......... .......... .......... .......... 57% 1.78M 3s\n",
      "  7500K .......... .......... .......... .......... .......... 57% 1.44M 3s\n",
      "  7550K .......... .......... .......... .......... .......... 58% 1.68M 3s\n",
      "  7600K .......... .......... .......... .......... .......... 58% 2.16M 3s\n",
      "  7650K .......... .......... .......... .......... .......... 58% 1.22M 3s\n",
      "  7700K .......... .......... .......... .......... .......... 59% 1.51M 3s\n",
      "  7750K .......... .......... .......... .......... .......... 59% 3.80M 3s\n",
      "  7800K .......... .......... .......... .......... .......... 60% 2.14M 3s\n",
      "  7850K .......... .......... .......... .......... .......... 60% 1.92M 3s\n",
      "  7900K .......... .......... .......... .......... .......... 60% 1.95M 3s\n",
      "  7950K .......... .......... .......... .......... .......... 61% 1.02M 3s\n",
      "  8000K .......... .......... .......... .......... .......... 61%  686K 3s\n",
      "  8050K .......... .......... .......... .......... .......... 61% 1.66M 3s\n",
      "  8100K .......... .......... .......... .......... .......... 62% 1.33M 3s\n",
      "  8150K .......... .......... .......... .......... .......... 62% 2.47M 3s\n",
      "  8200K .......... .......... .......... .......... .......... 63% 1.67M 3s\n",
      "  8250K .......... .......... .......... .......... .......... 63% 1.83M 3s\n",
      "  8300K .......... .......... .......... .......... .......... 63% 2.64M 3s\n",
      "  8350K .......... .......... .......... .......... .......... 64% 1.87M 3s\n",
      "  8400K .......... .......... .......... .......... .......... 64% 2.22M 3s\n",
      "  8450K .......... .......... .......... .......... .......... 64% 1.32M 3s\n",
      "  8500K .......... .......... .......... .......... .......... 65% 2.02M 2s\n",
      "  8550K .......... .......... .......... .......... .......... 65% 1.73M 2s\n",
      "  8600K .......... .......... .......... .......... .......... 66% 1.87M 2s\n",
      "  8650K .......... .......... .......... .......... .......... 66% 2.23M 2s\n",
      "  8700K .......... .......... .......... .......... .......... 66% 2.15M 2s\n",
      "  8750K .......... .......... .......... .......... .......... 67% 1.09M 2s\n",
      "  8800K .......... .......... .......... .......... .......... 67% 2.00M 2s\n",
      "  8850K .......... .......... .......... .......... .......... 68% 1.81M 2s\n",
      "  8900K .......... .......... .......... .......... .......... 68% 1.70M 2s\n",
      "  8950K .......... .......... .......... .......... .......... 68% 1.31M 2s\n",
      "  9000K .......... .......... .......... .......... .......... 69% 2.41M 2s\n",
      "  9050K .......... .......... .......... .......... .......... 69% 1.69M 2s\n",
      "  9100K .......... .......... .......... .......... .......... 69% 2.73M 2s\n",
      "  9150K .......... .......... .......... .......... .......... 70% 2.18M 2s\n",
      "  9200K .......... .......... .......... .......... .......... 70% 2.05M 2s\n",
      "  9250K .......... .......... .......... .......... .......... 71% 1.46M 2s\n",
      "  9300K .......... .......... .......... .......... .......... 71% 2.35M 2s\n",
      "  9350K .......... .......... .......... .......... .......... 71% 2.21M 2s\n",
      "  9400K .......... .......... .......... .......... .......... 72% 1.74M 2s\n",
      "  9450K .......... .......... .......... .......... .......... 72% 2.12M 2s\n",
      "  9500K .......... .......... .......... .......... .......... 73%  797K 2s\n",
      "  9550K .......... .......... .......... .......... .......... 73% 3.19M 2s\n",
      "  9600K .......... .......... .......... .......... .......... 73% 1.69M 2s\n",
      "  9650K .......... .......... .......... .......... .......... 74% 1.47M 2s\n",
      "  9700K .......... .......... .......... .......... .......... 74% 1.97M 2s\n",
      "  9750K .......... .......... .......... .......... .......... 74% 1.62M 2s\n",
      "  9800K .......... .......... .......... .......... .......... 75% 2.12M 2s\n",
      "  9850K .......... .......... .......... .......... .......... 75% 2.05M 2s\n",
      "  9900K .......... .......... .......... .......... .......... 76% 1.95M 2s\n",
      "  9950K .......... .......... .......... .......... .......... 76% 1.01M 2s\n",
      " 10000K .......... .......... .......... .......... .......... 76% 2.51M 2s\n",
      " 10050K .......... .......... .......... .......... .......... 77% 2.53M 2s\n",
      " 10100K .......... .......... .......... .......... .......... 77% 2.10M 2s\n",
      " 10150K .......... .......... .......... .......... .......... 77% 1.36M 2s\n",
      " 10200K .......... .......... .......... .......... .......... 78% 1.31M 2s\n",
      " 10250K .......... .......... .......... .......... .......... 78% 1.75M 2s\n",
      " 10300K .......... .......... .......... .......... .......... 79% 3.77M 2s\n",
      " 10350K .......... .......... .......... .......... .......... 79% 2.14M 1s\n",
      " 10400K .......... .......... .......... .......... .......... 79% 2.44M 1s\n",
      " 10450K .......... .......... .......... .......... .......... 80% 1.72M 1s\n",
      " 10500K .......... .......... .......... .......... .......... 80% 1.78M 1s\n",
      " 10550K .......... .......... .......... .......... .......... 81% 2.27M 1s\n",
      " 10600K .......... .......... .......... .......... .......... 81% 1.98M 1s\n",
      " 10650K .......... .......... .......... .......... .......... 81% 2.53M 1s\n",
      " 10700K .......... .......... .......... .......... .......... 82% 2.03M 1s\n",
      " 10750K .......... .......... .......... .......... .......... 82% 1.26M 1s\n",
      " 10800K .......... .......... .......... .......... .......... 82% 1.73M 1s\n",
      " 10850K .......... .......... .......... .......... .......... 83% 2.74M 1s\n",
      " 10900K .......... .......... .......... .......... .......... 83% 1.64M 1s\n",
      " 10950K .......... .......... .......... .......... .......... 84% 1.21M 1s\n",
      " 11000K .......... .......... .......... .......... .......... 84% 1.53M 1s\n",
      " 11050K .......... .......... .......... .......... .......... 84% 2.89M 1s\n",
      " 11100K .......... .......... .......... .......... .......... 85%  789K 1s\n",
      " 11150K .......... .......... .......... .......... .......... 85% 1.85M 1s\n",
      " 11200K .......... .......... .......... .......... .......... 86% 1.90M 1s\n",
      " 11250K .......... .......... .......... .......... .......... 86%  936K 1s\n",
      " 11300K .......... .......... .......... .......... .......... 86% 2.06M 1s\n",
      " 11350K .......... .......... .......... .......... .......... 87% 1.95M 1s\n",
      " 11400K .......... .......... .......... .......... .......... 87% 1.97M 1s\n",
      " 11450K .......... .......... .......... .......... .......... 87% 2.29M 1s\n",
      " 11500K .......... .......... .......... .......... .......... 88% 2.09M 1s\n",
      " 11550K .......... .......... .......... .......... .......... 88%  957K 1s\n",
      " 11600K .......... .......... .......... .......... .......... 89% 2.78M 1s\n",
      " 11650K .......... .......... .......... .......... .......... 89% 1.89M 1s\n",
      " 11700K .......... .......... .......... .......... .......... 89% 1.96M 1s\n",
      " 11750K .......... .......... .......... .......... .......... 90% 1.71M 1s\n",
      " 11800K .......... .......... .......... .......... .......... 90% 1.93M 1s\n",
      " 11850K .......... .......... .......... .......... .......... 90%  523K 1s\n",
      " 11900K .......... .......... .......... .......... .......... 91% 7.70M 1s\n",
      " 11950K .......... .......... .......... .......... .......... 91% 42.7M 1s\n",
      " 12000K .......... .......... .......... .......... .......... 92% 4.57M 1s\n",
      " 12050K .......... .......... .......... .......... .......... 92% 1.63M 1s\n",
      " 12100K .......... .......... .......... .......... .......... 92% 2.15M 1s\n",
      " 12150K .......... .......... .......... .......... .......... 93% 1.64M 0s\n",
      " 12200K .......... .......... .......... .......... .......... 93% 1.49M 0s\n",
      " 12250K .......... .......... .......... .......... .......... 94% 2.46M 0s\n",
      " 12300K .......... .......... .......... .......... .......... 94% 1.54M 0s\n",
      " 12350K .......... .......... .......... .......... .......... 94% 1.24M 0s\n",
      " 12400K .......... .......... .......... .......... .......... 95% 2.09M 0s\n",
      " 12450K .......... .......... .......... .......... .......... 95% 1.63M 0s\n",
      " 12500K .......... .......... .......... .......... .......... 95% 2.97M 0s\n",
      " 12550K .......... .......... .......... .......... .......... 96% 2.20M 0s\n",
      " 12600K .......... .......... .......... .......... .......... 96% 1.58M 0s\n",
      " 12650K .......... .......... .......... .......... .......... 97% 2.16M 0s\n",
      " 12700K .......... .......... .......... .......... .......... 97% 2.64M 0s\n",
      " 12750K .......... .......... .......... .......... .......... 97% 1.12M 0s\n",
      " 12800K .......... .......... .......... .......... .......... 98% 2.23M 0s\n",
      " 12850K .......... .......... .......... .......... .......... 98% 2.04M 0s\n",
      " 12900K .......... .......... .......... .......... .......... 99% 1.94M 0s\n",
      " 12950K .......... .......... .......... .......... .......... 99% 2.25M 0s\n",
      " 13000K .......... .......... .......... .......... .......... 99% 2.04M 0s\n",
      " 13050K .......... .......... ........                        100% 1.63M=7.2s\n",
      "\n",
      "2022-01-13 11:50:58 (1.77 MB/s) - 'unsw_nb15_binarized.npz' saved [13391907/13391907]\n",
      "\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Samples in each set: train = 175341, test = 82332\n",
      "Shape of one input sample: torch.Size([593])\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "from torch.utils.data import TensorDataset\n",
    "\n",
    "def get_preqnt_dataset(data_dir: str, train: bool):\n",
    "    unsw_nb15_data = np.load(data_dir + \"/unsw_nb15_binarized.npz\")\n",
    "    if train:\n",
    "        partition = \"train\"\n",
    "    else:\n",
    "        partition = \"test\"\n",
    "    part_data = unsw_nb15_data[partition].astype(np.float32)\n",
    "    part_data = torch.from_numpy(part_data)\n",
    "    part_data_in = part_data[:, :-1]\n",
    "    part_data_out = part_data[:, -1]\n",
    "    return TensorDataset(part_data_in, part_data_out)\n",
    "\n",
    "train_quantized_dataset = get_preqnt_dataset(\".\", True)\n",
    "test_quantized_dataset = get_preqnt_dataset(\".\", False)\n",
    "\n",
    "print(\"Samples in each set: train = %d, test = %s\" % (len(train_quantized_dataset), len(test_quantized_dataset))) \n",
    "print(\"Shape of one input sample: \" +  str(train_quantized_dataset[0][0].shape))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Set up DataLoader\n",
    "\n",
    "Following either option, we now have access to the quantized dataset. We will wrap the dataset in a PyTorch `DataLoader` for easier access in batches."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import DataLoader, Dataset\n",
    "\n",
    "batch_size = 1000\n",
    "\n",
    "# dataset loaders\n",
    "train_quantized_loader = DataLoader(train_quantized_dataset, batch_size=batch_size, shuffle=True)\n",
    "test_quantized_loader = DataLoader(test_quantized_dataset, batch_size=batch_size, shuffle=False)    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Input shape for 1 batch: torch.Size([1000, 593])\n",
      "Label shape for 1 batch: torch.Size([1000])\n"
     ]
    }
   ],
   "source": [
    "count = 0\n",
    "for x,y in train_quantized_loader:\n",
    "    print(\"Input shape for 1 batch: \" + str(x.shape))\n",
    "    print(\"Label shape for 1 batch: \" + str(y.shape))\n",
    "    count += 1\n",
    "    if count == 1:\n",
    "        break"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Define a PyTorch Device <a id='define_pytorch_device'></a> \n",
    "\n",
    "GPUs can significantly speed-up training of deep neural networks. We check for availability of a GPU and if so define it as target device."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Target device: cuda\n"
     ]
    }
   ],
   "source": [
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "print(\"Target device: \" + str(device))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Define the Quantized MLP Model <a id='define_quantized_mlp'></a>\n",
    "\n",
    "We'll now define an MLP model that will be trained to perform inference with quantized weights and activations.\n",
    "For this, we'll use the quantization-aware training (QAT) capabilities offered by [Brevitas](https://github.com/Xilinx/brevitas).\n",
    "\n",
    "Our MLP will have four fully-connected (FC) layers in total: three hidden layers with 64 neurons, and a final output layer with a single output, all using 2-bit weights. We'll use 2-bit quantized ReLU activation functions, and apply batch normalization between each FC layer and its activation.\n",
    "\n",
    "In case you'd like to experiment with different quantization settings or topology parameters, we'll define all these topology settings as variables."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "input_size = 593      \n",
    "hidden1 = 64      \n",
    "hidden2 = 64\n",
    "hidden3 = 64\n",
    "weight_bit_width = 2\n",
    "act_bit_width = 2\n",
    "num_classes = 1    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we can define our MLP using the layer primitives provided by Brevitas:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Sequential(\n",
       "  (0): QuantLinear(\n",
       "    in_features=593, out_features=64, bias=True\n",
       "    (input_quant): ActQuantProxyFromInjector(\n",
       "      (_zero_hw_sentinel): StatelessBuffer()\n",
       "    )\n",
       "    (output_quant): ActQuantProxyFromInjector(\n",
       "      (_zero_hw_sentinel): StatelessBuffer()\n",
       "    )\n",
       "    (weight_quant): WeightQuantProxyFromInjector(\n",
       "      (_zero_hw_sentinel): StatelessBuffer()\n",
       "      (tensor_quant): RescalingIntQuant(\n",
       "        (int_quant): IntQuant(\n",
       "          (float_to_int_impl): RoundSte()\n",
       "          (tensor_clamp_impl): TensorClampSte()\n",
       "          (delay_wrapper): DelayWrapper(\n",
       "            (delay_impl): _NoDelay()\n",
       "          )\n",
       "        )\n",
       "        (scaling_impl): StatsFromParameterScaling(\n",
       "          (parameter_list_stats): _ParameterListStats(\n",
       "            (first_tracked_param): _ViewParameterWrapper(\n",
       "              (view_shape_impl): OverTensorView()\n",
       "            )\n",
       "            (stats): _Stats(\n",
       "              (stats_impl): AbsMax()\n",
       "            )\n",
       "          )\n",
       "          (stats_scaling_impl): _StatsScaling(\n",
       "            (affine_rescaling): Identity()\n",
       "            (restrict_clamp_scaling): _RestrictClampValue(\n",
       "              (clamp_min_ste): ScalarClampMinSte()\n",
       "              (restrict_value_impl): FloatRestrictValue()\n",
       "            )\n",
       "            (restrict_scaling_pre): Identity()\n",
       "          )\n",
       "        )\n",
       "        (int_scaling_impl): IntScaling()\n",
       "        (zero_point_impl): ZeroZeroPoint(\n",
       "          (zero_point): StatelessBuffer()\n",
       "        )\n",
       "        (msb_clamp_bit_width_impl): BitWidthConst(\n",
       "          (bit_width): StatelessBuffer()\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "    (bias_quant): BiasQuantProxyFromInjector(\n",
       "      (_zero_hw_sentinel): StatelessBuffer()\n",
       "    )\n",
       "  )\n",
       "  (1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  (2): Dropout(p=0.5, inplace=False)\n",
       "  (3): QuantReLU(\n",
       "    (input_quant): ActQuantProxyFromInjector(\n",
       "      (_zero_hw_sentinel): StatelessBuffer()\n",
       "    )\n",
       "    (act_quant): ActQuantProxyFromInjector(\n",
       "      (_zero_hw_sentinel): StatelessBuffer()\n",
       "      (fused_activation_quant_proxy): FusedActivationQuantProxy(\n",
       "        (activation_impl): ReLU()\n",
       "        (tensor_quant): RescalingIntQuant(\n",
       "          (int_quant): IntQuant(\n",
       "            (float_to_int_impl): RoundSte()\n",
       "            (tensor_clamp_impl): TensorClamp()\n",
       "            (delay_wrapper): DelayWrapper(\n",
       "              (delay_impl): _NoDelay()\n",
       "            )\n",
       "          )\n",
       "          (scaling_impl): ParameterFromRuntimeStatsScaling(\n",
       "            (stats_input_view_shape_impl): OverTensorView()\n",
       "            (stats): _Stats(\n",
       "              (stats_impl): AbsPercentile()\n",
       "            )\n",
       "            (restrict_clamp_scaling): _RestrictClampValue(\n",
       "              (clamp_min_ste): ScalarClampMinSte()\n",
       "              (restrict_value_impl): FloatRestrictValue()\n",
       "            )\n",
       "            (restrict_inplace_preprocess): Identity()\n",
       "            (restrict_preprocess): Identity()\n",
       "          )\n",
       "          (int_scaling_impl): IntScaling()\n",
       "          (zero_point_impl): ZeroZeroPoint(\n",
       "            (zero_point): StatelessBuffer()\n",
       "          )\n",
       "          (msb_clamp_bit_width_impl): BitWidthConst(\n",
       "            (bit_width): StatelessBuffer()\n",
       "          )\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "  )\n",
       "  (4): QuantLinear(\n",
       "    in_features=64, out_features=64, bias=True\n",
       "    (input_quant): ActQuantProxyFromInjector(\n",
       "      (_zero_hw_sentinel): StatelessBuffer()\n",
       "    )\n",
       "    (output_quant): ActQuantProxyFromInjector(\n",
       "      (_zero_hw_sentinel): StatelessBuffer()\n",
       "    )\n",
       "    (weight_quant): WeightQuantProxyFromInjector(\n",
       "      (_zero_hw_sentinel): StatelessBuffer()\n",
       "      (tensor_quant): RescalingIntQuant(\n",
       "        (int_quant): IntQuant(\n",
       "          (float_to_int_impl): RoundSte()\n",
       "          (tensor_clamp_impl): TensorClampSte()\n",
       "          (delay_wrapper): DelayWrapper(\n",
       "            (delay_impl): _NoDelay()\n",
       "          )\n",
       "        )\n",
       "        (scaling_impl): StatsFromParameterScaling(\n",
       "          (parameter_list_stats): _ParameterListStats(\n",
       "            (first_tracked_param): _ViewParameterWrapper(\n",
       "              (view_shape_impl): OverTensorView()\n",
       "            )\n",
       "            (stats): _Stats(\n",
       "              (stats_impl): AbsMax()\n",
       "            )\n",
       "          )\n",
       "          (stats_scaling_impl): _StatsScaling(\n",
       "            (affine_rescaling): Identity()\n",
       "            (restrict_clamp_scaling): _RestrictClampValue(\n",
       "              (clamp_min_ste): ScalarClampMinSte()\n",
       "              (restrict_value_impl): FloatRestrictValue()\n",
       "            )\n",
       "            (restrict_scaling_pre): Identity()\n",
       "          )\n",
       "        )\n",
       "        (int_scaling_impl): IntScaling()\n",
       "        (zero_point_impl): ZeroZeroPoint(\n",
       "          (zero_point): StatelessBuffer()\n",
       "        )\n",
       "        (msb_clamp_bit_width_impl): BitWidthConst(\n",
       "          (bit_width): StatelessBuffer()\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "    (bias_quant): BiasQuantProxyFromInjector(\n",
       "      (_zero_hw_sentinel): StatelessBuffer()\n",
       "    )\n",
       "  )\n",
       "  (5): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  (6): Dropout(p=0.5, inplace=False)\n",
       "  (7): QuantReLU(\n",
       "    (input_quant): ActQuantProxyFromInjector(\n",
       "      (_zero_hw_sentinel): StatelessBuffer()\n",
       "    )\n",
       "    (act_quant): ActQuantProxyFromInjector(\n",
       "      (_zero_hw_sentinel): StatelessBuffer()\n",
       "      (fused_activation_quant_proxy): FusedActivationQuantProxy(\n",
       "        (activation_impl): ReLU()\n",
       "        (tensor_quant): RescalingIntQuant(\n",
       "          (int_quant): IntQuant(\n",
       "            (float_to_int_impl): RoundSte()\n",
       "            (tensor_clamp_impl): TensorClamp()\n",
       "            (delay_wrapper): DelayWrapper(\n",
       "              (delay_impl): _NoDelay()\n",
       "            )\n",
       "          )\n",
       "          (scaling_impl): ParameterFromRuntimeStatsScaling(\n",
       "            (stats_input_view_shape_impl): OverTensorView()\n",
       "            (stats): _Stats(\n",
       "              (stats_impl): AbsPercentile()\n",
       "            )\n",
       "            (restrict_clamp_scaling): _RestrictClampValue(\n",
       "              (clamp_min_ste): ScalarClampMinSte()\n",
       "              (restrict_value_impl): FloatRestrictValue()\n",
       "            )\n",
       "            (restrict_inplace_preprocess): Identity()\n",
       "            (restrict_preprocess): Identity()\n",
       "          )\n",
       "          (int_scaling_impl): IntScaling()\n",
       "          (zero_point_impl): ZeroZeroPoint(\n",
       "            (zero_point): StatelessBuffer()\n",
       "          )\n",
       "          (msb_clamp_bit_width_impl): BitWidthConst(\n",
       "            (bit_width): StatelessBuffer()\n",
       "          )\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "  )\n",
       "  (8): QuantLinear(\n",
       "    in_features=64, out_features=64, bias=True\n",
       "    (input_quant): ActQuantProxyFromInjector(\n",
       "      (_zero_hw_sentinel): StatelessBuffer()\n",
       "    )\n",
       "    (output_quant): ActQuantProxyFromInjector(\n",
       "      (_zero_hw_sentinel): StatelessBuffer()\n",
       "    )\n",
       "    (weight_quant): WeightQuantProxyFromInjector(\n",
       "      (_zero_hw_sentinel): StatelessBuffer()\n",
       "      (tensor_quant): RescalingIntQuant(\n",
       "        (int_quant): IntQuant(\n",
       "          (float_to_int_impl): RoundSte()\n",
       "          (tensor_clamp_impl): TensorClampSte()\n",
       "          (delay_wrapper): DelayWrapper(\n",
       "            (delay_impl): _NoDelay()\n",
       "          )\n",
       "        )\n",
       "        (scaling_impl): StatsFromParameterScaling(\n",
       "          (parameter_list_stats): _ParameterListStats(\n",
       "            (first_tracked_param): _ViewParameterWrapper(\n",
       "              (view_shape_impl): OverTensorView()\n",
       "            )\n",
       "            (stats): _Stats(\n",
       "              (stats_impl): AbsMax()\n",
       "            )\n",
       "          )\n",
       "          (stats_scaling_impl): _StatsScaling(\n",
       "            (affine_rescaling): Identity()\n",
       "            (restrict_clamp_scaling): _RestrictClampValue(\n",
       "              (clamp_min_ste): ScalarClampMinSte()\n",
       "              (restrict_value_impl): FloatRestrictValue()\n",
       "            )\n",
       "            (restrict_scaling_pre): Identity()\n",
       "          )\n",
       "        )\n",
       "        (int_scaling_impl): IntScaling()\n",
       "        (zero_point_impl): ZeroZeroPoint(\n",
       "          (zero_point): StatelessBuffer()\n",
       "        )\n",
       "        (msb_clamp_bit_width_impl): BitWidthConst(\n",
       "          (bit_width): StatelessBuffer()\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "    (bias_quant): BiasQuantProxyFromInjector(\n",
       "      (_zero_hw_sentinel): StatelessBuffer()\n",
       "    )\n",
       "  )\n",
       "  (9): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  (10): Dropout(p=0.5, inplace=False)\n",
       "  (11): QuantReLU(\n",
       "    (input_quant): ActQuantProxyFromInjector(\n",
       "      (_zero_hw_sentinel): StatelessBuffer()\n",
       "    )\n",
       "    (act_quant): ActQuantProxyFromInjector(\n",
       "      (_zero_hw_sentinel): StatelessBuffer()\n",
       "      (fused_activation_quant_proxy): FusedActivationQuantProxy(\n",
       "        (activation_impl): ReLU()\n",
       "        (tensor_quant): RescalingIntQuant(\n",
       "          (int_quant): IntQuant(\n",
       "            (float_to_int_impl): RoundSte()\n",
       "            (tensor_clamp_impl): TensorClamp()\n",
       "            (delay_wrapper): DelayWrapper(\n",
       "              (delay_impl): _NoDelay()\n",
       "            )\n",
       "          )\n",
       "          (scaling_impl): ParameterFromRuntimeStatsScaling(\n",
       "            (stats_input_view_shape_impl): OverTensorView()\n",
       "            (stats): _Stats(\n",
       "              (stats_impl): AbsPercentile()\n",
       "            )\n",
       "            (restrict_clamp_scaling): _RestrictClampValue(\n",
       "              (clamp_min_ste): ScalarClampMinSte()\n",
       "              (restrict_value_impl): FloatRestrictValue()\n",
       "            )\n",
       "            (restrict_inplace_preprocess): Identity()\n",
       "            (restrict_preprocess): Identity()\n",
       "          )\n",
       "          (int_scaling_impl): IntScaling()\n",
       "          (zero_point_impl): ZeroZeroPoint(\n",
       "            (zero_point): StatelessBuffer()\n",
       "          )\n",
       "          (msb_clamp_bit_width_impl): BitWidthConst(\n",
       "            (bit_width): StatelessBuffer()\n",
       "          )\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "  )\n",
       "  (12): QuantLinear(\n",
       "    in_features=64, out_features=1, bias=True\n",
       "    (input_quant): ActQuantProxyFromInjector(\n",
       "      (_zero_hw_sentinel): StatelessBuffer()\n",
       "    )\n",
       "    (output_quant): ActQuantProxyFromInjector(\n",
       "      (_zero_hw_sentinel): StatelessBuffer()\n",
       "    )\n",
       "    (weight_quant): WeightQuantProxyFromInjector(\n",
       "      (_zero_hw_sentinel): StatelessBuffer()\n",
       "      (tensor_quant): RescalingIntQuant(\n",
       "        (int_quant): IntQuant(\n",
       "          (float_to_int_impl): RoundSte()\n",
       "          (tensor_clamp_impl): TensorClampSte()\n",
       "          (delay_wrapper): DelayWrapper(\n",
       "            (delay_impl): _NoDelay()\n",
       "          )\n",
       "        )\n",
       "        (scaling_impl): StatsFromParameterScaling(\n",
       "          (parameter_list_stats): _ParameterListStats(\n",
       "            (first_tracked_param): _ViewParameterWrapper(\n",
       "              (view_shape_impl): OverTensorView()\n",
       "            )\n",
       "            (stats): _Stats(\n",
       "              (stats_impl): AbsMax()\n",
       "            )\n",
       "          )\n",
       "          (stats_scaling_impl): _StatsScaling(\n",
       "            (affine_rescaling): Identity()\n",
       "            (restrict_clamp_scaling): _RestrictClampValue(\n",
       "              (clamp_min_ste): ScalarClampMinSte()\n",
       "              (restrict_value_impl): FloatRestrictValue()\n",
       "            )\n",
       "            (restrict_scaling_pre): Identity()\n",
       "          )\n",
       "        )\n",
       "        (int_scaling_impl): IntScaling()\n",
       "        (zero_point_impl): ZeroZeroPoint(\n",
       "          (zero_point): StatelessBuffer()\n",
       "        )\n",
       "        (msb_clamp_bit_width_impl): BitWidthConst(\n",
       "          (bit_width): StatelessBuffer()\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "    (bias_quant): BiasQuantProxyFromInjector(\n",
       "      (_zero_hw_sentinel): StatelessBuffer()\n",
       "    )\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from brevitas.nn import QuantLinear, QuantReLU\n",
    "import torch.nn as nn\n",
    "\n",
    "# Setting seeds for reproducibility\n",
    "torch.manual_seed(0)\n",
    "\n",
    "model = nn.Sequential(\n",
    "      QuantLinear(input_size, hidden1, bias=True, weight_bit_width=weight_bit_width),\n",
    "      nn.BatchNorm1d(hidden1),\n",
    "      nn.Dropout(0.5),\n",
    "      QuantReLU(bit_width=act_bit_width),\n",
    "      QuantLinear(hidden1, hidden2, bias=True, weight_bit_width=weight_bit_width),\n",
    "      nn.BatchNorm1d(hidden2),\n",
    "      nn.Dropout(0.5),\n",
    "      QuantReLU(bit_width=act_bit_width),\n",
    "      QuantLinear(hidden2, hidden3, bias=True, weight_bit_width=weight_bit_width),\n",
    "      nn.BatchNorm1d(hidden3),\n",
    "      nn.Dropout(0.5),\n",
    "      QuantReLU(bit_width=act_bit_width),\n",
    "      QuantLinear(hidden3, num_classes, bias=True, weight_bit_width=weight_bit_width)\n",
    ")\n",
    "\n",
    "model.to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note that the MLP's output is not yet quantized. Even though we want the final output of our MLP to be a binary (0/1) value indicating the classification, we've only defined a single-neuron FC layer as the output. While training the network we'll pass that output through a sigmoid function as part of the loss criterion, which [gives better numerical stability](https://pytorch.org/docs/stable/generated/torch.nn.BCEWithLogitsLoss.html). Later on, after we're done training the network, we'll add a quantization node at the end before we export it to FINN."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Define Train and Test  Methods  <a id='train_test'></a>\n",
    "The train and test methods will use a `DataLoader`, which feeds the model with a new predefined batch of training data in each iteration, until the entire training data is fed to the model. Each repetition of this process is called an `epoch`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(model, train_loader, optimizer, criterion):\n",
    "    losses = []\n",
    "    # ensure model is in training mode\n",
    "    model.train()    \n",
    "    \n",
    "    for i, data in enumerate(train_loader, 0):        \n",
    "        inputs, target = data\n",
    "        inputs, target = inputs.to(device), target.to(device)\n",
    "        optimizer.zero_grad()   \n",
    "                \n",
    "        # forward pass\n",
    "        output = model(inputs.float())\n",
    "        loss = criterion(output, target.unsqueeze(1))\n",
    "        \n",
    "        # backward pass + run optimizer to update weights\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        \n",
    "        # keep track of loss value\n",
    "        losses.append(loss.data.cpu().numpy()) \n",
    "           \n",
    "    return losses"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from sklearn.metrics import accuracy_score\n",
    "\n",
    "def test(model, test_loader):    \n",
    "    # ensure model is in eval mode\n",
    "    model.eval() \n",
    "    y_true = []\n",
    "    y_pred = []\n",
    "   \n",
    "    with torch.no_grad():\n",
    "        for data in test_loader:\n",
    "            inputs, target = data\n",
    "            inputs, target = inputs.to(device), target.to(device)\n",
    "            output_orig = model(inputs.float())\n",
    "            # run the output through sigmoid\n",
    "            output = torch.sigmoid(output_orig)  \n",
    "            # compare against a threshold of 0.5 to generate 0/1\n",
    "            pred = (output.detach().cpu().numpy() > 0.5) * 1\n",
    "            target = target.cpu().float()\n",
    "            y_true.extend(target.tolist()) \n",
    "            y_pred.extend(pred.reshape(-1).tolist())\n",
    "        \n",
    "    return accuracy_score(y_true, y_pred)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Train the QNN <a id=\"train_qnn\"></a>\n",
    "\n",
    "We provide two options for training below: you can opt for training the model from scratch (slower) or use a pre-trained model (faster). The first option will give more insight into how the training process works, while the second option will likely give better accuracy."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## (Option 1, slower) Train the Model from Scratch <a id=\"train_scratch\"></a>\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Before we start training our MLP we need to define some hyperparameters. Moreover, in order to monitor the loss function evolution over epochs, we need to define a method for it. As mentioned earlier, we'll use a loss criterion which applies a sigmoid function during the training phase (`BCEWithLogitsLoss`). For the testing phase, we're manually computing the sigmoid and thresholding at 0.5."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_epochs = 10\n",
    "lr = 0.001 \n",
    "\n",
    "def display_loss_plot(losses, title=\"Training loss\", xlabel=\"Iterations\", ylabel=\"Loss\"):\n",
    "    x_axis = [i for i in range(len(losses))]\n",
    "    plt.plot(x_axis,losses)\n",
    "    plt.title(title)\n",
    "    plt.xlabel(xlabel)\n",
    "    plt.ylabel(ylabel)\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "# loss criterion and optimizer\n",
    "criterion = nn.BCEWithLogitsLoss().to(device)\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=(0.9, 0.999))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training loss = 0.132592 test accuracy = 0.808884: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [02:06<00:00, 12.67s/it]\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "from sklearn.metrics import accuracy_score\n",
    "from tqdm import tqdm, trange\n",
    "\n",
    "# Setting seeds for reproducibility\n",
    "torch.manual_seed(0)\n",
    "np.random.seed(0)\n",
    "\n",
    "running_loss = []\n",
    "running_test_acc = []\n",
    "t = trange(num_epochs, desc=\"Training loss\", leave=True)\n",
    "\n",
    "for epoch in t:\n",
    "        loss_epoch = train(model, train_quantized_loader, optimizer,criterion)\n",
    "        test_acc = test(model, test_quantized_loader)\n",
    "        t.set_description(\"Training loss = %f test accuracy = %f\" % (np.mean(loss_epoch), test_acc))\n",
    "        t.refresh() # to show immediately the update           \n",
    "        running_loss.append(loss_epoch)\n",
    "        running_test_acc.append(test_acc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "%matplotlib inline\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "loss_per_epoch = [np.mean(loss_per_epoch) for loss_per_epoch in running_loss]\n",
    "display_loss_plot(loss_per_epoch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "acc_per_epoch = [np.mean(acc_per_epoch) for acc_per_epoch in running_test_acc]\n",
    "display_loss_plot(acc_per_epoch, title=\"Test accuracy\", ylabel=\"Accuracy [%]\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.8088835446727882"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test(model, test_quantized_loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Save the Brevitas model to disk\n",
    "torch.save(model.state_dict(), \"state_dict_self-trained.pth\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## (Option 2, faster) Load Pre-Trained Parameters <a id=\"load_pretrained\"></a>\n",
    "\n",
    "Instead of training from scratch, you can also use pre-trained parameters we provide here. These parameters should achieve ~91.9% test accuracy."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "# Make sure the model is on CPU before loading a pretrained state_dict\n",
    "model = model.cpu()\n",
    "\n",
    "# Load pretrained weights\n",
    "trained_state_dict = torch.load(\"state_dict.pth\")[\"models_state_dict\"][0]\n",
    "\n",
    "model.load_state_dict(trained_state_dict, strict=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.9188772287810328"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Move the model back to it's target device\n",
    "model.to(device)\n",
    "\n",
    "# Test for accuracy\n",
    "test(model, test_quantized_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Why do these parameters give better accuracy vs training from scratch?** Even with the topology and quantization fixed, achieving good accuracy on a given dataset requires [*hyperparameter tuning*](https://towardsdatascience.com/hyperparameters-optimization-526348bb8e2d) and potentially running training for a long time. The \"training from scratch\" example above is only intended as a quick example, whereas the pretrained parameters are obtained from a longer training run using the [determined.ai](https://determined.ai/) platform for hyperparameter tuning."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Network Surgery Before Export <a id=\"network_surgery\"></a>\n",
    "\n",
    "Sometimes, it's desirable to make some changes to our trained network prior to export (this is known in general as \"network surgery\"). This depends on the model and is not generally necessary, but in this case we want to make a couple of changes to get better results with FINN."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Move the model to CPU before surgery\n",
    "model = model.cpu()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's start by padding the input. Our input vectors are 593-bit, which will make folding (parallelization) for the first layer a bit tricky since 593 is a prime number. So we'll pad the weight matrix of the first layer with seven 0-valued columns to work with an input size of 600 instead. When using the modified network we'll similarly provide inputs padded to 600 bits."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(64, 593)"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from copy import deepcopy\n",
    "\n",
    "modified_model = deepcopy(model)\n",
    "\n",
    "W_orig = modified_model[0].weight.data.detach().numpy()\n",
    "W_orig.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(64, 600)"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import numpy as np\n",
    "\n",
    "# pad the second (593-sized) dimensions with 7 zeroes at the end\n",
    "W_new = np.pad(W_orig, [(0,0), (0,7)])\n",
    "W_new.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([64, 600])"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "modified_model[0].weight.data = torch.from_numpy(W_new)\n",
    "modified_model[0].weight.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we'll modify the expected input/output ranges. In FINN, we prefer to work with bipolar {-1, +1} instead of binary {0, 1} values. To achieve this, we'll create a \"wrapper\" model that handles the pre/postprocessing as follows:\n",
    "\n",
    "* on the input side, we'll pre-process by (x + 1) / 2 in order to map incoming {-1, +1} inputs to {0, 1} ones which the trained network is used to. Since we're just multiplying/adding a scalar, these operations can be [*streamlined*](https://finn.readthedocs.io/en/latest/nw_prep.html#streamlining-transformations) by FINN and implemented with no extra cost.\n",
    "\n",
    "* on the output side, we'll add a binary quantizer which maps everthing below 0 to -1 and everything above 0 to +1. This is essentially the same behavior as the sigmoid we used earlier, except the outputs are bipolar instead of binary."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "CybSecMLPForExport(\n",
       "  (pretrained): Sequential(\n",
       "    (0): QuantLinear(\n",
       "      in_features=593, out_features=64, bias=True\n",
       "      (input_quant): ActQuantProxyFromInjector(\n",
       "        (_zero_hw_sentinel): StatelessBuffer()\n",
       "      )\n",
       "      (output_quant): ActQuantProxyFromInjector(\n",
       "        (_zero_hw_sentinel): StatelessBuffer()\n",
       "      )\n",
       "      (weight_quant): WeightQuantProxyFromInjector(\n",
       "        (_zero_hw_sentinel): StatelessBuffer()\n",
       "        (tensor_quant): RescalingIntQuant(\n",
       "          (int_quant): IntQuant(\n",
       "            (float_to_int_impl): RoundSte()\n",
       "            (tensor_clamp_impl): TensorClampSte()\n",
       "            (delay_wrapper): DelayWrapper(\n",
       "              (delay_impl): _NoDelay()\n",
       "            )\n",
       "          )\n",
       "          (scaling_impl): StatsFromParameterScaling(\n",
       "            (parameter_list_stats): _ParameterListStats(\n",
       "              (first_tracked_param): _ViewParameterWrapper(\n",
       "                (view_shape_impl): OverTensorView()\n",
       "              )\n",
       "              (stats): _Stats(\n",
       "                (stats_impl): AbsMax()\n",
       "              )\n",
       "            )\n",
       "            (stats_scaling_impl): _StatsScaling(\n",
       "              (affine_rescaling): Identity()\n",
       "              (restrict_clamp_scaling): _RestrictClampValue(\n",
       "                (clamp_min_ste): ScalarClampMinSte()\n",
       "                (restrict_value_impl): FloatRestrictValue()\n",
       "              )\n",
       "              (restrict_scaling_pre): Identity()\n",
       "            )\n",
       "          )\n",
       "          (int_scaling_impl): IntScaling()\n",
       "          (zero_point_impl): ZeroZeroPoint(\n",
       "            (zero_point): StatelessBuffer()\n",
       "          )\n",
       "          (msb_clamp_bit_width_impl): BitWidthConst(\n",
       "            (bit_width): StatelessBuffer()\n",
       "          )\n",
       "        )\n",
       "      )\n",
       "      (bias_quant): BiasQuantProxyFromInjector(\n",
       "        (_zero_hw_sentinel): StatelessBuffer()\n",
       "      )\n",
       "    )\n",
       "    (1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (2): Dropout(p=0.5, inplace=False)\n",
       "    (3): QuantReLU(\n",
       "      (input_quant): ActQuantProxyFromInjector(\n",
       "        (_zero_hw_sentinel): StatelessBuffer()\n",
       "      )\n",
       "      (act_quant): ActQuantProxyFromInjector(\n",
       "        (_zero_hw_sentinel): StatelessBuffer()\n",
       "        (fused_activation_quant_proxy): FusedActivationQuantProxy(\n",
       "          (activation_impl): ReLU()\n",
       "          (tensor_quant): RescalingIntQuant(\n",
       "            (int_quant): IntQuant(\n",
       "              (float_to_int_impl): RoundSte()\n",
       "              (tensor_clamp_impl): TensorClamp()\n",
       "              (delay_wrapper): DelayWrapper(\n",
       "                (delay_impl): _NoDelay()\n",
       "              )\n",
       "            )\n",
       "            (scaling_impl): ParameterFromRuntimeStatsScaling(\n",
       "              (stats_input_view_shape_impl): OverTensorView()\n",
       "              (stats): _Stats(\n",
       "                (stats_impl): AbsPercentile()\n",
       "              )\n",
       "              (restrict_clamp_scaling): _RestrictClampValue(\n",
       "                (clamp_min_ste): ScalarClampMinSte()\n",
       "                (restrict_value_impl): FloatRestrictValue()\n",
       "              )\n",
       "              (restrict_inplace_preprocess): Identity()\n",
       "              (restrict_preprocess): Identity()\n",
       "            )\n",
       "            (int_scaling_impl): IntScaling()\n",
       "            (zero_point_impl): ZeroZeroPoint(\n",
       "              (zero_point): StatelessBuffer()\n",
       "            )\n",
       "            (msb_clamp_bit_width_impl): BitWidthConst(\n",
       "              (bit_width): StatelessBuffer()\n",
       "            )\n",
       "          )\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "    (4): QuantLinear(\n",
       "      in_features=64, out_features=64, bias=True\n",
       "      (input_quant): ActQuantProxyFromInjector(\n",
       "        (_zero_hw_sentinel): StatelessBuffer()\n",
       "      )\n",
       "      (output_quant): ActQuantProxyFromInjector(\n",
       "        (_zero_hw_sentinel): StatelessBuffer()\n",
       "      )\n",
       "      (weight_quant): WeightQuantProxyFromInjector(\n",
       "        (_zero_hw_sentinel): StatelessBuffer()\n",
       "        (tensor_quant): RescalingIntQuant(\n",
       "          (int_quant): IntQuant(\n",
       "            (float_to_int_impl): RoundSte()\n",
       "            (tensor_clamp_impl): TensorClampSte()\n",
       "            (delay_wrapper): DelayWrapper(\n",
       "              (delay_impl): _NoDelay()\n",
       "            )\n",
       "          )\n",
       "          (scaling_impl): StatsFromParameterScaling(\n",
       "            (parameter_list_stats): _ParameterListStats(\n",
       "              (first_tracked_param): _ViewParameterWrapper(\n",
       "                (view_shape_impl): OverTensorView()\n",
       "              )\n",
       "              (stats): _Stats(\n",
       "                (stats_impl): AbsMax()\n",
       "              )\n",
       "            )\n",
       "            (stats_scaling_impl): _StatsScaling(\n",
       "              (affine_rescaling): Identity()\n",
       "              (restrict_clamp_scaling): _RestrictClampValue(\n",
       "                (clamp_min_ste): ScalarClampMinSte()\n",
       "                (restrict_value_impl): FloatRestrictValue()\n",
       "              )\n",
       "              (restrict_scaling_pre): Identity()\n",
       "            )\n",
       "          )\n",
       "          (int_scaling_impl): IntScaling()\n",
       "          (zero_point_impl): ZeroZeroPoint(\n",
       "            (zero_point): StatelessBuffer()\n",
       "          )\n",
       "          (msb_clamp_bit_width_impl): BitWidthConst(\n",
       "            (bit_width): StatelessBuffer()\n",
       "          )\n",
       "        )\n",
       "      )\n",
       "      (bias_quant): BiasQuantProxyFromInjector(\n",
       "        (_zero_hw_sentinel): StatelessBuffer()\n",
       "      )\n",
       "    )\n",
       "    (5): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (6): Dropout(p=0.5, inplace=False)\n",
       "    (7): QuantReLU(\n",
       "      (input_quant): ActQuantProxyFromInjector(\n",
       "        (_zero_hw_sentinel): StatelessBuffer()\n",
       "      )\n",
       "      (act_quant): ActQuantProxyFromInjector(\n",
       "        (_zero_hw_sentinel): StatelessBuffer()\n",
       "        (fused_activation_quant_proxy): FusedActivationQuantProxy(\n",
       "          (activation_impl): ReLU()\n",
       "          (tensor_quant): RescalingIntQuant(\n",
       "            (int_quant): IntQuant(\n",
       "              (float_to_int_impl): RoundSte()\n",
       "              (tensor_clamp_impl): TensorClamp()\n",
       "              (delay_wrapper): DelayWrapper(\n",
       "                (delay_impl): _NoDelay()\n",
       "              )\n",
       "            )\n",
       "            (scaling_impl): ParameterFromRuntimeStatsScaling(\n",
       "              (stats_input_view_shape_impl): OverTensorView()\n",
       "              (stats): _Stats(\n",
       "                (stats_impl): AbsPercentile()\n",
       "              )\n",
       "              (restrict_clamp_scaling): _RestrictClampValue(\n",
       "                (clamp_min_ste): ScalarClampMinSte()\n",
       "                (restrict_value_impl): FloatRestrictValue()\n",
       "              )\n",
       "              (restrict_inplace_preprocess): Identity()\n",
       "              (restrict_preprocess): Identity()\n",
       "            )\n",
       "            (int_scaling_impl): IntScaling()\n",
       "            (zero_point_impl): ZeroZeroPoint(\n",
       "              (zero_point): StatelessBuffer()\n",
       "            )\n",
       "            (msb_clamp_bit_width_impl): BitWidthConst(\n",
       "              (bit_width): StatelessBuffer()\n",
       "            )\n",
       "          )\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "    (8): QuantLinear(\n",
       "      in_features=64, out_features=64, bias=True\n",
       "      (input_quant): ActQuantProxyFromInjector(\n",
       "        (_zero_hw_sentinel): StatelessBuffer()\n",
       "      )\n",
       "      (output_quant): ActQuantProxyFromInjector(\n",
       "        (_zero_hw_sentinel): StatelessBuffer()\n",
       "      )\n",
       "      (weight_quant): WeightQuantProxyFromInjector(\n",
       "        (_zero_hw_sentinel): StatelessBuffer()\n",
       "        (tensor_quant): RescalingIntQuant(\n",
       "          (int_quant): IntQuant(\n",
       "            (float_to_int_impl): RoundSte()\n",
       "            (tensor_clamp_impl): TensorClampSte()\n",
       "            (delay_wrapper): DelayWrapper(\n",
       "              (delay_impl): _NoDelay()\n",
       "            )\n",
       "          )\n",
       "          (scaling_impl): StatsFromParameterScaling(\n",
       "            (parameter_list_stats): _ParameterListStats(\n",
       "              (first_tracked_param): _ViewParameterWrapper(\n",
       "                (view_shape_impl): OverTensorView()\n",
       "              )\n",
       "              (stats): _Stats(\n",
       "                (stats_impl): AbsMax()\n",
       "              )\n",
       "            )\n",
       "            (stats_scaling_impl): _StatsScaling(\n",
       "              (affine_rescaling): Identity()\n",
       "              (restrict_clamp_scaling): _RestrictClampValue(\n",
       "                (clamp_min_ste): ScalarClampMinSte()\n",
       "                (restrict_value_impl): FloatRestrictValue()\n",
       "              )\n",
       "              (restrict_scaling_pre): Identity()\n",
       "            )\n",
       "          )\n",
       "          (int_scaling_impl): IntScaling()\n",
       "          (zero_point_impl): ZeroZeroPoint(\n",
       "            (zero_point): StatelessBuffer()\n",
       "          )\n",
       "          (msb_clamp_bit_width_impl): BitWidthConst(\n",
       "            (bit_width): StatelessBuffer()\n",
       "          )\n",
       "        )\n",
       "      )\n",
       "      (bias_quant): BiasQuantProxyFromInjector(\n",
       "        (_zero_hw_sentinel): StatelessBuffer()\n",
       "      )\n",
       "    )\n",
       "    (9): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (10): Dropout(p=0.5, inplace=False)\n",
       "    (11): QuantReLU(\n",
       "      (input_quant): ActQuantProxyFromInjector(\n",
       "        (_zero_hw_sentinel): StatelessBuffer()\n",
       "      )\n",
       "      (act_quant): ActQuantProxyFromInjector(\n",
       "        (_zero_hw_sentinel): StatelessBuffer()\n",
       "        (fused_activation_quant_proxy): FusedActivationQuantProxy(\n",
       "          (activation_impl): ReLU()\n",
       "          (tensor_quant): RescalingIntQuant(\n",
       "            (int_quant): IntQuant(\n",
       "              (float_to_int_impl): RoundSte()\n",
       "              (tensor_clamp_impl): TensorClamp()\n",
       "              (delay_wrapper): DelayWrapper(\n",
       "                (delay_impl): _NoDelay()\n",
       "              )\n",
       "            )\n",
       "            (scaling_impl): ParameterFromRuntimeStatsScaling(\n",
       "              (stats_input_view_shape_impl): OverTensorView()\n",
       "              (stats): _Stats(\n",
       "                (stats_impl): AbsPercentile()\n",
       "              )\n",
       "              (restrict_clamp_scaling): _RestrictClampValue(\n",
       "                (clamp_min_ste): ScalarClampMinSte()\n",
       "                (restrict_value_impl): FloatRestrictValue()\n",
       "              )\n",
       "              (restrict_inplace_preprocess): Identity()\n",
       "              (restrict_preprocess): Identity()\n",
       "            )\n",
       "            (int_scaling_impl): IntScaling()\n",
       "            (zero_point_impl): ZeroZeroPoint(\n",
       "              (zero_point): StatelessBuffer()\n",
       "            )\n",
       "            (msb_clamp_bit_width_impl): BitWidthConst(\n",
       "              (bit_width): StatelessBuffer()\n",
       "            )\n",
       "          )\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "    (12): QuantLinear(\n",
       "      in_features=64, out_features=1, bias=True\n",
       "      (input_quant): ActQuantProxyFromInjector(\n",
       "        (_zero_hw_sentinel): StatelessBuffer()\n",
       "      )\n",
       "      (output_quant): ActQuantProxyFromInjector(\n",
       "        (_zero_hw_sentinel): StatelessBuffer()\n",
       "      )\n",
       "      (weight_quant): WeightQuantProxyFromInjector(\n",
       "        (_zero_hw_sentinel): StatelessBuffer()\n",
       "        (tensor_quant): RescalingIntQuant(\n",
       "          (int_quant): IntQuant(\n",
       "            (float_to_int_impl): RoundSte()\n",
       "            (tensor_clamp_impl): TensorClampSte()\n",
       "            (delay_wrapper): DelayWrapper(\n",
       "              (delay_impl): _NoDelay()\n",
       "            )\n",
       "          )\n",
       "          (scaling_impl): StatsFromParameterScaling(\n",
       "            (parameter_list_stats): _ParameterListStats(\n",
       "              (first_tracked_param): _ViewParameterWrapper(\n",
       "                (view_shape_impl): OverTensorView()\n",
       "              )\n",
       "              (stats): _Stats(\n",
       "                (stats_impl): AbsMax()\n",
       "              )\n",
       "            )\n",
       "            (stats_scaling_impl): _StatsScaling(\n",
       "              (affine_rescaling): Identity()\n",
       "              (restrict_clamp_scaling): _RestrictClampValue(\n",
       "                (clamp_min_ste): ScalarClampMinSte()\n",
       "                (restrict_value_impl): FloatRestrictValue()\n",
       "              )\n",
       "              (restrict_scaling_pre): Identity()\n",
       "            )\n",
       "          )\n",
       "          (int_scaling_impl): IntScaling()\n",
       "          (zero_point_impl): ZeroZeroPoint(\n",
       "            (zero_point): StatelessBuffer()\n",
       "          )\n",
       "          (msb_clamp_bit_width_impl): BitWidthConst(\n",
       "            (bit_width): StatelessBuffer()\n",
       "          )\n",
       "        )\n",
       "      )\n",
       "      (bias_quant): BiasQuantProxyFromInjector(\n",
       "        (_zero_hw_sentinel): StatelessBuffer()\n",
       "      )\n",
       "    )\n",
       "  )\n",
       "  (qnt_output): QuantIdentity(\n",
       "    (input_quant): ActQuantProxyFromInjector(\n",
       "      (_zero_hw_sentinel): StatelessBuffer()\n",
       "    )\n",
       "    (act_quant): ActQuantProxyFromInjector(\n",
       "      (_zero_hw_sentinel): StatelessBuffer()\n",
       "      (fused_activation_quant_proxy): FusedActivationQuantProxy(\n",
       "        (activation_impl): Identity()\n",
       "        (tensor_quant): ClampedBinaryQuant(\n",
       "          (scaling_impl): ParameterFromRuntimeStatsScaling(\n",
       "            (stats_input_view_shape_impl): OverTensorView()\n",
       "            (stats): _Stats(\n",
       "              (stats_impl): AbsPercentile()\n",
       "            )\n",
       "            (restrict_clamp_scaling): _RestrictClampValue(\n",
       "              (clamp_min_ste): ScalarClampMinSte()\n",
       "              (restrict_value_impl): FloatRestrictValue()\n",
       "            )\n",
       "            (restrict_inplace_preprocess): Identity()\n",
       "            (restrict_preprocess): Identity()\n",
       "          )\n",
       "          (bit_width): BitWidthConst(\n",
       "            (bit_width): StatelessBuffer()\n",
       "          )\n",
       "          (zero_point): StatelessBuffer()\n",
       "          (delay_wrapper): DelayWrapper(\n",
       "            (delay_impl): _NoDelay()\n",
       "          )\n",
       "          (tensor_clamp_impl): TensorClamp()\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from brevitas.core.quant import QuantType\n",
    "from brevitas.nn import QuantIdentity\n",
    "\n",
    "\n",
    "class CybSecMLPForExport(nn.Module):\n",
    "    def __init__(self, my_pretrained_model):\n",
    "        super(CybSecMLPForExport, self).__init__()\n",
    "        self.pretrained = my_pretrained_model\n",
    "        self.qnt_output = QuantIdentity(quant_type=QuantType.BINARY, bit_width=1, min_val=-1.0, max_val=1.0)\n",
    "    \n",
    "    def forward(self, x):\n",
    "        # assume x contains bipolar {-1,1} elems\n",
    "        # shift from {-1,1} -> {0,1} since that is the\n",
    "        # input range for the trained network\n",
    "        x = (x + torch.tensor([1.0]).to(x.device)) / 2.0  \n",
    "        out_original = self.pretrained(x)\n",
    "        out_final = self.qnt_output(out_original)   # output as {-1,1}     \n",
    "        return out_final\n",
    "\n",
    "model_for_export = CybSecMLPForExport(modified_model)\n",
    "model_for_export.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test_padded_bipolar(model, test_loader):    \n",
    "    # ensure model is in eval mode\n",
    "    model.eval() \n",
    "    y_true = []\n",
    "    y_pred = []\n",
    "   \n",
    "    with torch.no_grad():\n",
    "        for data in test_loader:\n",
    "            inputs, target = data\n",
    "            inputs, target = inputs.to(device), target.to(device)\n",
    "            # pad inputs to 600 elements\n",
    "            input_padded = torch.nn.functional.pad(inputs, (0,7,0,0))\n",
    "            # convert inputs to {-1,+1}\n",
    "            input_scaled = 2 * input_padded - 1\n",
    "            # run the model\n",
    "            output = model(input_scaled.float())\n",
    "            y_pred.extend(list(output.flatten().cpu().numpy()))\n",
    "            # make targets bipolar {-1,+1}\n",
    "            expected = 2 * target.float() - 1\n",
    "            expected = expected.cpu().numpy()\n",
    "            y_true.extend(list(expected.flatten()))\n",
    "        \n",
    "    return accuracy_score(y_true, y_pred)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.9188772287810328"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_padded_bipolar(model_for_export, test_quantized_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Export to FINN-ONNX <a id=\"export_finn_onnx\" ></a>\n",
    "\n",
    "\n",
    "[ONNX](https://onnx.ai/) is an open format built to represent machine learning models, and the FINN compiler expects an ONNX model as input. We'll now export our network into ONNX to be imported and used in FINN for the next notebooks. Note that the particular ONNX representation used for FINN differs from standard ONNX, you can read more about this [here](https://finn.readthedocs.io/en/latest/internals.html#intermediate-representation-finn-onnx).\n",
    "\n",
    "You can see below how we export a trained network in Brevitas into a FINN-compatible ONNX representation. Note how we create a `QuantTensor` instance with dummy data to tell Brevitas how our inputs look like, which will be used to set the input quantization annotation on the exported model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model saved to cybsec-mlp-ready.onnx\n"
     ]
    }
   ],
   "source": [
    "import brevitas.onnx as bo\n",
    "from brevitas.quant_tensor import QuantTensor\n",
    "\n",
    "ready_model_filename = \"cybsec-mlp-ready.onnx\"\n",
    "input_shape = (1, 600)\n",
    "\n",
    "# create a QuantTensor instance to mark input as bipolar during export\n",
    "input_a = np.random.randint(0, 1, size=input_shape).astype(np.float32)\n",
    "input_a = 2 * input_a - 1\n",
    "scale = 1.0\n",
    "input_t = torch.from_numpy(input_a * scale)\n",
    "input_qt = QuantTensor(\n",
    "    input_t, scale=torch.tensor(scale), bit_width=torch.tensor(1.0), signed=True\n",
    ")\n",
    "\n",
    "#Move to CPU before export\n",
    "model_for_export.cpu()\n",
    "\n",
    "# Export to ONNX\n",
    "bo.export_finn_onnx(\n",
    "    model_for_export, export_path=ready_model_filename, input_t=input_qt\n",
    ")\n",
    "\n",
    "print(\"Model saved to %s\" % ready_model_filename)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## View the Exported ONNX in Netron\n",
    "\n",
    "Let's examine the exported ONNX model with [Netron](https://github.com/lutzroeder/netron), which is a visualizer for neural networks and allows interactive investigation of network properties. For example, you can click on the individual nodes and view the properties. Particular things of note:\n",
    "\n",
    "* The input tensor \"0\" is annotated with `quantization: finn_datatype: BIPOLAR`\n",
    "* The input preprocessing (x + 1) / 2 is exported as part of the network (initial `Add` and `Div` layers)\n",
    "* Brevitas `QuantLinear` layers are exported to ONNX as `MatMul`. We've exported the padded version; shape of the first MatMul node's weight parameter is 600x64\n",
    "* The weight parameters (second inputs) for MatMul nodes are annotated with `quantization: finn_datatype: INT2`\n",
    "* The quantized activations are exported as `MultiThreshold` nodes with `domain=finn.custom_op.general`\n",
    "* There's a final `MultiThreshold` node with threshold=0 to produce the final bipolar output (this is the `qnt_output` from `CybSecMLPForExport`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from finn.util.visualization import showInNetron\n",
    "\n",
    "showInNetron(ready_model_filename)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## That's it! <a id=\"thats_it\" ></a>\n",
    "You created, trained and tested a quantized MLP that is ready to be loaded into FINN, congratulations! You can now proceed to the next notebook."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}