diff --git a/notebooks/brevitas-network-import.ipynb b/notebooks/brevitas-network-import.ipynb index cbe494e3165983ddc170989196c5c133e84b168b..9099cb97ffbc6dfad5077ac02888fc9e5eb13af4 100644 --- a/notebooks/brevitas-network-import.ipynb +++ b/notebooks/brevitas-network-import.ipynb @@ -4,12 +4,20 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Importing Brevitas networks into FINN" + "# Importing Brevitas networks into FINN\n", + "\n", + "In this notebook we'll go through an example of how to import a Brevitas-trained QNN into FINN. The steps will be as follows:\n", + "\n", + "1. Load up the trained PyTorch model\n", + "2. Call Brevitas FINN-ONNX export and visualize with Netron\n", + "3. Import into FINN and call cleanup transformations\n", + "\n", + "We'll use the following showSrc function to print the source code for function calls in the Jupyter notebook:" ] }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -19,9 +27,18 @@ " print(\"\".join(inspect.getsourcelines(what)[0]))" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. Load up the trained PyTorch model\n", + "\n", + "The FINN Docker image comes with several [example Brevitas networks](https://github.com/maltanar/brevitas_cnv_lfc), and we'll use the LFC-w1a1 model as the example network here. This is a binarized fully connected network trained on the MNIST dataset. Let's start by looking at what the PyTorch network definition looks like:" + ] + }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 3, "metadata": {}, "outputs": [ { @@ -77,44 +94,509 @@ "showSrc(LFC)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can see that the network topology is constructed using a few helper functions that generate the quantized linear layers and quantized activations. The bitwidth of the layers is actually parametrized in the constructor, so let's instantiate a 1-bit weights and activations version of this network. We also have pretrained weights for this network, which we will load into the model." + ] + }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "metadata": {}, "outputs": [ { - "name": "stdout", + "data": { + "text/plain": [ + "LFC(\n", + " (features): ModuleList(\n", + " (0): QuantHardTanh(\n", + " (act_quant_proxy): ActivationQuantProxy(\n", + " (fused_activation_quant_proxy): FusedActivationQuantProxy(\n", + " (activation_impl): Identity()\n", + " (tensor_quant): ClampedBinaryQuant(\n", + " (scaling_impl): StandaloneScaling(\n", + " (restrict_value): RestrictValue(\n", + " (forward_impl): Sequential(\n", + " (0): PowerOfTwo()\n", + " (1): ClampMin()\n", + " )\n", + " )\n", + " )\n", + " )\n", + " )\n", + " )\n", + " )\n", + " (1): Dropout(p=0.2)\n", + " (2): QuantLinear(\n", + " in_features=784, out_features=1024, bias=False\n", + " (weight_reg): WeightReg()\n", + " (weight_quant): WeightQuantProxy(\n", + " (tensor_quant): BinaryQuant(\n", + " (scaling_impl): ParameterStatsScaling(\n", + " (parameter_list_stats): ParameterListStats(\n", + " (first_tracked_param): _ViewParameterWrapper()\n", + " (stats): Stats(\n", + " (stats_impl): AbsAve()\n", + " )\n", + " )\n", + " (stats_scaling_impl): StatsScaling(\n", + " (affine_rescaling): Identity()\n", + " (restrict_scaling): RestrictValue(\n", + " (forward_impl): Sequential(\n", + " (0): PowerOfTwo()\n", + " (1): ClampMin()\n", + " )\n", + " )\n", + " (restrict_scaling_preprocess): LogTwo()\n", + " )\n", + " )\n", + " )\n", + " )\n", + " (bias_quant): BiasQuantProxy()\n", + " )\n", + " (3): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (4): QuantHardTanh(\n", + " (act_quant_proxy): ActivationQuantProxy(\n", + " (fused_activation_quant_proxy): FusedActivationQuantProxy(\n", + " (activation_impl): Identity()\n", + " (tensor_quant): ClampedBinaryQuant(\n", + " (scaling_impl): StandaloneScaling(\n", + " (restrict_value): RestrictValue(\n", + " (forward_impl): Sequential(\n", + " (0): PowerOfTwo()\n", + " (1): ClampMin()\n", + " )\n", + " )\n", + " )\n", + " )\n", + " )\n", + " )\n", + " )\n", + " (5): Dropout(p=0.2)\n", + " (6): QuantLinear(\n", + " in_features=1024, out_features=1024, bias=False\n", + " (weight_reg): WeightReg()\n", + " (weight_quant): WeightQuantProxy(\n", + " (tensor_quant): BinaryQuant(\n", + " (scaling_impl): ParameterStatsScaling(\n", + " (parameter_list_stats): ParameterListStats(\n", + " (first_tracked_param): _ViewParameterWrapper()\n", + " (stats): Stats(\n", + " (stats_impl): AbsAve()\n", + " )\n", + " )\n", + " (stats_scaling_impl): StatsScaling(\n", + " (affine_rescaling): Identity()\n", + " (restrict_scaling): RestrictValue(\n", + " (forward_impl): Sequential(\n", + " (0): PowerOfTwo()\n", + " (1): ClampMin()\n", + " )\n", + " )\n", + " (restrict_scaling_preprocess): LogTwo()\n", + " )\n", + " )\n", + " )\n", + " )\n", + " (bias_quant): BiasQuantProxy()\n", + " )\n", + " (7): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (8): QuantHardTanh(\n", + " (act_quant_proxy): ActivationQuantProxy(\n", + " (fused_activation_quant_proxy): FusedActivationQuantProxy(\n", + " (activation_impl): Identity()\n", + " (tensor_quant): ClampedBinaryQuant(\n", + " (scaling_impl): StandaloneScaling(\n", + " (restrict_value): RestrictValue(\n", + " (forward_impl): Sequential(\n", + " (0): PowerOfTwo()\n", + " (1): ClampMin()\n", + " )\n", + " )\n", + " )\n", + " )\n", + " )\n", + " )\n", + " )\n", + " (9): Dropout(p=0.2)\n", + " (10): QuantLinear(\n", + " in_features=1024, out_features=1024, bias=False\n", + " (weight_reg): WeightReg()\n", + " (weight_quant): WeightQuantProxy(\n", + " (tensor_quant): BinaryQuant(\n", + " (scaling_impl): ParameterStatsScaling(\n", + " (parameter_list_stats): ParameterListStats(\n", + " (first_tracked_param): _ViewParameterWrapper()\n", + " (stats): Stats(\n", + " (stats_impl): AbsAve()\n", + " )\n", + " )\n", + " (stats_scaling_impl): StatsScaling(\n", + " (affine_rescaling): Identity()\n", + " (restrict_scaling): RestrictValue(\n", + " (forward_impl): Sequential(\n", + " (0): PowerOfTwo()\n", + " (1): ClampMin()\n", + " )\n", + " )\n", + " (restrict_scaling_preprocess): LogTwo()\n", + " )\n", + " )\n", + " )\n", + " )\n", + " (bias_quant): BiasQuantProxy()\n", + " )\n", + " (11): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (12): QuantHardTanh(\n", + " (act_quant_proxy): ActivationQuantProxy(\n", + " (fused_activation_quant_proxy): FusedActivationQuantProxy(\n", + " (activation_impl): Identity()\n", + " (tensor_quant): ClampedBinaryQuant(\n", + " (scaling_impl): StandaloneScaling(\n", + " (restrict_value): RestrictValue(\n", + " (forward_impl): Sequential(\n", + " (0): PowerOfTwo()\n", + " (1): ClampMin()\n", + " )\n", + " )\n", + " )\n", + " )\n", + " )\n", + " )\n", + " )\n", + " (13): Dropout(p=0.2)\n", + " )\n", + " (fc): QuantLinear(\n", + " in_features=1024, out_features=10, bias=False\n", + " (weight_reg): WeightReg()\n", + " (weight_quant): WeightQuantProxy(\n", + " (tensor_quant): BinaryQuant(\n", + " (scaling_impl): ParameterStatsScaling(\n", + " (parameter_list_stats): ParameterListStats(\n", + " (first_tracked_param): _ViewParameterWrapper()\n", + " (stats): Stats(\n", + " (stats_impl): AbsAve()\n", + " )\n", + " )\n", + " (stats_scaling_impl): StatsScaling(\n", + " (affine_rescaling): Identity()\n", + " (restrict_scaling): RestrictValue(\n", + " (forward_impl): Sequential(\n", + " (0): PowerOfTwo()\n", + " (1): ClampMin()\n", + " )\n", + " )\n", + " (restrict_scaling_preprocess): LogTwo()\n", + " )\n", + " )\n", + " )\n", + " )\n", + " (bias_quant): BiasQuantProxy()\n", + " )\n", + ")" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import torch\n", + "\n", + "trained_lfc_w1a1_checkpoint = \"/workspace/brevitas_cnv_lfc/pretrained_models/LFC_1W1A/checkpoints/best.tar\"\n", + "lfc = LFC(weight_bit_width=1, act_bit_width=1, in_bit_width=1).eval()\n", + "checkpoint = torch.load(trained_lfc_w1a1_checkpoint, map_location=\"cpu\")\n", + "lfc.load_state_dict(checkpoint[\"state_dict\"])\n", + "lfc" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We have now instantiated our trained PyTorch network. Let's try to run an example MNIST image through the network using PyTorch." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8li6FKAAATB0lEQVR4nO3dfWxd5X0H8O/3XttxXhwSJ8GYJECIYIPSNQUPWoEmCm2agqbAtFGiFZGK1f0DpNKyF8S0FU3ahmAtmraOyW0iwsroulJWptEXSJkoaKAkKCThNbwkEC/EhLzY5MW5vve3P3zo3ODze8w999xzm+f7kSzb9+dz7pNrf3Pt+zvP89DMICInvlLRAxCR5lDYRSKhsItEQmEXiYTCLhKJtmbeWQenWSdmNvMuTwwM1LM0VBg4ecZuDcvl9FNXq9nO3ZZ+bgCwsfrPz/Z2/9yVSt3nztNRHMIxG530m5op7CRXAPh7AGUA3zGzO7yv78RMXMTLs9zliSkQOC8wAGBjY/XfdXuHf+5QIGt+vTz7pNRa9cBB/9wB5bnz3Hr13X3pxcB/Ym09p7r1scH/detFecbWp9bq/jWeZBnAtwB8DsC5AFaRPLfe84lIvrL8zX4hgFfN7HUzOwbgewBWNmZYItJoWcK+EMBbEz7fldz2K0j2k9xIcmMFoxnuTkSyyP3VeDMbMLM+M+trx7S8705EUmQJ+yCAxRM+X5TcJiItKEvYNwA4i+QSkh0ArgXwcGOGJSKNVnfrzczGSN4E4KcYb72tNbPnGzayiJTnzHHr1f376z53acYMt147fNg/QaAtWJ4926277bWS31IMCrQkvfYa2/wf/eqeoXpG1NIy9dnN7BEAjzRoLCKSI10uKxIJhV0kEgq7SCQUdpFIKOwikVDYRSLR1PnssQr1uu3YsWznn5m+RkDt0KFs5w6MvTo8XPe523p73HpoGmmwF+5cI1DuPcW/77d2+acOTQ2uZPue5kHP7CKRUNhFIqGwi0RCYReJhMIuEgmFXSQSar01QXAaaUhgmmmm9lqe5w7IukJreU76yrWAP7021FoLTd3N/D0tgJ7ZRSKhsItEQmEXiYTCLhIJhV0kEgq7SCQUdpFIqM/eBOV53W7d3W0UQHn+fP8ORp1ttU72dzq1Gf4uPaUhfxnrvcvPdOvzvrgztVap+UtBt31tlluvvfCaW88iy9TdVqVndpFIKOwikVDYRSKhsItEQmEXiYTCLhIJhV0kEuqzN0Gojx7y8t2L3Porn1qTWjti/pLGpcD/93tr/vE9Zb9P75nGdre+fNZqt14u+2O3SnqN0/xxm3ftAoBSV5dbr42MuPUiZAo7yR0ARgBUAYyZWV8jBiUijdeIZ/ZPmdneBpxHRHKkv9lFIpE17AbgZyQ3keyf7AtI9pPcSHJjBf7fQSKSn6y/xl9iZoMkTwbwKMmXzOyJiV9gZgMABgBgNrst4/2JSJ0yPbOb2WDyfgjAQwAubMSgRKTx6g47yZkku97/GMByANsaNTARaawsv8b3AHiI4+uOtwH4VzP7SUNGdYIJbXt8YOVvufUvfPQXbr3M9P+zd1b8v5y6Sn4f/bQ2f0551WpufdTGUmv7a/7a6x/7h+f8c9f8H98tf/HbqbVpP97gHtu28FS3nnXN+yLUHXYzex3Axxo4FhHJkVpvIpFQ2EUiobCLREJhF4mEwi4SCZo176K22ey2i3h50+7v18Wql/w2zurZQ279tcp7qbWl7X7r7GDtiFsvw9/SOTRFtob01tysUqd7bMgbzr8bAJ47dkpq7R+/dI17bPm/n/Xrc+e69ep+fwnuvDxj6zFs+yb9pumZXSQSCrtIJBR2kUgo7CKRUNhFIqGwi0RCYReJxK/XUtJ0er7ONM/xst8vtlqG6w0C0zwRuJZhzW1Xu/VT71zr1pfPSO+lh6agjtSqbv2y+//ErS99wO8n1zrTl4t+c4W/HPOG/m+69SWBawjmlNKvT/jjK/2lpH/j1YVufWzXoFtvRXpmF4mEwi4SCYVdJBIKu0gkFHaRSCjsIpFQ2EUi0fz57KVP13+CJo61kbLOfT6y0t97Y9856ZdLjE13D4W1+Y/pmf8W6KNvecm/gwzO2+Q/F311wRNufVFgGWzPZ09d5tY1n11EWpbCLhIJhV0kEgq7SCQUdpFIKOwikVDYRSLR/Pns3rzz0LzwImWYS5+15zr9Pze59UWPlFNrVvG3ZC7POcmtVw8cdOuh7ahRS/+e1o4edQ996m7/+oKv/o3fZ8+C7R1uvag+ehbBZ3aSa0kOkdw24bZuko+S3J68968wEJHCTeXX+HsBrDjutlsBrDezswCsTz4XkRYWDLuZPQFg33E3rwSwLvl4HYCrGjwuEWmwev9m7zGz3cnHbwPoSftCkv0A+gGgE4G/70QkN5lfjbfxmTSpsynMbMDM+sysrx3+In8ikp96w76HZC8AJO/9bUZFpHD1hv1hANcnH18P4EeNGY6I5CX4NzvJBwBcCmA+yV0Avg7gDgDfJ3kDgJ0A/M2uJwqsU143rw8OZF5XPnS8xwL/5vL8eW69uvfdTOd3j61mu7ahdnQ08AX1j23ulgNuPct89Yr54ypND+wdH6hXh4c/7JByFwy7ma1KKV3e4LGISI50uaxIJBR2kUgo7CKRUNhFIqGwi0RCWzYngls21/ypou59t/kPc3Wf32IKKc/rdk7ut5hCU1iDAq21Umd6i4rT/XWu31ru/LumwNuu+r1aoGU4LXC159hYHSMqlp7ZRSKhsItEQmEXiYTCLhIJhV0kEgq7SCQUdpFItFafPTRN1RNYhtrGitvu2QK97qyq7x6/ROD/Y6Bf7PXBAYCd/vGhPr23XHRbYBnrpb/7mlvfWz3k1ueXZ6bWRkLLlgeuHzghl5IWkRODwi4SCYVdJBIKu0gkFHaRSCjsIpFQ2EUi0dw+O/253eE55fn2q12l9G2RWU6vTUWe2ypXLj7PPXbvR/0+etsR/3uyYN2zbt01w5/PPnDm/W49yyLY3z1wgVv3rl0AgFJXl1uvjYx86DHlTc/sIpFQ2EUiobCLREJhF4mEwi4SCYVdJBIKu0gkmjyfnf767lbJ8a79ufKlwBrmnJk+N5od7e6xb646w61b4Ltwxoo33Pqc9vTH9NaF33KPPafdH3st0M3+0hf9zXxrlj62L5z8Y/fYo+b3+HvK/jUCd+1bmlp76rNL3GNLM/wtl1uxjx4SfGYnuZbkEMltE267neQgyc3J2xX5DlNEsprKr/H3Algxye13m9my5O2Rxg5LRBotGHYzewKAf+2giLS8LC/Q3URyS/Jr/ty0LyLZT3IjyY0VS1+PTETyVW/Y7wGwFMAyALsBfCPtC81swMz6zKyvnf7ihiKSn7rCbmZ7zKxqZjUA3wZwYWOHJSKNVlfYSfZO+PRqANvSvlZEWkOwz07yAQCXAphPcheArwO4lOQyAAZgB4AvT+nezNy522zv8A8PzPv2lM89263v+L15bv03P7M9tTaw5N/dY731ywHgjcp7bn1J+yy3vmss/fhFbf6xIW9Ujrj1f178mFufUUr/nr5S8dd9Py3j2Be0pffCX/6a32c/669OvOevYNjNbNUkN6/JYSwikiNdLisSCYVdJBIKu0gkFHaRSCjsIpFo/pbNzpLMWVprIS/fkHpFLwDgtWv/ya1vGk0fW6i1FtJV8qffrj/iL1W99Wj6ctErZ/ktpAVl/0cg1PYL2V89nFo7u91/3CrmLx1+2Pyfl9Wzh1Jrl33+LvfYaz6y2q13/6H/uLXils56ZheJhMIuEgmFXSQSCrtIJBR2kUgo7CKRUNhFItH8Pruz7XKWrYlDW+i+9Hl/SWXA72VfMC19qubTR/1+8EMH/e2BH7vnk259/sD/uPXKp9PP/8rfnuIee+OCx936R/xZx/jJYX8558Vt6VNkz//pH7nHTt/p3/knr9zi1tec9mRq7aj51zY8vewHbn3Fg1e6dVymPruIFERhF4mEwi4SCYVdJBIKu0gkFHaRSCjsIpGgBbbFbaSTSvPsE53pG77WjvrbQ7Wdvji1NvbmLvfYe3f+wq2HNoteWJ6RWit721ADWHPQ73Uv63zTrb9dne3WXx/tSa1d3fW8e2y3s9Qz4C8FDQAvHkufrw4Aq//yltTanPv86wdC2hYtdOv8bvp20/cu9fvo71T9PvzMUmAr69Mucet5ecbWY9j2TTp4PbOLREJhF4mEwi4SCYVdJBIKu0gkFHaRSCjsIpFoap99NrvtIl6ey7lD2z3v+cGZbv2pC+5z616/ebezZTIA9Aa2Hg5t2byobbpbH7X0qwRmlTrdY28avMit//w//Ln4J2/2r1CY9l8bUmttvf71B2N73nHrpY52t+5dt1FZ3uce+/N7v+PWz3nqOrd+2h9sdet5ydRnJ7mY5OMkXyD5PMmvJLd3k3yU5Pbkvb8Lg4gUaiq/xo8BuMXMzgXwCQA3kjwXwK0A1pvZWQDWJ5+LSIsKht3MdpvZs8nHIwBeBLAQwEoA65IvWwfgqrwGKSLZfag16EieAeDjAJ4B0GNmu5PS2wAmvUCbZD+AfgDoRPr15SKSrym/Gk9yFoAHAdxsZsMTazb+Kt+kr/SZ2YCZ9ZlZXzv8xQlFJD9TCjvJdowH/X4z+2Fy8x6SvUm9F0D6lpkiUrhg640kMf43+T4zu3nC7XcBeNfM7iB5K4BuM/tT71yh1lvbktPdsYy9sdOte0qdfgsKZ5/hlof+On1K4/knD7rHvjo83613dYy69Z37/UZHz53pbUE+7W/ZXJ7lb5tsx7Jtox2atuzhNP83QRv1HzfQmaYa+Lkvz5/n1qv7Dvj37SyZniev9TaVv9kvBnAdgK0kNye33QbgDgDfJ3kDgJ0ArmnEYEUkH8Gwm9mTANL+i8znChkRaThdLisSCYVdJBIKu0gkFHaRSCjsIpFo6pbNLJdQnpW+LHKoj+71Pqt73/Xve7o/TbS65SW3vuD303u+bwb6vR30l1seDfR8T+30e7peL7s0w79EuTo87NazKs1M7+PXDh3yjw302auBx50d6dcfhHr0oZ+n8mx/ee+8H9d66JldJBIKu0gkFHaRSCjsIpFQ2EUiobCLREJhF4lEU/vsVq25/cfgHGKn9xlaSrq6f78/uIDywt7U2tjrO/yDA330UleXW6+NjPjn95TL9R8LoDznJLdePXDQrdcO+9cYuOcO9apL/r8tON/dO3XB1yfkQc/sIpFQ2EUiobCLREJhF4mEwi4SCYVdJBIKu0gkmtpnDwnNIfZYJdv65iHBXnoGmfroOZ871EcPynNL8BzXZs9yfUCr0jO7SCQUdpFIKOwikVDYRSKhsItEQmEXiYTCLhKJYNhJLib5OMkXSD5P8ivJ7beTHCS5OXm7Iv/hiki9pnJRzRiAW8zsWZJdADaRfDSp3W1mf5ff8ESkUaayP/tuALuTj0dIvghgYd4DE5HG+lB/s5M8A8DHATyT3HQTyS0k15Kcm3JMP8mNJDdWUP8yQSKSzZTDTnIWgAcB3GxmwwDuAbAUwDKMP/N/Y7LjzGzAzPrMrK8d/t5dIpKfKYWdZDvGg36/mf0QAMxsj5lVzawG4NsALsxvmCKS1VRejSeANQBeNLNvTrh94nKrVwPY1vjhiUijTOXV+IsBXAdgK8nNyW23AVhFchkAA7ADwJdzGaGINMRUXo1/EgAnKT3S+OGISF50BZ1IJBR2kUgo7CKRUNhFIqGwi0RCYReJhMIuEgmFXSQSCrtIJBR2kUgo7CKRUNhFIqGwi0RCYReJBC3PLXWPvzPyHQA7J9w0H8Depg3gw2nVsbXquACNrV6NHNvpZrZgskJTw/6BOyc3mllfYQNwtOrYWnVcgMZWr2aNTb/Gi0RCYReJRNFhHyj4/j2tOrZWHRegsdWrKWMr9G92EWmeop/ZRaRJFHaRSBQSdpIrSL5M8lWStxYxhjQkd5DcmmxDvbHgsawlOURy24Tbukk+SnJ78n7SPfYKGltLbOPtbDNe6GNX9PbnTf+bnWQZwCsAPgNgF4ANAFaZ2QtNHUgKkjsA9JlZ4RdgkPwdAO8BuM/MzktuuxPAPjO7I/mPcq6Z/VmLjO12AO8VvY13sltR78RtxgFcBWA1CnzsnHFdgyY8bkU8s18I4FUze93MjgH4HoCVBYyj5ZnZEwD2HXfzSgDrko/XYfyHpelSxtYSzGy3mT2bfDwC4P1txgt97JxxNUURYV8I4K0Jn+9Ca+33bgB+RnITyf6iBzOJHjPbnXz8NoCeIgczieA23s103DbjLfPY1bP9eVZ6ge6DLjGz8wF8DsCNya+rLcnG/wZrpd7plLbxbpZJthn/pSIfu3q3P8+qiLAPAlg84fNFyW0twcwGk/dDAB5C621Fvef9HXST90MFj+eXWmkb78m2GUcLPHZFbn9eRNg3ADiL5BKSHQCuBfBwAeP4AJIzkxdOQHImgOVova2oHwZwffLx9QB+VOBYfkWrbOOdts04Cn7sCt/+3Mya/gbgCoy/Iv8agD8vYgwp4zoTwHPJ2/NFjw3AAxj/ta6C8dc2bgAwD8B6ANsBPAagu4XG9i8AtgLYgvFg9RY0tksw/iv6FgCbk7crin7snHE15XHT5bIikdALdCKRUNhFIqGwi0RCYReJhMIuEgmFXSQSCrtIJP4PSkcHEGlbZOgAAAAASUVORK5CYII=\n", + "text/plain": [ + "<Figure size 432x288 with 1 Axes>" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "import matplotlib.pyplot as plt\n", + "from pkgutil import get_data\n", + "import onnx\n", + "import onnx.numpy_helper as nph\n", + "raw_i = get_data(\"finn\", \"data/onnx/mnist-conv/test_data_set_0/input_0.pb\")\n", + "input_tensor = onnx.load_tensor_from_string(raw_i)\n", + "input_tensor_npy = nph.to_array(input_tensor)\n", + "input_tensor_pyt = torch.from_numpy(input_tensor_npy).float()\n", + "imgplot = plt.imshow(input_tensor_npy.reshape(28,28))" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([2.4663e-03, 6.8211e-06, 8.9177e-01, 2.1330e-05, 3.6883e-04, 3.0418e-06,\n", + " 1.1795e-04, 5.0158e-05, 1.0517e-01, 2.4597e-05])" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from torch.nn.functional import softmax\n", + "# do forward pass in PyTorch/Brevitas\n", + "produced = lfc.forward(input_tensor_pyt).detach()\n", + "probabilities = softmax(produced, dim=-1).flatten()\n", + "probabilities" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAEICAYAAABS0fM3AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8li6FKAAAa3klEQVR4nO3debxdZXn28d9FEgRCBCFxgASIGCxxeAuNgKJCRSqoQB1qwRcrVkXfFkVBK1oriIqvI7WKVgQnZCgC2qABxFdE3yKBMBsmQwQSQAlzGEoYrv6xni2bk3P2WSRZ68BZ1/fz2Z+z13jfe59z1r3X86z1bNkmIiK6a62xTiAiIsZWCkFERMelEEREdFwKQUREx6UQRER0XApBRETHpRBE1CTpMEk/KM83k3SvpAmrsJ+PSTpmzWc4MOaOkn5Xcv7rNmPHk18KwTgk6XpJrx5m/s6SHi0Hg97j9L7lW0n6oaTbJN0t6XJJB63KwW5I3AMkLZD0oKTvPsFt3yLpPEn3S/rlKOv2v77lkq6R9I7VyX0ktm+0vb7tR2rktHTItkfYflcTeQ1wOPC1kvOPV3dnkr4r6dNrIK94Ekgh6J6by8Gg99gDQNKWwHxgCfAi2xsAfwPMAaasbkzg08C3V2HbO4B/Bf5v3Vi21weeDnwE+Jak2UNXkjRxFXJ5KtscWLgqG3bwveqcFILo+SRwnu2DbN8CYPsa22+1fdfQlSX9paQr+qbPlnRh3/Sve00Qtk8rn0JvH2Y/z5D0E0nLJN1Znk/vLbf9c9snUxWT2lz5MXAnMFvSFpIs6Z2SbgR+UeLvUM447pJ0maSd+3KbKenccnZxNjC1b1lvfxPL9EaSviPp5vI6fixpMnAGsEnfGdgm/U1MZds9JS0sOfxS0tZ9y66X9KFydna3pP+QtE5ZNrW8X3dJuqO85yv9T0u6DngucHrJ4Wklj7llu0WS3t23/mGSTpH0A0n3APsNeq/73ot3SFpSXv97Jb2k5H2XpK/1rb+lpF9Iur2cfR4vacO+5dtKuqS87z8sr/nTfctfL+nSst/zJL14UH4xuhSC6Hk1cMoTWP98YFY5GE0CXkx1wJsiaV2qM4lf19jPWsB3qD6xbgY8AHxt4BY1SFpL0huADYEr+hbtBGwNvEbSpsBPqc5WNgI+BJwqaVpZ9wTgIqoC8Cng7QNCHgesB7wAeCZwpO37gN15/FnY4wqapK2AE4EPANOAeVQH7LX7VnsLsBswk+p93q/MPxhYWrZ7FvAxYKUxY2xvCdwI7FFyeBA4qWy7CfBm4AhJr+rbbC+qv4cNgeMHvO5+2wOzgL+lOov7Z6q/qxcAb5G0U+9lA58tsbcGZgCHlfdjbeBHwHepficnAm/oBZC0DdWZ5XuAjYFvAnMlPa1mjjGMFILu2aR8kuo93lLmbwzcUncnth8ALgReCfwFcBnwX8COwA7A72yvdAYwzH5ut32q7fttLwc+Q3WwXlWbSLoLuA04FHib7Wv6lh9m+76S/77APNvzbD9q+2xgAfBaSZsBLwH+xfaDtn8FnM4wJD2H6oD/Xtt32n7I9rk18/1b4Ke2z7b9EPBFYF3gZX3r/Jvtm23fUXL48zL/IeA5wOYl5q9dY/AwSTOofk8fsf3fti8FjgH+rm+139j+cXlfHqj5Wj5V9vcz4D7gRNu32r6J6kPBNgC2F5XX+6DtZcCXeex3vgMwsbzmh2yfBlzQF2N/4Ju259t+xPb3gAfLdrGK0vbXPTfbnj7M/NupDirDkvTvVAdOgCNsHwGcC+xM9cnyXKpmmJ2o/jFrHQglrQccSfWJ9xll9hRJE0briB3BSK+vZ0nf882Bv5G0R9+8ScA5VJ9W7yyf6ntuoPr0OtQM4A7bd65CvpuU/QJg+1FJS4BN+9b5Q9/z+8s2AF+g+iT9M0kAR9uu05eyScl3ed+8G6jO4nqW8MT9se/5A8NMrw8g6VnAV4BXUPU/rUX1t9PL7aYhBW3o7+ztkt7XN29tHntPYhXkjCB6fg68aaSFtt/b17xxRJndKwSvLM/PpSoEO1GzEFA1bzwf2N7208u+oGo+aMLQA8xxtjfse0wuB9NbgGeUdv6ezUbY5xJgo/527hHiDedmqoMbAKqO6DOAm0Z9IfZy2wfbfi6wJ3CQpF1G267E3EhS/0UAmw2J2eSwxEeU/b+o/M735bHf9y3ApuV96OkvvkuAzwz5na1n+8QG8x33UgjGr0mS1ul7jHb2dyjwMklfkPRsAEnPKx2Gwx3gAM6jOohvB1xgeyHVQW174Fe9lSRNLB2cE4AJQ/KZQvVp8S5JG5U86Nt2Qtl2IrBW2XbSE3gfBvkBsIek1/TiqLrcc7rtG6iaiT4paW1JLwf2GG4npXP9DODrqjq/J0nqFbQ/AhtL2mCEHE4GXidpl/K6DqY6ozpvtORLp+nzykHzbuAR4NHRtrO9pOz/s+U1vxh4Z3k/2jAFuBe4u/TTfLhv2W+oXscB5e9mL6q/r55vAe+VtL0qkyW9bkhRiycohWD8mkd1gO09Dhu0su3rgJcCWwALJd0NnEp1MFw+wjb3ARcDC22vKLN/A9xg+9a+VT9ecjiE6tPfA2UeVJ2K61K16Z8PnDkkzNvK+t+gakp4gOpgsNrKAXEvqk7WZVSfNj/MY/8Xb6UqandQFajvD9jd26ja7K8GbqXq/MX21VQdnotLn8zjmjBK/8W+wFep3oM9qDp1VzC6WVRncvdSve9ft31Oje0A9qH6Xd9M1Tl7qO2f19x2dX0S2JaqeP0UOK23oLzuN1IVpruo3pufUBVHbC8A3k11QcGdwCJGuaopRqd8MU1EPJlJmg/8u+3vjHUu41XOCCLiSUXSTpKeXZqG3k51yezQM8VYg3LVUEQ82Tyfqu9kMrAYeHPvJsdoRpqGIiI6Lk1DEREd95RrGpo6daq32GKLsU4jIuIp5aKLLrrN9rThlj3lCsEWW2zBggULxjqNiIinFEk3jLQsTUMRER2XQhAR0XEpBBERHZdCEBHRcSkEEREdl0IQEdFxKQQRER2XQhAR0XEpBBERHfeUu7M4nrgjz7628Rgf3HWrxmNERDNyRhAR0XEpBBERHZdCEBHRcSkEEREdl0IQEdFxKQQRER2XQhAR0XEpBBERHZdCEBHRcSkEEREdl0IQEdFxKQQRER2XQhAR0XEpBBERHZdCEBHRcSkEEREdl0IQEdFxjRYCSbtJukbSIkmHDLN8M0nnSLpE0uWSXttkPhERsbLGCoGkCcBRwO7AbGAfSbOHrPZx4GTb2wB7A19vKp+IiBhek2cE2wGLbC+2vQI4CdhryDoGnl6ebwDc3GA+ERExjCYLwabAkr7ppWVev8OAfSUtBeYB7xtuR5L2l7RA0oJly5Y1kWtERGeNdWfxPsB3bU8HXgscJ2mlnGwfbXuO7TnTpk1rPcmIiPGsyUJwEzCjb3p6mdfvncDJALZ/A6wDTG0wp4iIGKLJQnAhMEvSTElrU3UGzx2yzo3ALgCStqYqBGn7iYhoUWOFwPbDwAHAWcBVVFcHLZR0uKQ9y2oHA++WdBlwIrCfbTeVU0RErGxikzu3PY+qE7h/3if6nl8J7NhkDhERMdhYdxZHRMQYSyGIiOi4FIKIiI5LIYiI6LgUgoiIjkshiIjouBSCiIiOSyGIiOi4FIKIiI5LIYiI6LgUgoiIjkshiIjouBSCiIiOSyGIiOi4FIKIiI5LIYiI6LgUgoiIjkshiIjouBSCiIiOSyGIiOi4FIKIiI4btRBIelEbiURExNioc0bwdUkXSPoHSRs0nlFERLRq1EJg+xXA/wZmABdJOkHSro1nFhERrajVR2D7d8DHgY8AOwH/JulqSW9sMrmIiGhenT6CF0s6ErgKeBWwh+2ty/MjG84vIiIaNrHGOl8FjgE+ZvuB3kzbN0v6eGOZRUREK+o0Df3I9nH9RUDSgQC2j2sss4iIaEWdQvB3w8zbbw3nERERY2TEpiFJ+wBvBWZKmtu3aApwR9OJRUREOwb1EZwH3AJMBb7UN385cHmTSUVERHtGLAS2bwBuAF7aXjoREdG2QU1D/9/2yyUtB9y/CLDtpzeeXURENG7QGcHLy88p7aUTERFtG3RGsNGgDW2nwzgiYhwY1Fl8EVWTkIZZZuC5jWQUERGtGtQ0NLPNRCIiYmyMeEOZpD8rP7cd7lFn55J2k3SNpEWSDhlhnbdIulLSQkknrNrLiIiIVTWoaeggYH8efw9Bj6kGnRuRpAnAUcCuwFLgQklzbV/Zt84s4KPAjrbvlPTMJ5h/RESspkFNQ/uXn3+5ivveDlhkezGApJOAvYAr+9Z5N3CU7TtLrFtXMVZERKyiOsNQryPpIEmnSTpV0gckrVNj35sCS/qml5Z5/bYCtpL0X5LOl7TbCDnsL2mBpAXLli2rEToiIuqqM+jc94EXUA1H/bXyfE2NOjoRmAXsDOwDfEvShkNXsn207Tm250ybNm0NhY6ICKj3fQQvtD27b/ocSVeOuPZjbqL6esue6WVev6XAfNsPAb+XdC1VYbiwxv4jImINqHNGcLGkHXoTkrYHFtTY7kJglqSZktYG9gbmDlnnx1RnA0iaStVUtLjGviMiYg0ZdGfxFVRXB00CzpN0Y5neHLh6tB3bfljSAcBZwATg27YXSjocWGB7bln2V+UM4xHgw7ZvX90XFRER9Q1qGnr96u7c9jxg3pB5n+h7bqrLVA9a3VgREbFqRhuG+k/KNf51rhaKiIinkDqXj+4p6XfA74FzgeuBMxrOKyIiWlKns/hTwA7AtWX8oV2A8xvNKiIiWlOnEDxUOnDXkrSW7XOAOQ3nFRERLalzH8FdktYHfg0cL+lW4L5m04qIiLbUOSPYC3gA+ABwJnAdsEeTSUVERHtGPSOwfZ+kZ1MNIncHcFau9Y+IGD/qXDX0LuAC4I3Am4HzJf1904lFREQ76vQRfBjYpncWIGlj4Dzg200mFhER7ajTR3A7sLxvenmZFxER48CgsYZ6wz4sAuZL+k+qsYb2Ai5vIbeIiGjBoKahKeXndeXR85/NpRMREW0bNNbQJ/uny70E2L636aQiIqI9da4aeqGkS4CFwEJJF0l6QfOpRUREG+p0Fh8NHGR7c9ubAwcD32o2rYiIaEudQjC5jC8EgO1fApMbyygiIlpV5z6CxZL+hce+sH5f8nWSERHjRp0zgr8HpgGnAacCU8u8iIgYBwaeEUiaAPyz7fe3lE9ERLRs4BmB7UeAl7eUS0REjIE6fQSXSJoL/JC+7yGwfVpjWUVERGvqFIJ1qMYWelXfPFP1GURExFNcrdFHbd/WeCYRETEmRuwjkLSHpGXA5ZKWSnpZi3lFRERLBnUWfwZ4he1NgDcBn20npYiIaNOgQvCw7asBbM/nsdFIIyJiHBnUR/DMvu8kWGna9pebSysiItoyqBB8i8efBQydjoiIcaD29xFERMT4VGesoYiIGMdSCCIiOi6FICKi40bsIxhyxdBKctVQRMT4MOiqod4VQs8HXgLMLdN7ABc0mVRERLRn1KuGJP0K2Nb28jJ9GPDTVrKLiIjG1ekjeBawom96RZkXERHjQJ3RR78PXCDpR2X6r4HvNZdSRES0adRCYPszks4AXlFmvcP2Jc2mFRERbal7+eh6wD22vwIslTSzzkaSdpN0jaRFkg4ZsN6bJFnSnJr5RETEGjJqIZB0KPAR4KNl1iTgBzW2mwAcBewOzAb2kTR7mPWmAAcC8+unHRERa0qdM4I3AHtSvq/Y9s3UG3xuO2CR7cW2VwAnAXsNs96ngM8B/10r44iIWKPqFIIVtk31PcVImlxz35sCS/qml5Z5fyJpW2CG7YGXo0raX9ICSQuWLVtWM3xERNRRpxCcLOmbwIaS3g38HDhmdQNLWgv4MnDwaOvaPtr2HNtzpk2btrqhIyKiT52rhr4oaVfgHqq7jD9h++wa+74JmNE3Pb3M65kCvBD4pSSAZwNzJe1pe0HN/CMiYjWNWggkfc72R4Czh5k3yIXArHKF0U3A3sBbewtt3w1M7dvnL4EPpQhERLSrTtPQrsPM2320jWw/DBwAnAVcBZxse6GkwyXt+cTSjIiIpgwaffT/AP8AbCnp8r5FU4Dz6uzc9jxg3pB5nxhh3Z3r7DMiItasQU1DJwBnAJ8F+m8GW277jkazioiI1ozYNGT7btvXA18B7rB9g+0bgIclbd9WghER0aw6fQTfAO7tm763zIuIiHGgTiFQuaEMANuPUm/U0oiIeAqoUwgWS3q/pEnlcSCwuOnEIiKiHXUKwXuBl1HdC7AU2B7Yv8mkIiKiPXXuLL6V6mawiIgYhwbdR/BPtj8v6auUAef62X5/o5lFREQrBp0RXFV+ZsiHiIhxbMRCYPv08jPfTxwRMY4Naho6nWGahHpsZ7ygiIhxYFDT0BfLzzdSDRHd+3rKfYA/NplURES0Z1DT0LkAkr5ku/9L5U+XlH6DiIhxos59BJMlPbc3Ub5foO7XVUZExJNcnaEiPkj1LWKLAQGbA+9pNKuIiGhNnRvKzpQ0C/izMutq2w82m1ZERLRl1KYhSesBHwYOsH0ZsJmk1zeeWUREtKJOH8F3gBXAS8v0TcCnG8soIiJaVacQbGn788BDALbvp+oriIiIcaBOIVghaV3KzWWStgTSRxARMU7UuWroUOBMYIak44Edgf2aTCoiItozsBBIEnA11d3FO1A1CR1o+7YWcouIiBYMLAS2LWme7RcBP20pp4iIaFGdPoKLJb2k8UwiImJM1Okj2B7YV9L1wH1UzUO2/eImE4uIiHbUKQSvaTyLiIgYM4O+j2Adqi+ufx5wBXCs7YfbSiwiItoxqI/ge8AcqiKwO/ClVjKKiIhWDWoaml2uFkLSscAF7aQUERFtGnRG8FDvSZqEIiLGr0FnBP9L0j3luYB1y3TvqqGnN55dREQ0btBXVU5oM5GIiBgbdW4oi4iIcSyFICKi41IIIiI6LoUgIqLjUggiIjqu0UIgaTdJ10haJOmQYZYfJOlKSZdL+n+SNm8yn4iIWFljhUDSBOAoquEpZgP7SJo9ZLVLgDllJNNTgM83lU9ERAyvyTOC7YBFthfbXgGcBOzVv4Ltc2zfXybPB6Y3mE9ERAyjyUKwKbCkb3ppmTeSdwJnDLdA0v6SFkhasGzZsjWYYkREPCk6iyXtSzXS6ReGW277aNtzbM+ZNm1au8lFRIxzdb6YZlXdBMzom55e5j2OpFcD/wzsZPvBBvOJiIhhNHlGcCEwS9JMSWsDewNz+1eQtA3wTWBP27c2mEtERIygsUJQhq4+ADgLuAo42fZCSYdL2rOs9gVgfeCHki6VNHeE3UVEREOabBrC9jxg3pB5n+h7/uom40dExOieFJ3FERExdlIIIiI6LoUgIqLjUggiIjouhSAiouNSCCIiOi6FICKi41IIIiI6LoUgIqLjUggiIjouhSAiouNSCCIiOi6FICKi41IIIiI6LoUgIqLjUggiIjqu0S+miYho05FnX9t4jA/uulXjMdqWM4KIiI5LIYiI6LgUgoiIjkshiIjouBSCiIiOSyGIiOi4FIKIiI5LIYiI6LgUgoiIjkshiIjouBSCiIiOSyGIiOi4FIKIiI5LIYiI6LgUgoiIjkshiIjouBSCiIiOSyGIiOi4FIKIiI5LIYiI6LgUgoiIjmu0EEjaTdI1khZJOmSY5U+T9B9l+XxJWzSZT0RErKyxQiBpAnAUsDswG9hH0uwhq70TuNP284Ajgc81lU9ERAxvYoP73g5YZHsxgKSTgL2AK/vW2Qs4rDw/BfiaJNl2Ewkdefa1Tez2cT6461aNx4iIWJOaLASbAkv6ppcC24+0ju2HJd0NbAzc1r+SpP2B/cvkvZKuaSTj4U0dms8gB41h7DUsrzuxE3sYa/Bvve3XvflIC5osBGuM7aOBo8citqQFtuckdmIndmKPl9hDNdlZfBMwo296epk37DqSJgIbALc3mFNERAzRZCG4EJglaaaktYG9gblD1pkLvL08fzPwi6b6ByIiYniNNQ2VNv8DgLOACcC3bS+UdDiwwPZc4FjgOEmLgDuoisWTzZg0SSV2Yid2YrdF+QAeEdFtubM4IqLjUggiIjouhWAEow2P0XDsb0u6VdJvW447Q9I5kq6UtFDSgS3GXkfSBZIuK7E/2VbsvhwmSLpE0k/GIPb1kq6QdKmkBS3H3lDSKZKulnSVpJe2FPf55fX2HvdI+kAbsUv8D5a/td9KOlHSOi3GPrDEXdjmax6R7TyGPKg6t68DngusDVwGzG4x/iuBbYHftvy6nwNsW55PAa5t63UDAtYvzycB84EdWn79BwEnAD9pM26JfT0wte24Jfb3gHeV52sDG45BDhOAPwCbtxRvU+D3wLpl+mRgv5ZivxD4LbAe1QU7PweeNxa/+94jZwTD+9PwGLZXAL3hMVph+1dUV1G1yvYtti8uz5cDV1H9w7QR27bvLZOTyqO1KxkkTQdeBxzTVswnA0kbUH3wOBbA9grbd41BKrsA19m+ocWYE4F1yz1M6wE3txR3a2C+7fttPwycC7yxpdjDSiEY3nDDY7RyQHyyKCPBbkP1ybytmBMkXQrcCpxtu7XYwL8C/wQ82mLMfgZ+JumiMqRKW2YCy4DvlGaxYyRNbjF+z97AiW0Fs30T8EXgRuAW4G7bP2sp/G+BV0jaWNJ6wGt5/M23rUshiJVIWh84FfiA7Xvaimv7Edt/TnUX+naSXthGXEmvB261fVEb8UbwctvbUo3W+4+SXtlS3IlUzZDfsL0NcB/Qdp/Y2sCewA9bjPkMqrP8mcAmwGRJ+7YR2/ZVVCMt/ww4E7gUeKSN2CNJIRheneExxiVJk6iKwPG2TxuLHErTxDnAbi2F3BHYU9L1VM2Ar5L0g5ZiA3/6hIrtW4EfUTVPtmEpsLTv7OsUqsLQpt2Bi23/scWYrwZ+b3uZ7YeA04CXtRXc9rG2/8L2K4E7qfrjxkwKwfDqDI8x7kgSVVvxVba/3HLsaZI2LM/XBXYFrm4jtu2P2p5uewuq3/UvbLfy6RBA0mRJU3rPgb+iaj5onO0/AEskPb/M2oXHDxXfhn1osVmouBHYQdJ65e9+F6o+sVZIemb5uRlV/8AJbcUezlNi9NG2eYThMdqKL+lEYGdgqqSlwKG2j20h9I7A24ArSls9wMdsz2sh9nOA75UvNFoLONl265dxjpFnAT+qjkdMBE6wfWaL8d8HHF8+9CwG3tFW4FL4dgXe01ZMANvzJZ0CXAw8DFxCu0M+nCppY+Ah4B/HqIP+TzLEREREx6VpKCKi41IIIiI6LoUgIqLjUggiIjouhSAiouNSCCIiOi6FICKi4/4HEHMv4f97kiwAAAAASUVORK5CYII=\n", + "text/plain": [ + "<Figure size 432x288 with 1 Axes>" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "import numpy as np\n", + "objects = [str(x) for x in range(10)]\n", + "y_pos = np.arange(len(objects))\n", + "plt.bar(y_pos, probabilities, align='center', alpha=0.5)\n", + "plt.xticks(y_pos, objects)\n", + "plt.ylabel('Predicted Probability')\n", + "plt.title('LFC-w1a1 Predictions for Image')\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. Call Brevitas FINN-ONNX export and visualize with Netron\n", + "\n", + "Brevitas comes with built-in FINN-ONNX export functionality. This is similar to the regular ONNX export capabilities of PyTorch, with a few differences:\n", + "\n", + "1. The weight quantization logic is not exported as part of the graph; rather, the quantized weights themselves are exported.\n", + "2. Special quantization annotations are used to preserve the low-bit quantization information. ONNX (at the time of writing) supports 8-bit quantization as the minimum bitwidth, whereas FINN-ONNX quantization annotations can go down to binary/bipolar quantization.\n", + "3. Low-bit quantized activation functions are exported as MultiThreshold operators.\n", + "\n", + "It's actually quite straightforward to export ONNX from our Brevitas model as follows:" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stderr", "output_type": "stream", "text": [ - "Requirement already satisfied: netron in /opt/conda/lib/python3.6/site-packages (3.5.9)\r\n" + "/workspace/brevitas_cnv_lfc/training_scripts/models/LFC.py:73: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.\n", + " x = 2.0 * x - torch.tensor([1.0])\n" ] } ], "source": [ - "!pip install --user netron" + "import brevitas.onnx as bo\n", + "export_onnx_path = \"/tmp/LFCW1A1.onnx\"\n", + "input_shape = (1, 1, 28, 28)\n", + "bo.export_finn_onnx(lfc, input_shape, export_onnx_path)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's examine what the exported ONNX model looks like. For this, we will use the Netron visualizer:" ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Serving 'LFCW1A1.onnx' at http://0.0.0.0:8081\n" + "Serving '/tmp/LFCW1A1.onnx' at http://0.0.0.0:8081\n" ] } ], "source": [ "import netron\n", - "netron.start('LFCW1A1.onnx', port=8081, host=\"0.0.0.0\")" + "netron.start(export_onnx_path, port=8081, host=\"0.0.0.0\")" ] }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "<iframe src=\"http://0.0.0.0:8081/\" style=\"position: relative; width: 100%;\" height=\"400\"></iframe>\n" + ], + "text/plain": [ + "<IPython.core.display.HTML object>" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "%%html\n", + "<iframe src=\"http://0.0.0.0:8081/\" style=\"position: relative; width: 100%;\" height=\"400\"></iframe>" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "When running this notebook in the FINN Docker container, you should be able to see an interactive visualization of the imported network above, and click on individual nodes to inspect their parameters. If you look at any of the MatMul nodes, you should be able to see that the weights are all {-1, +1} values, and the activations are Sign functions." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. Import into FINN and call cleanup transformations\n", + "\n", + "We will now import this ONNX model into FINN using the ModelWrapper, and examine some of the graph attributes from Python." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "input: \"32\"\n", + "input: \"33\"\n", + "output: \"35\"\n", + "op_type: \"MatMul\"" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from finn.core.modelwrapper import ModelWrapper\n", + "model = ModelWrapper(export_onnx_path)\n", + "model.graph.node[9]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The ModelWrapper exposes a range of other useful functions as well. For instance, by convention the second input of the MatMul node will be a pre-initialized weight tensor, which we can view using the following:" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[ 1., 1., 1., ..., 1., 1., -1.],\n", + " [ 1., 1., -1., ..., 1., 1., -1.],\n", + " [-1., 1., -1., ..., -1., 1., -1.],\n", + " ...,\n", + " [-1., 1., -1., ..., -1., -1., 1.],\n", + " [ 1., 1., -1., ..., 1., 1., -1.],\n", + " [-1., 1., 1., ..., -1., -1., 1.]], dtype=float32)" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model.get_initializer(model.graph.node[9].input[1])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "If we want to operate further on this model in FINN, it is a good idea to execute certain \"cleanup\" transformations on this graph. Here, we will run shape inference and constant folding on this graph, and visualize the resulting graph in Netron again." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Stopping http://0.0.0.0:8081\n", + "Serving '/tmp/LFCW1A1-clean.onnx' at http://0.0.0.0:8081\n" + ] + } + ], + "source": [ + "from finn.transformation.fold_constants import FoldConstants\n", + "from finn.transformation.infer_shapes import InferShapes\n", + "model = model.transform(InferShapes())\n", + "model = model.transform(FoldConstants())\n", + "export_onnx_path_transformed = \"/tmp/LFCW1A1-clean.onnx\"\n", + "model.save(export_onnx_path_transformed)\n", + "netron.start(export_onnx_path_transformed, port=8081, host=\"0.0.0.0\")" + ] + }, + { + "cell_type": "code", + "execution_count": 15, "metadata": {}, "outputs": [ { @@ -135,6 +617,67 @@ "<iframe src=\"http://0.0.0.0:8081/\" style=\"position: relative; width: 100%;\" height=\"400\"></iframe>" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can see that the resulting graph has become smaller and simpler. Specifically, the input reshaping is now a single Reshape node instead of the Shape -> Gather -> Unsqueeze -> Concat -> Reshape sequence. We can now use the internal ONNX execution capabilities of FINN to ensure that we still get the same output from this model as we did with PyTorch." + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[ 3.3252678 , -2.5652065 , 9.215742 , -1.4251148 , 1.4251148 ,\n", + " -3.3727715 , 0.28502294, -0.5700459 , 7.07807 , -1.2826033 ]],\n", + " dtype=float32)" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import finn.core.onnx_exec as oxe\n", + "input_dict = {\"0\": nph.to_array(input_tensor)}\n", + "output_dict = oxe.execute_onnx(model, input_dict)\n", + "produced_finn = output_dict[list(output_dict.keys())[0]]\n", + "\n", + "produced_finn" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "np.isclose(produced, produced_finn).all()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We have succesfully verified that the transformed and cleaned-up FINN graph still produces the same output, and can now use this model for further processing in FINN." + ] + }, { "cell_type": "code", "execution_count": null,