From 6be464a40e53ef891da95e062e8882c6bdd821af Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <yamanu@xilinx.com>
Date: Mon, 8 Feb 2021 19:49:33 +0100
Subject: [PATCH] [Notebook] add option to use prequantized dataset for cybsec
 MLP

---
 .../1-train-mlp-with-brevitas.ipynb           | 236 ++++++++++++++----
 .../2-export-to-finn-and-verify.ipynb         |  54 ++--
 2 files changed, 214 insertions(+), 76 deletions(-)

diff --git a/notebooks/end2end_example/cybersecurity/1-train-mlp-with-brevitas.ipynb b/notebooks/end2end_example/cybersecurity/1-train-mlp-with-brevitas.ipynb
index 84dde835d..d53a07ae8 100644
--- a/notebooks/end2end_example/cybersecurity/1-train-mlp-with-brevitas.ipynb
+++ b/notebooks/end2end_example/cybersecurity/1-train-mlp-with-brevitas.ipynb
@@ -40,15 +40,15 @@
     "-------------\n",
     "\n",
     "* [Initial setup](#initial_setup)\n",
-    "* [Define the Quantized MLP model](#define_quantized_mlp)\n",
     "* [Load the UNSW_NB15 dataset](#load_dataset) \n",
+    "    * [(Option 1, slower) Download original and quantize](#dataset_qnt_manual)\n",
+    "    * [(Option 2, faster) Use prequantized version](#dataset_qnt_pre)\n",
+    "* [Define the Quantized MLP model](#define_quantized_mlp)\n",
     "* [Define Train and Test  Methods](#train_test)\n",
-    "* [(Option 1) Train the Model from Scratch](#train_scratch)\n",
-    "* [(Option 2) Load Pre-Trained Parameters](#load_pretrained)\n",
+    "    * [(Option 1) Train the Model from Scratch](#train_scratch)\n",
+    "    * [(Option 2) Load Pre-Trained Parameters](#load_pretrained)\n",
     "* [Network Surgery Before Export](#network_surgery)\n",
-    "* [Export to FINN-ONNX](#export_finn_onnx)\n",
-    "* [View the Exported ONNX in Netron](#view_in_netron)\n",
-    "* [That's it!](#thats_it)"
+    "* [Export to FINN-ONNX](#export_finn_onnx)"
    ]
   },
   {
@@ -77,7 +77,17 @@
     "### Dataset Quantization <a id='dataset_qnt'></a>\n",
     "\n",
     "The goal of this notebook is to train a Quantized Neural Network (QNN) to be later deployed as an FPGA accelerator generated by the FINN compiler. Although we can choose a variety of different precisions for the input, [Murovic and Trost](https://ev.fe.uni-lj.si/1-2-2019/Murovic.pdf) have previously shown we can actually binarize the inputs and still get good (90%+) accuracy.\n",
-    "Thus, we will create a binarized representation for the dataset by following the procedure defined by [Murovic and Trost](https://ev.fe.uni-lj.si/1-2-2019/Murovic.pdf), which we repeat briefly here:\n",
+    "\n",
+    "For this part we offer two options: you may choose to download the original dataset and apply the quantization functions on it (slower), or use a pre-quantized version (faster). Both options will yield the same result, though the first option gives more details on how we quantize the dataset."
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## (Option 1, slower) Download original and quantize <a id='dataset_qnt_manual'></a>\n",
+    "\n",
+    "We will create a binarized representation for the dataset by following the procedure defined by [Murovic and Trost](https://ev.fe.uni-lj.si/1-2-2019/Murovic.pdf), which we repeat briefly here:\n",
     "\n",
     "* Original features have different formats ranging from integers, floating numbers to strings.\n",
     "* Integers, which for example represent a packet lifetime, are binarized with as many bits as to include the maximum value. \n",
@@ -86,7 +96,7 @@
     "* In the end, each sample is transformed into a 593-bit wide binary vector. \n",
     "* All vectors are labeled as bad (0) or normal (1)\n",
     "\n",
-    "Following their open-source implementation provided as a Matlab script [here](https://github.com/TadejMurovic/BNN_Deployment/blob/master/cybersecurity_dataset_unswb15.m), we've created a [Python version](dataloader_quantized.py).\n",
+    "Following Murovic and Trost's open-source implementation provided as a Matlab script [here](https://github.com/TadejMurovic/BNN_Deployment/blob/master/cybersecurity_dataset_unswb15.m), we've created a [Python version](dataloader_quantized.py).\n",
     "This `UNSW_NB15_quantized` class implements a PyTorch `DataLoader`, which represents a Python iterable over a dataset. This is useful because enables access to data in batches."
    ]
   },
@@ -111,7 +121,16 @@
    "cell_type": "code",
    "execution_count": 3,
    "metadata": {},
-   "outputs": [],
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Samples in each set: train = 175341, test = 82332\n",
+      "Shape of one input sample: torch.Size([593])\n"
+     ]
+    }
+   ],
    "source": [
     "from torch.utils.data import DataLoader, Dataset\n",
     "from dataloader_quantized import UNSW_NB15_quantized\n",
@@ -125,7 +144,19 @@
     "\n",
     "test_quantized_dataset = UNSW_NB15_quantized(file_path_train = file_path_train, \\\n",
     "                                              file_path_test = file_path_test, \\\n",
-    "                                              train=False)"
+    "                                              train=False)\n",
+    "\n",
+    "print(\"Samples in each set: train = %d, test = %s\" % (len(train_quantized_dataset), len(test_quantized_dataset)))\n",
+    "print(\"Shape of one input sample: \" +  str(train_quantized_dataset[0][0].shape))"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## (Option 2, faster) Use prequantized version <a id='dataset_qnt_pre'></a>\n",
+    "\n",
+    "Downloading the original dataset and quantizing it can take some time, so we provide a pre-quantized version for your convenience. Uncomment the following line to download:"
    ]
   },
   {
@@ -134,21 +165,109 @@
    "metadata": {},
    "outputs": [],
    "source": [
+    "# ! wget -O unsw_nb15_binarized.npz https://zenodo.org/record/4519767/files/unsw_nb15_binarized.npz?download=1"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "We can extract the binarized numpy arrays from the .npz archive and wrap them as a PyTorch `TensorDataset` as follows:"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 5,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Samples in each set: train = 175341, test = 82332\n",
+      "Shape of one input sample: torch.Size([593])\n"
+     ]
+    }
+   ],
+   "source": [
+    "import numpy as np\n",
+    "from torch.utils.data import TensorDataset\n",
+    "\n",
+    "def get_preqnt_dataset(data_dir: str, train: bool):\n",
+    "    unsw_nb15_data = np.load(data_dir + \"/unsw_nb15_binarized.npz\")\n",
+    "    if train:\n",
+    "        partition = \"train\"\n",
+    "    else:\n",
+    "        partition = \"test\"\n",
+    "    part_data = unsw_nb15_data[partition].astype(np.float32)\n",
+    "    part_data = torch.from_numpy(part_data)\n",
+    "    part_data_in = part_data[:, :-1]\n",
+    "    part_data_out = part_data[:, -1]\n",
+    "    return TensorDataset(part_data_in, part_data_out)\n",
+    "\n",
+    "train_quantized_dataset = get_preqnt_dataset(\".\", True)\n",
+    "test_quantized_dataset = get_preqnt_dataset(\".\", False)\n",
+    "\n",
+    "print(\"Samples in each set: train = %d, test = %s\" % (len(train_quantized_dataset), len(test_quantized_dataset))) \n",
+    "print(\"Shape of one input sample: \" +  str(train_quantized_dataset[0][0].shape))"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Set up DataLoader\n",
+    "\n",
+    "Following either option, we now have access to the quantized dataset. We will wrap the dataset in a PyTorch `DataLoader` for easier access in batches."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 6,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from torch.utils.data import DataLoader, Dataset\n",
+    "\n",
     "batch_size = 1000\n",
     "\n",
     "# dataset loaders\n",
     "train_quantized_loader = DataLoader(train_quantized_dataset, batch_size=batch_size, shuffle=True)\n",
-    "test_quantized_loader = DataLoader(test_quantized_dataset, batch_size=batch_size, shuffle=True)    "
+    "test_quantized_loader = DataLoader(test_quantized_dataset, batch_size=batch_size, shuffle=False)    "
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 7,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Input shape for 1 batch: torch.Size([1000, 593])\n",
+      "Label shape for 1 batch: torch.Size([1000])\n"
+     ]
+    }
+   ],
+   "source": [
+    "count = 0\n",
+    "for x,y in train_quantized_loader:\n",
+    "    print(\"Input shape for 1 batch: \" + str(x.shape))\n",
+    "    print(\"Label shape for 1 batch: \" + str(y.shape))\n",
+    "    count += 1\n",
+    "    if count == 1:\n",
+    "        break"
    ]
   },
   {
    "cell_type": "markdown",
    "metadata": {},
    "source": [
-    "## Define the Quantized MLP Model <a id='define_quantized_mlp'></a>\n",
+    "# Define the Quantized MLP Model <a id='define_quantized_mlp'></a>\n",
     "\n",
     "We'll now define an MLP model that will be trained to perform inference with quantized weights and activations.\n",
-    "For this, we'll use the quantization-aware training (QAT) capabilities offered by[Brevitas](https://github.com/Xilinx/brevitas).\n",
+    "For this, we'll use the quantization-aware training (QAT) capabilities offered by [Brevitas](https://github.com/Xilinx/brevitas).\n",
     "\n",
     "Our MLP will have four fully-connected (FC) layers in total: three hidden layers with 64 neurons, and a final output layer with a single output, all using 2-bit weights. We'll use 2-bit quantized ReLU activation functions, and apply batch normalization between each FC layer and its activation.\n",
     "\n",
@@ -157,7 +276,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 5,
+   "execution_count": 8,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -179,7 +298,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 6,
+   "execution_count": 9,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -214,13 +333,13 @@
    "cell_type": "markdown",
    "metadata": {},
    "source": [
-    "## Define Train and Test  Methods  <a id='train_test'></a>\n",
+    "# Define Train and Test  Methods  <a id='train_test'></a>\n",
     "The train and test methods will use a `DataLoader`, which feeds the model with a new predefined batch of training data in each iteration, until the entire training data is fed to the model. Each repetition of this process is called an `epoch`."
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 7,
+   "execution_count": 10,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -249,7 +368,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 8,
+   "execution_count": 11,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -281,7 +400,16 @@
    "cell_type": "markdown",
    "metadata": {},
    "source": [
-    "## (Option 1) Train the Model from Scratch <a id=\"train_scratch\"></a>\n"
+    "# Train the QNN <a id=\"train_qnn\"></a>\n",
+    "\n",
+    "We provide two options for training below: you can opt for training the model from scratch (slower) or use a pre-trained model (faster). The first option will give more insight into how the training process works, while the second option will likely give better accuracy."
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## (Option 1, slower) Train the Model from Scratch <a id=\"train_scratch\"></a>\n"
    ]
   },
   {
@@ -293,7 +421,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 9,
+   "execution_count": 12,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -311,7 +439,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 10,
+   "execution_count": 13,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -322,7 +450,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 11,
+   "execution_count": 14,
    "metadata": {
     "scrolled": true
    },
@@ -331,7 +459,7 @@
      "name": "stderr",
      "output_type": "stream",
      "text": [
-      "Training loss = 0.132304 test accuracy = 0.808811: 100%|██████████| 15/15 [04:52<00:00, 19.47s/it]\n"
+      "Training loss = 0.130799 test accuracy = 0.794551: 100%|██████████| 15/15 [01:25<00:00,  5.66s/it]\n"
      ]
     }
    ],
@@ -359,14 +487,14 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 12,
+   "execution_count": 15,
    "metadata": {
     "scrolled": true
    },
    "outputs": [
     {
      "data": {
-      "image/png": "\n",
+      "image/png": "\n",
       "text/plain": [
        "<Figure size 432x288 with 1 Axes>"
       ]
@@ -387,12 +515,12 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 13,
+   "execution_count": 16,
    "metadata": {},
    "outputs": [
     {
      "data": {
-      "image/png": "\n",
+      "image/png": "\n",
       "text/plain": [
        "<Figure size 432x288 with 1 Axes>"
       ]
@@ -410,16 +538,16 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 14,
+   "execution_count": 17,
    "metadata": {},
    "outputs": [
     {
      "data": {
       "text/plain": [
-       "0.8088106689986883"
+       "0.7945513287664577"
       ]
      },
-     "execution_count": 14,
+     "execution_count": 17,
      "metadata": {},
      "output_type": "execute_result"
     }
@@ -430,7 +558,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 15,
+   "execution_count": 18,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -442,14 +570,14 @@
    "cell_type": "markdown",
    "metadata": {},
    "source": [
-    "## (Option 2) Load Pre-Trained Parameters <a id=\"load_pretrained\"></a>\n",
+    "## (Option 2, faster) Load Pre-Trained Parameters <a id=\"load_pretrained\"></a>\n",
     "\n",
     "Instead of training from scratch, you can also use pre-trained parameters we provide here. These parameters should achieve ~91.9% test accuracy."
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 16,
+   "execution_count": 19,
    "metadata": {},
    "outputs": [
     {
@@ -458,7 +586,7 @@
        "IncompatibleKeys(missing_keys=[], unexpected_keys=[])"
       ]
      },
-     "execution_count": 16,
+     "execution_count": 19,
      "metadata": {},
      "output_type": "execute_result"
     }
@@ -473,7 +601,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 17,
+   "execution_count": 20,
    "metadata": {
     "scrolled": true
    },
@@ -484,7 +612,7 @@
        "0.9188772287810328"
       ]
      },
-     "execution_count": 17,
+     "execution_count": 20,
      "metadata": {},
      "output_type": "execute_result"
     }
@@ -497,7 +625,7 @@
    "cell_type": "markdown",
    "metadata": {},
    "source": [
-    "## Network Surgery Before Export <a id=\"network_surgery\"></a>\n",
+    "# Network Surgery Before Export <a id=\"network_surgery\"></a>\n",
     "\n",
     "Sometimes, it's desirable to make some changes to our trained network prior to export (this is known in general as \"network surgery\"). This depends on the model and is not generally necessary, but in this case we want to make a couple of changes to get better results with FINN."
    ]
@@ -511,7 +639,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 18,
+   "execution_count": 21,
    "metadata": {},
    "outputs": [
     {
@@ -520,7 +648,7 @@
        "(64, 593)"
       ]
      },
-     "execution_count": 18,
+     "execution_count": 21,
      "metadata": {},
      "output_type": "execute_result"
     }
@@ -536,7 +664,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 19,
+   "execution_count": 22,
    "metadata": {},
    "outputs": [
     {
@@ -545,7 +673,7 @@
        "(64, 600)"
       ]
      },
-     "execution_count": 19,
+     "execution_count": 22,
      "metadata": {},
      "output_type": "execute_result"
     }
@@ -560,7 +688,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 20,
+   "execution_count": 23,
    "metadata": {},
    "outputs": [
     {
@@ -569,7 +697,7 @@
        "torch.Size([64, 600])"
       ]
      },
-     "execution_count": 20,
+     "execution_count": 23,
      "metadata": {},
      "output_type": "execute_result"
     }
@@ -592,7 +720,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 21,
+   "execution_count": 24,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -620,7 +748,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 22,
+   "execution_count": 25,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -650,7 +778,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 23,
+   "execution_count": 26,
    "metadata": {},
    "outputs": [
     {
@@ -659,7 +787,7 @@
        "0.9188772287810328"
       ]
      },
-     "execution_count": 23,
+     "execution_count": 26,
      "metadata": {},
      "output_type": "execute_result"
     }
@@ -672,14 +800,14 @@
    "cell_type": "markdown",
    "metadata": {},
    "source": [
-    "## Export to FINN-ONNX <a id=\"export_finn_onnx\" ></a>\n",
+    "# Export to FINN-ONNX <a id=\"export_finn_onnx\" ></a>\n",
     "\n",
     "FINN expects an ONNX model as input. We'll now export our network into ONNX to be imported and used in FINN for the next notebooks. Note that the particular ONNX representation used for FINN differs from standard ONNX, you can read more about this [here](https://finn.readthedocs.io/en/latest/internals.html#intermediate-representation-finn-onnx)."
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 24,
+   "execution_count": 27,
    "metadata": {
     "scrolled": true
    },
@@ -714,7 +842,7 @@
    "cell_type": "markdown",
    "metadata": {},
    "source": [
-    "## View the Exported ONNX in Netron <a id=\"view_in_netron\" ></a>\n",
+    "## View the Exported ONNX in Netron\n",
     "\n",
     "Let's examine the exported ONNX model with Netron. Particular things of note:\n",
     "\n",
@@ -727,7 +855,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 24,
+   "execution_count": 28,
    "metadata": {},
    "outputs": [
     {
@@ -744,17 +872,17 @@
        "        <iframe\n",
        "            width=\"100%\"\n",
        "            height=\"400\"\n",
-       "            src=\"http://0.0.0.0:8081/\"\n",
+       "            src=\"http://localhost:8081/\"\n",
        "            frameborder=\"0\"\n",
        "            allowfullscreen\n",
        "        ></iframe>\n",
        "        "
       ],
       "text/plain": [
-       "<IPython.lib.display.IFrame at 0x7f3a74f3a5c0>"
+       "<IPython.lib.display.IFrame at 0x7f8a808b3b70>"
       ]
      },
-     "execution_count": 24,
+     "execution_count": 28,
      "metadata": {},
      "output_type": "execute_result"
     }
diff --git a/notebooks/end2end_example/cybersecurity/2-export-to-finn-and-verify.ipynb b/notebooks/end2end_example/cybersecurity/2-export-to-finn-and-verify.ipynb
index 358615dbe..0f69019a3 100644
--- a/notebooks/end2end_example/cybersecurity/2-export-to-finn-and-verify.ipynb
+++ b/notebooks/end2end_example/cybersecurity/2-export-to-finn-and-verify.ipynb
@@ -93,14 +93,14 @@
        "        <iframe\n",
        "            width=\"100%\"\n",
        "            height=\"400\"\n",
-       "            src=\"http://0.0.0.0:8081/\"\n",
+       "            src=\"http://localhost:8081/\"\n",
        "            frameborder=\"0\"\n",
        "            allowfullscreen\n",
        "        ></iframe>\n",
        "        "
       ],
       "text/plain": [
-       "<IPython.lib.display.IFrame at 0x7fc1fc950748>"
+       "<IPython.lib.display.IFrame at 0x7fe10c830e48>"
       ]
      },
      "execution_count": 3,
@@ -124,7 +124,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 3,
+   "execution_count": 4,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -145,6 +145,8 @@
    "cell_type": "markdown",
    "metadata": {},
    "source": [
+    "**Would the FINN compiler still work if we didn't do this?** The compilation step in the next notebook applies these transformations internally and would work fine, but we're going to use FINN's verification capabilities below and these require the tidy-up transformations.\n",
+    "\n",
     "There's one more thing we'll do: we will mark the input tensor datatype as bipolar, which will be used by the compiler later on. \n",
     "\n",
     "*In the near future it will be possible to add this information to the model while exporting, instead of having to add it manually.*"
@@ -152,7 +154,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 4,
+   "execution_count": 5,
    "metadata": {},
    "outputs": [
     {
@@ -213,14 +215,14 @@
        "        <iframe\n",
        "            width=\"100%\"\n",
        "            height=\"400\"\n",
-       "            src=\"http://0.0.0.0:8081/\"\n",
+       "            src=\"http://localhost:8081/\"\n",
        "            frameborder=\"0\"\n",
        "            allowfullscreen\n",
        "        ></iframe>\n",
        "        "
       ],
       "text/plain": [
-       "<IPython.lib.display.IFrame at 0x7fc280154278>"
+       "<IPython.lib.display.IFrame at 0x7fe10a956e80>"
       ]
      },
      "execution_count": 6,
@@ -245,7 +247,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 5,
+   "execution_count": 7,
    "metadata": {},
    "outputs": [
     {
@@ -254,22 +256,30 @@
        "torch.Size([100, 593])"
       ]
      },
-     "execution_count": 5,
+     "execution_count": 7,
      "metadata": {},
      "output_type": "execute_result"
     }
    ],
    "source": [
-    "from torch.utils.data import DataLoader, Dataset\n",
-    "from dataloader_quantized import UNSW_NB15_quantized\n",
+    "import numpy as np\n",
+    "from torch.utils.data import TensorDataset\n",
     "\n",
-    "test_quantized_dataset = UNSW_NB15_quantized(file_path_train='UNSW_NB15_training-set.csv', \\\n",
-    "                                              file_path_test = \"UNSW_NB15_testing-set.csv\", \\\n",
-    "                                              train=False)\n",
+    "def get_preqnt_dataset(data_dir: str, train: bool):\n",
+    "    unsw_nb15_data = np.load(data_dir + \"/unsw_nb15_binarized.npz\")\n",
+    "    if train:\n",
+    "        partition = \"train\"\n",
+    "    else:\n",
+    "        partition = \"test\"\n",
+    "    part_data = unsw_nb15_data[partition].astype(np.float32)\n",
+    "    part_data = torch.from_numpy(part_data)\n",
+    "    part_data_in = part_data[:, :-1]\n",
+    "    part_data_out = part_data[:, -1]\n",
+    "    return TensorDataset(part_data_in, part_data_out)\n",
     "\n",
     "n_verification_inputs = 100\n",
-    "# last column is the label, exclude it\n",
-    "input_tensor = test_quantized_dataset.data[:n_verification_inputs,:-1]\n",
+    "test_quantized_dataset = get_preqnt_dataset(\".\", False)\n",
+    "input_tensor = test_quantized_dataset.tensors[0][:n_verification_inputs]\n",
     "input_tensor.shape"
    ]
   },
@@ -282,7 +292,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 6,
+   "execution_count": 8,
    "metadata": {},
    "outputs": [
     {
@@ -291,7 +301,7 @@
        "IncompatibleKeys(missing_keys=[], unexpected_keys=[])"
       ]
      },
-     "execution_count": 6,
+     "execution_count": 8,
      "metadata": {},
      "output_type": "execute_result"
     }
@@ -335,7 +345,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 7,
+   "execution_count": 9,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -365,7 +375,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 8,
+   "execution_count": 10,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -395,14 +405,14 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 9,
+   "execution_count": 11,
    "metadata": {},
    "outputs": [
     {
      "name": "stderr",
      "output_type": "stream",
      "text": [
-      "ok 100 nok 0: 100%|██████████| 100/100 [00:46<00:00,  2.17it/s]\n"
+      "ok 100 nok 0: 100%|██████████| 100/100 [00:47<00:00,  2.11it/s]\n"
      ]
     }
    ],
@@ -431,7 +441,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 10,
+   "execution_count": 12,
    "metadata": {},
    "outputs": [
     {
-- 
GitLab