{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# FINN - End-to-End Flow\n",
    "-----------------------------------------------------------------\n",
    "\n",
    "In this notebook, we will show how to take a simple, binarized, fully-connected network trained on the MNIST data set and take it all the way down to a customized bitfile running on a PYNQ board. \n",
    "\n",
    "This notebook is quite lengthy, and some of the cells (involving Vivado synthesis) may take up to an hour to finish running. To let you save and resume your progress, we will save the intermediate ONNX models that are generated in the various steps to disk, so that you can jump back directly to where you left off.\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Overview\n",
    "\n",
    "The FINN compiler comes with many *transformations* that modify the ONNX representation of the network according to certain patterns. This notebook will demonstrate a *possible* sequence of such transformations to take a particular trained network all the way down to hardware, as shown in the figure below."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![](finn-design-flow-example.svg)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The cylinder-like fields show the state of the network representation in the respective step. The rectangular fields represent the transformations that are applied to the network to achieve a certain result. The diagram is divided into 5 blocks, each of it includes several flow steps. The flow starts in top left corner with Brevitas export (purple block), followed by the preparation of the network (grey block) for the Vivado HLS synthesis and Vivado IPI stitching (yellow block), and finally building a PYNQ overlay bitfile and testing it on a PYNQ board (pink block).\n",
    "There is an additional section for functional verification (green block), which we will not cover in this notebook.\n",
    "\n",
    "\n",
    "This Jupyter notebook is organized based on the sections described above. We will use the following helper functions, `showSrc` to show source code of FINN library calls and `showInNetron` to show the ONNX model at the current transformation step. The Netron displays are interactive, but they only work when running the notebook actively and not on GitHub (i.e. if you are viewing this on GitHub you'll only see blank squares)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 74,
   "metadata": {},
   "outputs": [],
   "source": [
    "import inspect\n",
    "import netron\n",
    "from finn.util.basic import make_build_dir\n",
    "from IPython.display import IFrame\n",
    "\n",
    "def showSrc(what):\n",
    "    print(\"\".join(inspect.getsourcelines(what)[0]))\n",
    "    \n",
    "def showInNetron(model_filename):\n",
    "    netron.start(model_filename, port=8081, host=\"0.0.0.0\")\n",
    "    return IFrame(src=\"http://0.0.0.0:8081/\", width=\"100%\", height=400)\n",
    "    \n",
    "build_dir = \"/workspace/finn\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Outline\n",
    "-------------\n",
    "1. [Brevitas export](#brev_exp)\n",
    "2. [Network preparation](#nw_prep)\n",
    "    * Basic transformations\n",
    "    * Streamlining\n",
    "    * Conversion to HLS layers\n",
    "    * Folding\n",
    "3. [Vivado HLS and Vivado IPI](#vivado)\n",
    "    * HLS IP per layer\n",
    "    * Creation of stitched design\n",
    "4. [Synthesize, Deploy and Test on PYNQ](#hw_test)\n",
    "    * PYNQ shell project\n",
    "    * Synthesis, place and route\n",
    "    * Driver generation\n",
    "    * Deployment and remote execution"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Brevitas export <a id='brev_exp'></a>\n",
    "FINN expects an ONNX model as input. This can be a model trained with [Brevitas](https://github.com/Xilinx/brevitas). Brevitas is a PyTorch library for quantization-aware training and the FINN Docker image comes with several [example Brevitas networks](https://github.com/maltanar/brevitas_cnv_lfc). To show the FINN end-to-end flow, we'll use the TFC-w1a1 model as example network.\n",
    "\n",
    "First a few things have to be imported. Then the model can be loaded with the pretrained weights."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/workspace/brevitas_cnv_lfc/training_scripts/models/TFC.py:73: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.\n",
      "  x = 2.0 * x - torch.tensor([1.0])\n"
     ]
    }
   ],
   "source": [
    "import onnx\n",
    "from finn.util.test import get_test_model_trained\n",
    "import brevitas.onnx as bo\n",
    "\n",
    "tfc = get_test_model_trained(\"TFC\", 1, 1)\n",
    "bo.export_finn_onnx(tfc, (1, 1, 28, 28), build_dir+\"/tfc_w1_a1.onnx\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The model was now exported, loaded with the pretrained weights and saved under the name \"lfc_w1_a1.onnx\".\n",
    "To visualize the exported model, Netron can be used. Netron is a visualizer for neural networks and allows interactive investigation of network properties. For example, you can click on the individual nodes and view the properties."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 82,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Stopping http://0.0.0.0:8081\n",
      "Serving '/workspace/finn/tfc_w1_a1.onnx' at http://0.0.0.0:8081\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "        <iframe\n",
       "            width=\"100%\"\n",
       "            height=\"400\"\n",
       "            src=\"http://0.0.0.0:8081/\"\n",
       "            frameborder=\"0\"\n",
       "            allowfullscreen\n",
       "        ></iframe>\n",
       "        "
      ],
      "text/plain": [
       "<IPython.lib.display.IFrame at 0x7fc62d60e5c0>"
      ]
     },
     "execution_count": 82,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "showInNetron(build_dir+\"/tfc_w1_a1.onnx\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now that we have the model in .onnx format, we can work with it using FINN. For that FINN `ModelWrapper` is used. It is a wrapper around the ONNX model which provides several helper functions to make it easier to work with the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 83,
   "metadata": {},
   "outputs": [],
   "source": [
    "from finn.core.modelwrapper import ModelWrapper\n",
    "model = ModelWrapper(build_dir+\"/tfc_w1_a1.onnx\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now the model is prepared and could be simulated using Python. How this works is described in subsection [Simulation using Python](#simpy) in the section about *Simulation & Emulation flows for functional verification*.\n",
    "\n",
    "The model can now also be processed in different ways. The principle of FINN are analysis and transformation passes, which can be applied to the model. An analysis pass extracts specific information about the model and returns it to the user in the form of a dictionary. A transformation pass changes the model and returns the changed model back to the FINN flow.\n",
    "\n",
    "Since the goal in this notebook is to process the model to such an extent that a bitstream can be generated from it, the focus is on the transformations that are necessary for this. In the next section these are discussed in more detail."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Network preparation <a id='nw_prep'></a>\n",
    "\n",
    "* [Tidy-up transformations](#basic_trafo)\n",
    "* [Streamlining](#streamline)\n",
    "* [Conversion to HLS layers](#hls_layers)\n",
    "* [Folding](#folding)\n",
    "\n",
    "\n",
    "In this section, we will put the network through a series of transformations that puts it in a form that can be stitched together to form a FINN-style dataflow architecture, yielding a high-performance, high-efficiency FPGA accelerator."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### FINN-style Dataflow Architectures\n",
    "\n",
    "We start with a quick recap of FINN-style dataflow architectures. The key idea in such architectures is to parallelize across layers as well as within layers by dedicating a proportionate amount of compute resources to each layer, as illustrated in the figure below taken from the [FINN-R paper](https://arxiv.org/pdf/1809.04570.pdf):\n",
    "\n",
    "![](finn-hw-arch.png)\n",
    "\n",
    "In practice, the compute arrays are instantiated by function calls to optimized Vivado HLS building blocks from the [finn-hlslib](https://github.com/Xilinx/finn-hlslib) library. As these function calls can only handle certain patterns/cases, we need to transform the network into an appropriate form so that we can replace network layers with these function calls, which is the goal of the network preparation process."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Tidy-up transformations <a id='basic_trafo'></a>\n",
    "This section deals with some basic transformations, which are applied to the model like a kind of \"tidy-up\" to make it easier to be processed. They do not appear in the diagram above, but they are applied in many steps in the FINN flow to postprocess the model after a transformation and/or to prepare it for the next transformation."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "These transformations are:\n",
    "* GiveUniqueNodeNames\n",
    "* GiveReadableTensorNames\n",
    "* InferShapes\n",
    "* InferDataTypes\n",
    "* FoldConstants"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In the first two transformations (`GiveUniqueNodeNames`, `GiveReadableTensorNames`) the nodes in the graph are first given unique (by enumeration) names, then the tensors are given human-readable names (based on the node names). The following two transformations (`InferShapes`, `InferDataTypes`) derive the shapes and data types of the tensors from the model properties and set them in the `ValueInfo` of the model. These transformations can almost always be applied without negative effects and do not affect the structure of the graph, ensuring that all the information needed is available.\n",
    "\n",
    "The last listed transformation is `FoldConstants`, which performs constant folding. It identifies a node with constant inputs and determines its output. The result is then set as constant-only inputs for the following node and the old node is removed. Although this transformation changes the structure of the model, it is a transformation that is usually always desired and can be applied to any model."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "These transformations can be imported and applied as follows."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 84,
   "metadata": {},
   "outputs": [],
   "source": [
    "from finn.transformation.general import GiveReadableTensorNames, GiveUniqueNodeNames\n",
    "from finn.transformation.infer_shapes import InferShapes\n",
    "from finn.transformation.infer_datatypes import InferDataTypes\n",
    "from finn.transformation.fold_constants import FoldConstants\n",
    "\n",
    "model = model.transform(InferShapes())\n",
    "model = model.transform(FoldConstants())\n",
    "model = model.transform(GiveUniqueNodeNames())\n",
    "model = model.transform(GiveReadableTensorNames())\n",
    "model = model.transform(InferDataTypes())\n",
    "\n",
    "model.save(build_dir+\"/tfc_w1_a1_tidy.onnx\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The result of these transformations can be viewed with netron after the model has been saved again. By clicking on the individual nodes, it can now be seen, for example, that each node has been given a name. Also the whole upper area could be folded, so that now the first node is \"Reshape\"."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 85,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Stopping http://0.0.0.0:8081\n",
      "Serving '/workspace/finn/tfc_w1_a1_tidy.onnx' at http://0.0.0.0:8081\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "        <iframe\n",
       "            width=\"100%\"\n",
       "            height=\"400\"\n",
       "            src=\"http://0.0.0.0:8081/\"\n",
       "            frameborder=\"0\"\n",
       "            allowfullscreen\n",
       "        ></iframe>\n",
       "        "
      ],
      "text/plain": [
       "<IPython.lib.display.IFrame at 0x7fc6c4430828>"
      ]
     },
     "execution_count": 85,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "showInNetron(build_dir+\"/tfc_w1_a1_tidy.onnx\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Streamlining <a id='streamline'></a>\n",
    "Streamlining is a transformation containing several sub-transformations. The goal of streamlining is to eliminate floating point operations by moving them around, then collapsing them into one operation and in the last step transform them into multi-thresholding nodes. For more information on the theoretical background of this, see [this paper](https://arxiv.org/pdf/1709.04060).\n",
    "\n",
    "Let's have a look at which sub-transformations `Streamline` consists of:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 86,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "class Streamline(Transformation):\n",
      "    \"\"\"Apply the streamlining transform, see arXiv:1709.04060.\"\"\"\n",
      "\n",
      "    def apply(self, model):\n",
      "        streamline_transformations = [\n",
      "            ConvertSubToAdd(),\n",
      "            BatchNormToAffine(),\n",
      "            ConvertSignToThres(),\n",
      "            MoveScalarAddPastMatMul(),\n",
      "            MoveScalarMulPastMatMul(),\n",
      "            MoveAddPastMul(),\n",
      "            CollapseRepeatedAdd(),\n",
      "            CollapseRepeatedMul(),\n",
      "            AbsorbAddIntoMultiThreshold(),\n",
      "            FactorOutMulSignMagnitude(),\n",
      "            AbsorbMulIntoMultiThreshold(),\n",
      "            Absorb1BitMulIntoMatMul(),\n",
      "            RoundAndClipThresholds(),\n",
      "        ]\n",
      "        for trn in streamline_transformations:\n",
      "            model = model.transform(trn)\n",
      "            model = model.transform(GiveUniqueNodeNames())\n",
      "            model = model.transform(GiveReadableTensorNames())\n",
      "            model = model.transform(InferDataTypes())\n",
      "        return (model, False)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from finn.transformation.streamline import Streamline\n",
    "showSrc(Streamline)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As can be seen, several transformations are involved in the streamlining transformation. There are move and collapse transformations. In the last step the operations are transformed into multithresholds. The involved transformations can be viewed in detail [here](https://github.com/Xilinx/finn/tree/dev/src/finn/transformation/streamline). After each transformation, three of the tidy-up transformations (`GiveUniqueNodeNames`, `GiveReadableTensorNames` and `InferDataTypes`) are applied to the model.\n",
    "\n",
    "After streamlining the network looks as follows:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 104,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Stopping http://0.0.0.0:8081\n",
      "Serving '/workspace/finn/tfc_w1_a1_streamlined.onnx' at http://0.0.0.0:8081\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "        <iframe\n",
       "            width=\"100%\"\n",
       "            height=\"400\"\n",
       "            src=\"http://0.0.0.0:8081/\"\n",
       "            frameborder=\"0\"\n",
       "            allowfullscreen\n",
       "        ></iframe>\n",
       "        "
      ],
      "text/plain": [
       "<IPython.lib.display.IFrame at 0x7fc65422bf60>"
      ]
     },
     "execution_count": 104,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model = ModelWrapper(build_dir+\"/tfc_w1_a1_tidy.onnx\")\n",
    "model = model.transform(Streamline())\n",
    "model.save(build_dir+\"/tfc_w1_a1_streamlined.onnx\")\n",
    "showInNetron(build_dir+\"/tfc_w1_a1_streamlined.onnx\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can see that the network has become simplified considerably compared to the previous step -- a lot of nodes have disappeared between the `MatMul` layers, and the `Sign` nodes have been replaced with `MultiThreshold` nodes instead. \n",
    "\n",
    "**The current implementation of streamlining is highly network-specific and may not work for your network if its topology is very different than the example network here. We hope to rectify this in future releases.**\n",
    "\n",
    "Our example network is a quantized network with 1-bit bipolar (-1, +1 values) precision, and we want FINN to implement them as XNOR-popcount operations [as described in the original FINN paper](https://arxiv.org/pdf/1612.07119). For this reason, after streamlining, the resulting bipolar matrix multiplications are converted into xnorpopcount operations. This transformation produces operations that are again collapsed and converted into thresholds. This procedure is shown below. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 105,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Stopping http://0.0.0.0:8081\n",
      "Serving '/workspace/finn/tfc_w1a1_ready_for_hls_conversion.onnx' at http://0.0.0.0:8081\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "        <iframe\n",
       "            width=\"100%\"\n",
       "            height=\"400\"\n",
       "            src=\"http://0.0.0.0:8081/\"\n",
       "            frameborder=\"0\"\n",
       "            allowfullscreen\n",
       "        ></iframe>\n",
       "        "
      ],
      "text/plain": [
       "<IPython.lib.display.IFrame at 0x7fc64d8d9f98>"
      ]
     },
     "execution_count": 105,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from finn.transformation.bipolar_to_xnor import ConvertBipolarMatMulToXnorPopcount\n",
    "import finn.transformation.streamline.absorb as absorb\n",
    "from finn.transformation.streamline.round_thresholds import RoundAndClipThresholds\n",
    "\n",
    "model = model.transform(ConvertBipolarMatMulToXnorPopcount())\n",
    "model = model.transform(absorb.AbsorbAddIntoMultiThreshold())\n",
    "model = model.transform(absorb.AbsorbMulIntoMultiThreshold())\n",
    "model = model.transform(RoundAndClipThresholds())\n",
    "\n",
    "model.save(build_dir+\"/tfc_w1a1_ready_for_hls_conversion.onnx\")\n",
    "showInNetron(build_dir+\"/tfc_w1a1_ready_for_hls_conversion.onnx\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Observe the pairs of `XnorPopcountmatMul` and `MultiThreshold` layers following each other -- this is the particular pattern that the next step will be looking for in order to convert them to HLS layers."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Conversion to HLS layers <a id='hls_layers'></a>\n",
    "Converts the nodes to HLS layers that correspond to the functions in [finn-hls library](https://finn-hlslib.readthedocs.io/en/latest/). In our case this transformation onverts pairs of binary XnorPopcountMatMul layers to StreamingFCLayer_Batch layers. Any immediately following MultiThreshold layers will also be absorbed into the MVTU.\n",
    "\n",
    "Below is the code for the transformation and the network is visualized using netron to create the new structure with `StreamingFCLayer_Batch` nodes, which will correspond to a function call from the [finn-hlslib](https://finn-hlslib.readthedocs.io/en/latest/library/fclayer.html#_CPPv4I_j_j_j_j000_i_i000E22StreamingFCLayer_BatchvRN3hls6streamI7ap_uintI9InStreamWEEERN3hls6streamI7ap_uintI10OutStreamWEEERK2TWRK2TAKjRK1R) library."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 90,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Stopping http://0.0.0.0:8081\n",
      "Serving '/workspace/finn/tfc_w1_a1_hls_layers.onnx' at http://0.0.0.0:8081\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "        <iframe\n",
       "            width=\"100%\"\n",
       "            height=\"400\"\n",
       "            src=\"http://0.0.0.0:8081/\"\n",
       "            frameborder=\"0\"\n",
       "            allowfullscreen\n",
       "        ></iframe>\n",
       "        "
      ],
      "text/plain": [
       "<IPython.lib.display.IFrame at 0x7fc62d60eb70>"
      ]
     },
     "execution_count": 90,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import finn.transformation.fpgadataflow.convert_to_hls_layers as to_hls\n",
    "model = ModelWrapper(build_dir+\"/tfc_w1a1_ready_for_hls_conversion.onnx\")\n",
    "model = model.transform(to_hls.InferBinaryStreamingFCLayer())\n",
    "model.save(build_dir+\"/tfc_w1_a1_hls_layers.onnx\")\n",
    "showInNetron(build_dir+\"/tfc_w1_a1_hls_layers.onnx\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Each StreamingFCLayer_Batch node has two attributes that specify the degree of folding, PE and SIMD. In all nodes the values for these attributes are set as default to 1, which would correspond to a maximum folding (time multiplexing) and thus minimum performance. We will shortly cover how these can be adjusted, but first we want to separate the HLS layers from the non-HLS layers in this network."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Creating a Dataflow Partition <a id='dataflow_partition'></a>\n",
    "\n",
    "In the graph above, you can see that there is a mixture of FINN HLS layers (StreamingFCLayer_Batch) with regular ONNX layers (Reshape, Mul, Add). To create a bitstream, FINN needs a model with only HLS layers. In order to achieve this, we will use the `CreateDataflowPartition` transformation to create a \"dataflow partition\" in this graph, separating out the HLS layers into another model, and replacing them with a placeholder layer called StreamingDataflowPartition:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 91,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Stopping http://0.0.0.0:8081\n",
      "Serving '/workspace/finn/tfc_w1_a1_dataflow_parent.onnx' at http://0.0.0.0:8081\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "        <iframe\n",
       "            width=\"100%\"\n",
       "            height=\"400\"\n",
       "            src=\"http://0.0.0.0:8081/\"\n",
       "            frameborder=\"0\"\n",
       "            allowfullscreen\n",
       "        ></iframe>\n",
       "        "
      ],
      "text/plain": [
       "<IPython.lib.display.IFrame at 0x7fc62d60e1d0>"
      ]
     },
     "execution_count": 91,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from finn.transformation.fpgadataflow.create_dataflow_partition import CreateDataflowPartition\n",
    "\n",
    "model = ModelWrapper(build_dir+\"/tfc_w1_a1_hls_layers.onnx\")\n",
    "parent_model = model.transform(CreateDataflowPartition())\n",
    "parent_model.save(build_dir+\"/tfc_w1_a1_dataflow_parent.onnx\")\n",
    "showInNetron(build_dir+\"/tfc_w1_a1_dataflow_parent.onnx\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can see that the StreamingFCLayer instances have all been replaced with a single `StreamingDataflowPartition`, which has an attribute `model` that points to the extracted, HLS dataflow-only graph:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 92,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Stopping http://0.0.0.0:8081\n",
      "Serving '/tmp/finn_maltanar/dataflow_partition_l2y9b77c/df_model.onnx' at http://0.0.0.0:8081\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "        <iframe\n",
       "            width=\"100%\"\n",
       "            height=\"400\"\n",
       "            src=\"http://0.0.0.0:8081/\"\n",
       "            frameborder=\"0\"\n",
       "            allowfullscreen\n",
       "        ></iframe>\n",
       "        "
      ],
      "text/plain": [
       "<IPython.lib.display.IFrame at 0x7fc62d60e320>"
      ]
     },
     "execution_count": 92,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from finn.custom_op.registry import getCustomOp\n",
    "sdp_node = getCustomOp(parent_model.graph.node[2])\n",
    "dataflow_model_filename = sdp_node.get_nodeattr(\"model\")\n",
    "showInNetron(dataflow_model_filename)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can see all the extracted `StreamingFCLayer` instances have been moved to the child (dataflow) model. We will load the child model with `ModelWrapper` and continue working on it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 93,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = ModelWrapper(dataflow_model_filename)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Folding and TLastMarker Insertion <a id='folding'></a>\n",
    "\n",
    "*Folding* in FINN describes how much a layer is time-multiplexed in terms of execution resources. There are several *folding factors* for each layer, controlled by the PE (parallelization over outputs) and SIMD (parallelization over inputs) parameters as described by the original [FINN paper](https://arxiv.org/pdf/1612.07119). The higher the PE and SIMD values are set, the faster the generated accelerator will run, and the more FPGA resources it will consume. \n",
    "\n",
    "Since the folding parameters are node attributes, they can be easily accessed and changed using a helper function of the `ModelWrapper`. But first we have to extract the nodes which are StreamingFCLayer_Batch operations. This is where the Netron visualization helps us, in the above diagram we can see that the first four nodes are StreamingFCLayer_Batch. Through the `print`s we can check if the extracted nodes all have the op_type \"StreamingFCLayer_Batch\"."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 94,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fc0 has the op_type: StreamingFCLayer_Batch\n",
      "fc1 has the op_type: StreamingFCLayer_Batch\n",
      "fc2 has the op_type: StreamingFCLayer_Batch\n",
      "fc3 has the op_type: StreamingFCLayer_Batch\n"
     ]
    }
   ],
   "source": [
    "fc0 = model.graph.node[0]\n",
    "fc1 = model.graph.node[1]\n",
    "fc2 = model.graph.node[2]\n",
    "fc3 = model.graph.node[3]\n",
    "print(\"fc0 has the op_type: \" + str(fc0.op_type))\n",
    "print(\"fc1 has the op_type: \" + str(fc1.op_type))\n",
    "print(\"fc2 has the op_type: \" + str(fc2.op_type))\n",
    "print(\"fc3 has the op_type: \" + str(fc3.op_type))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can use the higher-level [HLSCustomOp](https://github.com/Xilinx/finn/blob/dev/src/finn/custom_op/fpgadataflow/__init__.py) wrappers for these nodes. These wrappers provide easy access to specific properties of these nodes, such as the folding factors (PE and SIMD). Let's have a look at which node attributes are defined by the CustomOp wrapper, and adjust the SIMD and PE attributes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 95,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CustomOp wrapper is of class StreamingFCLayer_Batch\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'PE': ('i', True, 0),\n",
       " 'SIMD': ('i', True, 0),\n",
       " 'MW': ('i', True, 0),\n",
       " 'MH': ('i', True, 0),\n",
       " 'resType': ('s', True, ''),\n",
       " 'ActVal': ('i', False, 0),\n",
       " 'inputDataType': ('s', True, ''),\n",
       " 'weightDataType': ('s', True, ''),\n",
       " 'outputDataType': ('s', True, ''),\n",
       " 'binaryXnorMode': ('i', False, 0),\n",
       " 'noActivation': ('i', False, 0),\n",
       " 'inFIFODepth': ('i', False, 0),\n",
       " 'outFIFODepth': ('i', False, 0),\n",
       " 'backend': ('s', True, 'fpgadataflow'),\n",
       " 'code_gen_dir_npysim': ('s', False, ''),\n",
       " 'code_gen_dir_ipgen': ('s', False, ''),\n",
       " 'executable_path': ('s', False, ''),\n",
       " 'ipgen_path': ('s', False, ''),\n",
       " 'exec_mode': ('s', False, ''),\n",
       " 'sim_cycles': ('i', False, 0),\n",
       " 'rtlsim_trace': ('s', False, '')}"
      ]
     },
     "execution_count": 95,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "fc0w = getCustomOp(fc0)\n",
    "fc1w = getCustomOp(fc1)\n",
    "fc2w = getCustomOp(fc2)\n",
    "fc3w = getCustomOp(fc3)\n",
    "\n",
    "print(\"CustomOp wrapper is of class \" + fc0w.__class__.__name__)\n",
    "\n",
    "fc0w.get_nodeattr_types()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can see that the PE and SIMD are listed as node attributes, as well as the depths of the FIFOs that will be inserted between consecutive layers, and all can be adjusted using `set_nodeattr` subject to certain constraints.\n",
    "**In this notebook we are setting the folding factors and FIFO depths manually, but in a future version we will support determining the folding factors given an FPGA resource budget according to the analytical model from the [FINN-R paper](https://arxiv.org/pdf/1809.04570).**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 96,
   "metadata": {},
   "outputs": [],
   "source": [
    "# SIMD controls the folding over the input vector\n",
    "# PE controls the folding over the output vector\n",
    "\n",
    "fc0w.set_nodeattr(\"inFIFODepth\", 50)\n",
    "fc0w.set_nodeattr(\"SIMD\", 16)\n",
    "fc0w.set_nodeattr(\"PE\", 16)\n",
    "fc0w.set_nodeattr(\"outFIFODepth\", 4)\n",
    "\n",
    "fc1w.set_nodeattr(\"SIMD\", 16)\n",
    "fc1w.set_nodeattr(\"PE\", 16)\n",
    "fc1w.set_nodeattr(\"outFIFODepth\", 4)\n",
    "\n",
    "fc2w.set_nodeattr(\"SIMD\", 16)\n",
    "fc2w.set_nodeattr(\"PE\", 16)\n",
    "fc2w.set_nodeattr(\"outFIFODepth\", 4)\n",
    "\n",
    "fc3w.set_nodeattr(\"SIMD\", 16)\n",
    "fc3w.set_nodeattr(\"PE\", 10)\n",
    "fc3w.set_nodeattr(\"outFIFODepth\", 50)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, we will run the `InsertTLastMarker` transformation to get a `TLastMarker` node at the output of this graph, which is necessary to run the DMA engines correctly. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 98,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Stopping http://0.0.0.0:8081\n",
      "Serving '/workspace/finn/tfc_w1_a1_set_folding_factors.onnx' at http://0.0.0.0:8081\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "        <iframe\n",
       "            width=\"100%\"\n",
       "            height=\"400\"\n",
       "            src=\"http://0.0.0.0:8081/\"\n",
       "            frameborder=\"0\"\n",
       "            allowfullscreen\n",
       "        ></iframe>\n",
       "        "
      ],
      "text/plain": [
       "<IPython.lib.display.IFrame at 0x7fc654223780>"
      ]
     },
     "execution_count": 98,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from finn.transformation.fpgadataflow.insert_tlastmarker import InsertTLastMarker\n",
    "model = model.transform(InsertTLastMarker())\n",
    "model.save(build_dir+\"/tfc_w1_a1_set_folding_factors.onnx\")\n",
    "showInNetron(build_dir+\"/tfc_w1_a1_set_folding_factors.onnx\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This completes the network preparation and the network can be passed on to the next block *Vivado HLS and Vivado synthesis*, which is described below."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Vivado HLS and Vivado IPI <a id='vivado'></a>\n",
    "* [Generating HLS Code](#hls_per_layer)\n",
    "* [Synthesizing HLS to IP Blocks](#hls_synth)\n",
    "* [IP Stitching](#ip_stitching)\n",
    "\n",
    "As we will be dealing with FPGA synthesis tools in these tasks, we'll define two helper variables that describe the Xilinx FPGA part name and the PYNQ board name that we are targeting."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "dict_keys(['Ultra96', 'Pynq-Z1'])\n"
     ]
    }
   ],
   "source": [
    "# print the names of the supported PYNQ boards\n",
    "from finn.util.basic import pynq_part_map\n",
    "print(pynq_part_map.keys())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# change this if you have a different PYNQ board, see list above\n",
    "pynq_board = \"Ultra96\"\n",
    "fpga_part = pynq_part_map[pynq_board]\n",
    "target_clk_ns = 5"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Generating HLS Code <a id='hls_per_layer'></a>\n",
    "This section deals with the generation of an IP block from the different layers. These can then be stitched to a block design that corresponds to the complete model. The single conversion into IP blocks allows a good transparency and we can check the functionality of each IP block and compare it with the behaviour of the corresponding ONNX node. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Two transformations are required to generate HLS IP blocks for each layer: \n",
    "* `CodeGen_ipgen` which generates the HLS C++ code for the node and a tcl-script which starts the HLS synthesis and exports the design as IP. \n",
    "* `HLSSynth_IPGen` which passes the tcl-script to Vivado HLS and thus performs the actual IP generation. \n",
    "\n",
    "We start off by giving unique node names using the basic transformation `GiveUniqueNodeNames`, and then proceed with the HLS C++ code generation with `CodeGen_ipgen`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 99,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = ModelWrapper(build_dir+\"/tfc_w1_a1_set_folding_factors.onnx\")\n",
    "model = model.transform(GiveUniqueNodeNames())\n",
    "\n",
    "from finn.transformation.fpgadataflow.codegen_ipgen import CodeGen_ipgen\n",
    "model = model.transform(CodeGen_ipgen(fpga_part, target_clk_ns))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Synthesizing HLS to IP Blocks <a id='hls_synth'></a>\n",
    "\n",
    "Now that we have generated the HLS code for each layer, we can call the `HLSSynth_IPGen` transformation to convert the generated HLS into Vivado IP blocks. **As this involves calling HLS synthesis, this transformation will run for some time (several minutes).**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 100,
   "metadata": {},
   "outputs": [],
   "source": [
    "from finn.transformation.fpgadataflow.hlssynth_ipgen import HLSSynth_IPGen\n",
    "\n",
    "model = model.transform(HLSSynth_IPGen())\n",
    "model.save(build_dir+\"/tfc_w1_a1_ipgen.onnx\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Each `StreamingFCLayer_Batch` node now has new attributes which can be examined more closely with netron."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 106,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Stopping http://0.0.0.0:8081\n",
      "Serving '/workspace/finn/tfc_w1_a1_ipgen.onnx' at http://0.0.0.0:8081\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "        <iframe\n",
       "            width=\"100%\"\n",
       "            height=\"400\"\n",
       "            src=\"http://0.0.0.0:8081/\"\n",
       "            frameborder=\"0\"\n",
       "            allowfullscreen\n",
       "        ></iframe>\n",
       "        "
      ],
      "text/plain": [
       "<IPython.lib.display.IFrame at 0x7fc65422be48>"
      ]
     },
     "execution_count": 106,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "showInNetron(build_dir+\"/tfc_w1_a1_ipgen.onnx\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "There are two additional attributes: \n",
    "* `code_gen_dir_ipgen` which contains the directory path where all the files generated by the ipgen transformations are stored\n",
    "* `ipgen_path` which contains the path to the project directory in which the generated IP block is stored\n",
    "\n",
    "We can further investigate which files are produced by taking a look in this directory. For example for the first StreamingFCLayer_Batch node."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "hls_syn_StreamingFCLayer_Batch_0.tcl  thresh.h\r\n",
      "ipgen.sh\t\t\t      top_StreamingFCLayer_Batch_0.cpp\r\n",
      "params.h\t\t\t      vivado_hls.log\r\n",
      "project_StreamingFCLayer_Batch_0\r\n"
     ]
    }
   ],
   "source": [
    "fc0w = getCustomOp(model.graph.node[0])\n",
    "code_gen_dir = fc0w.get_nodeattr(\"code_gen_dir_ipgen\")\n",
    "!ls {code_gen_dir}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Directory *project_StreamingFCLayer_Batch_0* contains the project created by Vivado HLS into which the IP Block is exported, along with other files generated by Vivado HLS. If we compare it to the above visualization of the network with netron, this is exactly the name of the folder stored in the node attribute `ipgen_path`. The .cpp code that is passed to Vivado HLS can be found in the file *top_StreamingFCLayer_Batch_0.cpp*. The files *params.h* and *thresh.h* belong to that as well, they contain the values for the weights and thresholds. *vivado_hls.log* is the log file from Vivado HLS. Besides these files, the folder contains *ipgen.sh* and *hls_syn_StreamingFCLayer_Batch_0.tcl*. First we take a look at *ipgen.sh*."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "#!/bin/bash \r\n",
      "cd /tmp/finn_maltanar/code_gen_ipgen_StreamingFCLayer_Batch_hc367wg4\r\n",
      "vivado_hls /tmp/finn_maltanar/code_gen_ipgen_StreamingFCLayer_Batch_hc367wg4/hls_syn_StreamingFCLayer_Batch_0.tcl\r\n",
      "cd /workspace/finn\r\n"
     ]
    }
   ],
   "source": [
    "shell_script = code_gen_dir + \"/ipgen.sh\"\n",
    "!cat {shell_script}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The script consists only of two framing `cd` commands and a command to pass the tcl script to *vivado_hls*. The directory has to be changed to create the files in the correct folder and will then be changed back to the original directory. \n",
    "\n",
    "Below is the tcl script which is passed to *vivado_hls*."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r\n",
      "set config_proj_name project_StreamingFCLayer_Batch_0\r\n",
      "puts \"HLS project: $config_proj_name\"\r\n",
      "set config_hwsrcdir \"/tmp/finn_maltanar/code_gen_ipgen_StreamingFCLayer_Batch_hc367wg4\"\r\n",
      "puts \"HW source dir: $config_hwsrcdir\"\r\n",
      "set config_proj_part \"xczu3eg-sbva484-1-e\"\r\n",
      "\r\n",
      "set config_bnnlibdir \"/workspace/finn-hlslib\"\r\n",
      "\r\n",
      "set config_toplevelfxn \"StreamingFCLayer_Batch_0\"\r\n",
      "set config_clkperiod 5\r\n",
      "\r\n",
      "open_project $config_proj_name\r\n",
      "add_files $config_hwsrcdir/top_StreamingFCLayer_Batch_0.cpp -cflags \"-std=c++0x -I$config_bnnlibdir\"\r\n",
      "\r\n",
      "set_top $config_toplevelfxn\r\n",
      "open_solution sol1\r\n",
      "set_part $config_proj_part\r\n",
      "\r\n",
      "config_interface -m_axi_addr64\r\n",
      "config_rtl -auto_prefix\r\n",
      "\r\n",
      "create_clock -period $config_clkperiod -name default\r\n",
      "csynth_design\r\n",
      "export_design -format ip_catalog\r\n",
      "exit 0\r\n"
     ]
    }
   ],
   "source": [
    "tcl_script = code_gen_dir + \"/hls_syn_StreamingFCLayer_Batch_0.tcl\"\n",
    "!cat {tcl_script}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In the first part of the script the project is configured. For example the FPGA part and the clock are set. Then the project is opened and the files are added. The toplevel function is set and after creating a clock, the design is first synthesized with `csynth` and then exported as an IP block.\n",
    "\n",
    "Now that all IP blocks are in place, they can be stitched together to create an IP design that matches the ONNX model. This is covered in the next section."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### IP Stitching <a id='ip_stitching'></a>\n",
    "\n",
    "We now have IP blocks for each of our layers, and will stitch them together into a larger IP that implements the whole network using the `CodeGen_ipstitch` transformation. Bear in mind that this transformation can only be applied on a graph that only contains HLS nodes that already have been through the `HLSSynth_IPGen` transformation, which is the last step we performed. Prior to calling IP stitching, we'll also use the `ReplaceVerilogRelPaths` transformation to convert any relative `$readmemh` paths in the generated IP blocks to absolute ones, which prevents errors later on. **This step invokes Vivado and may take a few minutes to run.**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "from finn.transformation.fpgadataflow.codegen_ipstitch import CodeGen_ipstitch\n",
    "from finn.transformation.fpgadataflow.replace_verilog_relpaths import ReplaceVerilogRelPaths\n",
    "model = ModelWrapper(build_dir+\"/tfc_w1_a1_ipgen.onnx\")\n",
    "model = model.transform(ReplaceVerilogRelPaths())\n",
    "model = model.transform(CodeGen_ipstitch(fpga_part))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If you examine the nodes themselves on the transformed model you won't see a difference, because the IP stitching adds model-level metadata to the graph. This can be accessed using the `.model.metadata_props`, the `get_metadata_prop` function in `ModelWrapper`, or by clicking on the global input/output tensors in Netron."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[key: \"vivado_stitch_proj\"\n",
       "value: \"/tmp/finn_maltanar/vivado_stitch_proj_n3me5eke\"\n",
       ", key: \"vivado_stitch_vlnv\"\n",
       "value: \"xilinx_finn:finn:finn_design:1.0\"\n",
       ", key: \"wrapper_filename\"\n",
       "value: \"/tmp/finn_maltanar/vivado_stitch_proj_n3me5eke/finn_vivado_stitch_proj.srcs/sources_1/bd/finn_design/hdl/finn_design_wrapper.v\"\n",
       "]"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.model.metadata_props"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'/tmp/finn_maltanar/vivado_stitch_proj_n3me5eke'"
      ]
     },
     "execution_count": 34,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.get_metadata_prop(\"vivado_stitch_proj\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If you navigate to the folder above (remember the /tmp/finn_xxx folder is mounted on the host as well as inside Docker) you can open the Vivado project (.xpr) file there using Vivado, and view the following stitched IP block design:"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![](stitched_ip.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.save(build_dir+\"/tfc_w1_a1_ipstitch.onnx\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "At this point, one could take the generated stitched IP and integrate it into your own project using Vivado IP Integrator if desired. Here, we will continue the tutorial by assuming that we want to do a stand-alone deployment for this accelerator for a PYNQ board."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4.  Synthesize, Deploy and Test on PYNQ <a id='hw_test'></a>\n",
    "\n",
    "* [Inserting the IP into a PYNQ Overlay Shell](#pynq_shell)\n",
    "* [Synthesis, place and route](#synth_pl_ro)\n",
    "* [Driver Generation](#driver_gen)\n",
    "* [Deployment and Remote Execution](#deploy)\n",
    "\n",
    "\n",
    "We are almost done preparing our hardware design. We'll now put it in a form suitable for use as a PYNQ overlay, synthesize and deploy it."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Inserting the IP into a PYNQ Overlay Shell <a id='pynq_shell'></a>\n",
    "\n",
    "We are almost done preparing our hardware design. To deploy our accelerator on a PYNQ platform, it needs to be put inside an appropriate *shell* that bridges it with the interfaces that the underlying system exposes. FINN makes it easy to create a PYNQ-compatible overlay by inserting the stitched IP into an appropriate PYNQ shell with the `MakePYNQProject` transformation, and view the created PYNQ shell project directory using the `metadata_props`. **This invokes Vivado and may take a few minutes to run.**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[key: \"vivado_stitch_proj\"\n",
       "value: \"/tmp/finn_maltanar/vivado_stitch_proj_n3me5eke\"\n",
       ", key: \"vivado_stitch_vlnv\"\n",
       "value: \"xilinx_finn:finn:finn_design:1.0\"\n",
       ", key: \"wrapper_filename\"\n",
       "value: \"/tmp/finn_maltanar/vivado_stitch_proj_n3me5eke/finn_vivado_stitch_proj.srcs/sources_1/bd/finn_design/hdl/finn_design_wrapper.v\"\n",
       ", key: \"vivado_pynq_proj\"\n",
       "value: \"/tmp/finn_maltanar/vivado_pynq_proj_hqlnpt5q\"\n",
       "]"
      ]
     },
     "execution_count": 36,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from finn.transformation.fpgadataflow.make_pynq_proj import MakePYNQProject\n",
    "model = ModelWrapper(build_dir+\"/tfc_w1_a1_ipstitch.onnx\")\n",
    "model = model.transform(MakePYNQProject(pynq_board))\n",
    "model.model.metadata_props"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ip_config.tcl\t resizer.cache\tresizer.ip_user_files  resizer.xpr\r\n",
      "make_project.sh  resizer.hw\tresizer.srcs\t       synth_project.sh\r\n"
     ]
    }
   ],
   "source": [
    "! ls {model.get_metadata_prop(\"vivado_pynq_proj\")}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If we open the created Vivado project (.xpr) under the `vivado_pynq_proj` directory above, we can see the system-level block design as below, with the FINN-generated part of the design highlighted. Various other components, such as the DMA engine and data width converters, have also been instantiated.\n",
    "![](pynq_shell_project.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.save(build_dir + \"/tfc_w1_a1_pynq_project.onnx\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Synthesis, place and route <a id='synth_pl_ro'></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We are now ready for the final hardware generation step, which is synthesis, place and route to generate an FPGA bitfile. This can be done by either running the `synth_project.sh` script in the generated Vivado PYNQ project directory inside Docker, or by executing the `SynthPYNQProject` transformation. **This step involves launching Vivado for synthesis and may take a few hours.**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[key: \"vivado_stitch_proj\"\n",
       "value: \"/tmp/finn_maltanar/vivado_stitch_proj_n3me5eke\"\n",
       ", key: \"vivado_stitch_vlnv\"\n",
       "value: \"xilinx_finn:finn:finn_design:1.0\"\n",
       ", key: \"wrapper_filename\"\n",
       "value: \"/tmp/finn_maltanar/vivado_stitch_proj_n3me5eke/finn_vivado_stitch_proj.srcs/sources_1/bd/finn_design/hdl/finn_design_wrapper.v\"\n",
       ", key: \"vivado_pynq_proj\"\n",
       "value: \"/tmp/finn_maltanar/vivado_pynq_proj_hqlnpt5q\"\n",
       ", key: \"vivado_pynq_bitfile\"\n",
       "value: \"/tmp/finn_maltanar/vivado_pynq_proj_hqlnpt5q/resizer.bit\"\n",
       "]"
      ]
     },
     "execution_count": 39,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from finn.transformation.fpgadataflow.synth_pynq_proj import SynthPYNQProject\n",
    "model = ModelWrapper(build_dir + \"/tfc_w1_a1_pynq_project.onnx\")\n",
    "model = model.transform(SynthPYNQProject())\n",
    "model.model.metadata_props"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.save(build_dir + \"/tfc_w1_a1_post_synthesis.onnx\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Driver generation <a id='driver_gen'></a>\n",
    "\n",
    "Now that we have synthesized a bitfile for our network, we will generate some Python code for PYNQ that will act as the driver for this bitfile, package everything into a deployment folder and copy that to our PYNQ board."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [],
   "source": [
    "from finn.transformation.fpgadataflow.make_pynq_driver import MakePYNQDriver\n",
    "model = ModelWrapper(build_dir + \"/tfc_w1_a1_post_synthesis.onnx\")\n",
    "model = model.transform(MakePYNQDriver())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The generated driver is placed in a folder that is indicated by the `pynq_driver_dir` top-level metadata. We can examine the generated PYNQ Python driver code as follows:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r\n",
      "from pynq import Overlay\r\n",
      "import numpy as np\r\n",
      "from pynq import allocate\r\n",
      "from finn.util.data_packing import (\r\n",
      "    finnpy_to_packed_bytearray,\r\n",
      "    packed_bytearray_to_finnpy\r\n",
      ")\r\n",
      "from finn.core.datatype import DataType\r\n",
      "\r\n",
      "bitfile_path = \"resizer.bit\"\r\n",
      "ol = Overlay(bitfile_path)\r\n",
      "dma=ol.axi_dma_0\r\n",
      "\r\n",
      "# declare input/output types and shapes for the accelerator\r\n",
      "# input FINN DataType\r\n",
      "idt = DataType.BINARY\r\n",
      "# normal, folded and packed input shapes\r\n",
      "ishape_normal = (1, 784)\r\n",
      "ishape_folded = (1, 49, 16)\r\n",
      "ishape_packed = (1, 49, 2)\r\n",
      "# output FINN DataType\r\n",
      "odt = DataType.UINT32\r\n",
      "# normal, folded and packed output shapes\r\n",
      "oshape_normal = (1, 10)\r\n",
      "oshape_folded = (1, 1, 10)\r\n",
      "oshape_packed = (1, 1, 40)\r\n",
      "\r\n",
      "# load desired input .npy file\r\n",
      "ibuf_normal = np.load(\"input.npy\")\r\n",
      "# ensure that shape is as expected\r\n",
      "assert ibuf_normal.shape == ishape_normal\r\n",
      "# convert to folded form\r\n",
      "ibuf_folded = ibuf_normal.reshape(ishape_folded)\r\n",
      "\r\n",
      "# pack the input buffer, reversing both SIMD dim and endianness\r\n",
      "ibuf_packed = finnpy_to_packed_bytearray(\r\n",
      "    ibuf_folded, idt, reverse_endian=True, reverse_inner=True\r\n",
      ")\r\n",
      "# allocate a PYNQ buffer for the packed input buffer\r\n",
      "ibuf_packed_device = allocate(shape=ishape_packed, dtype=np.uint8)\r\n",
      "# copy the packed data into the PYNQ buffer\r\n",
      "# TODO optimization: pack directly into the PYNQ buffer?\r\n",
      "np.copyto(ibuf_packed_device, ibuf_packed)\r\n",
      "\r\n",
      "# allocate a PYNQ buffer for the returned packed output buffer\r\n",
      "obuf_packed = allocate(shape=oshape_packed, dtype=np.uint8)\r\n",
      "\r\n",
      "# set up the DMA and wait until all transfers complete\r\n",
      "dma.sendchannel.transfer(ibuf_packed_device)\r\n",
      "dma.recvchannel.transfer(obuf_packed)\r\n",
      "dma.sendchannel.wait()\r\n",
      "dma.recvchannel.wait()\r\n",
      "\r\n",
      "# unpack the packed output buffer from accelerator\r\n",
      "obuf_folded = packed_bytearray_to_finnpy(\r\n",
      "    obuf_packed, odt, oshape_folded, reverse_endian=True, reverse_inner=True\r\n",
      ")\r\n",
      "# convert to normal reshape and save\r\n",
      "obuf_normal = obuf_folded.reshape(oshape_normal)\r\n",
      "np.save(\"output.npy\", obuf_normal)\r\n"
     ]
    }
   ],
   "source": [
    "driver_dir = model.get_metadata_prop(\"pynq_driver_dir\")\n",
    "! cat {driver_dir}/driver.py"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can see that the generated driver contains the expected input/output shapes, expecting a file called `input.npy` to be provided prior to execution, which will be read in, packed into the format that the accelerator expects, running it and generating an `output.npy` file with the results. You can build your own applications around the accelerator by modifying the driver, or use the remote execution capabilities that FINN provides just to check if it is working, which will be our next step."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Deployment and Remote Execution <a id='deploy'></a>\n",
    "\n",
    "We'll now use the `DeployToPYNQ` transformation to create a deployment folder with the bitfile and driver file(s), and copy that to the PYNQ board. You can change the default IP address, username, password and target folder for the PYNQ below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [],
   "source": [
    "from finn.transformation.fpgadataflow.make_deployment import DeployToPYNQ\n",
    "ip = \"192.168.3.1\"\n",
    "username = \"xilinx\"\n",
    "password = \"xilinx\"\n",
    "target_dir = \"/home/xilinx/finn_tfc_end2end_example\"\n",
    "model = model.transform(DeployToPYNQ(ip, username, password, target_dir))\n",
    "model.save(build_dir + \"/tfc_w1_a1_pynq_deploy.onnx\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's verify that the remote access credentials is saved in the model metadata, and that the deployment folder has been successfully copied to the board:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[key: \"vivado_stitch_proj\"\n",
       "value: \"/tmp/finn_maltanar/vivado_stitch_proj_n3me5eke\"\n",
       ", key: \"vivado_stitch_vlnv\"\n",
       "value: \"xilinx_finn:finn:finn_design:1.0\"\n",
       ", key: \"wrapper_filename\"\n",
       "value: \"/tmp/finn_maltanar/vivado_stitch_proj_n3me5eke/finn_vivado_stitch_proj.srcs/sources_1/bd/finn_design/hdl/finn_design_wrapper.v\"\n",
       ", key: \"vivado_pynq_proj\"\n",
       "value: \"/tmp/finn_maltanar/vivado_pynq_proj_hqlnpt5q\"\n",
       ", key: \"vivado_pynq_bitfile\"\n",
       "value: \"/tmp/finn_maltanar/vivado_pynq_proj_hqlnpt5q/resizer.bit\"\n",
       ", key: \"pynq_driver_dir\"\n",
       "value: \"/tmp/finn_maltanar/pynq_driver_yu_l_jao\"\n",
       ", key: \"pynq_ip\"\n",
       "value: \"192.168.3.1\"\n",
       ", key: \"pynq_username\"\n",
       "value: \"xilinx\"\n",
       ", key: \"pynq_password\"\n",
       "value: \"xilinx\"\n",
       ", key: \"pynq_target_dir\"\n",
       "value: \"/home/xilinx/finn_tfc_end2end_example\"\n",
       ", key: \"pynq_deployment_dir\"\n",
       "value: \"/tmp/finn_maltanar/pynq_deployment_1oyo7x66\"\n",
       ", key: \"pynq_deploy_dir\"\n",
       "value: \"/tmp/finn_maltanar/pynq_deployment_1oyo7x66\"\n",
       ", key: \"exec_mode\"\n",
       "value: \"remote_pynq\"\n",
       "]"
      ]
     },
     "execution_count": 49,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.model.metadata_props"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 103,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "total 5820\r\n",
      "-rw-r--r-- 1 xilinx xilinx    1934 Feb 13 13:36 driver.py\r\n",
      "drwxr-xr-x 4 xilinx xilinx    4096 Feb 13 13:36 finn\r\n",
      "-rw-r--r-- 1 xilinx xilinx    3264 Feb 13 14:24 input.npy\r\n",
      "-rw-r--r-- 1 root   root       120 Feb 13 14:24 output.npy\r\n",
      "-rw-r--r-- 1 xilinx xilinx 5568787 Feb 13 13:36 resizer.bit\r\n",
      "-rw-r--r-- 1 xilinx xilinx  368173 Feb 13 13:36 resizer.hwh\r\n",
      "-rw-r--r-- 1 root   root        32 Feb 13 14:24 sds_trace_data.dat\r\n"
     ]
    }
   ],
   "source": [
    "! sshpass -p {password} ssh {username}@{ip} 'ls -l {target_dir}/*'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We only have two more steps to be able to remotely execute the deployed bitfile with some test data from the MNIST dataset. Let's load up some test data that comes bundled with FINN."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.image.AxesImage at 0x7fc62d83dbe0>"
      ]
     },
     "execution_count": 52,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8li6FKAAARX0lEQVR4nO3dfYyVZXrH8d/FoDAw8iYRCaisG/5QqmUbgk1KyOKmxlUMbKJm/aPauAmarMmqTVqz/UOSaqJVa/pH3YStL9CsmiWoq0a7a82mWo1GNFQQW1CULGR4E5H3t+HqH/NgZ3We6549z3nOc9z7+0kmM3Ouec65OTM/zsv13Pdt7i4Af/xGNT0AAJ1B2IFMEHYgE4QdyARhBzIxupM3Zma89Z+ZUaPKH09OnTpV23VXvf6enp6wPjAw0PJ1183dbbjLK4XdzK6U9M+SeiT9q7vfV+X6cmU27O/mS6k/6ip/eKNHx38CqcCk6r29vaW1Q4cOhcem9PX1hfUDBw6U1lIt50mTJoX1zz77LKx3o5afxptZj6R/kfR9SRdLusHMLm7XwAC0V5XX7PMlfeTuW9z9uKSnJS1pz7AAtFuVsM+Q9Lsh328rLvs9ZrbMzNaa2doKtwWgotrfoHP3FZJWSLxBBzSpyiP7dknnDfl+ZnEZgC5UJezvSJptZt8yszMl/VDS8+0ZFoB2a/lpvLufNLPbJP1ag623x9z9g7aNLCPjx48P6wcPHmz5useMGRPWjx07FtZTbcFx48aF9ai9lmoppqSOj9prqT76vn37WhpTN6v0mt3dX5L0UpvGAqBGnC4LZIKwA5kg7EAmCDuQCcIOZIKwA5mwTq4um+vpsqled6qXffTo0bA+duzYlo9Nia676vWfffbZYb3qNNLofp06dWp47O7du8N6amrwyZMnw3qdyuaz88gOZIKwA5kg7EAmCDuQCcIOZIKwA5mg9fYNkGrNVfkd1nnddUtNDa6yem1q6m5qanCTS03TegMyR9iBTBB2IBOEHcgEYQcyQdiBTBB2IBP02TvgrLPOCuvRbqOSNHHixLB+4sSJ0lpqN9LUFNbPP/88rC9YsCCs33rrraW1VC/6jjvuCOtbt24N601OM20SfXYgc4QdyARhBzJB2IFMEHYgE4QdyARhBzJBn/0b4JFHHgnrUS871Wuuuox1b29vWI+ktk2+5JJLwvqmTZvC+vHjx0trZ5xxRnhsdO6ClP53HzlyJKzXqazPXmnLZjP7VNIBSQOSTrr7vCrXB6A+lcJeWOTue9pwPQBqxGt2IBNVw+6SfmNm75rZsuF+wMyWmdlaM1tb8bYAVFD1afwCd99uZudIesXM/sfdXxv6A+6+QtIKiTfogCZVemR39+3F512SnpU0vx2DAtB+LYfdzMab2Vmnv5Z0haQN7RoYgPaq8jR+mqRniz7taElPuvu/t2VUf2RSWzYvWrQorF922WVhPeqVHzx4MDw21W/u6+sL66nzNKI566m11x999NGWr1uS7rzzztLaW2+9FR5b93bSTWg57O6+RdKftnEsAGpE6w3IBGEHMkHYgUwQdiAThB3IBFNcu0Bqqubs2bPD+v79+0trEyZMCI+NpoFK6SmwVbZ8TrX9UlJLcO/du7e0tnTp0vDYdevWhfVUSzLV8qwTS0kDmSPsQCYIO5AJwg5kgrADmSDsQCYIO5CJdiw42TFRT7fOfnBK6thU/ZZbbgnrq1atCuszZ85s+bZTffZ77rknrK9evTqsn3nmmaW1K664Ijz2wQcfDOuprbCj2168eHF47LZt28L6nj3fvDVWeWQHMkHYgUwQdiAThB3IBGEHMkHYgUwQdiATHZ/Pnup3Rzo51naqOvd54cKFYf2iiy4qrY0bNy48dvTo+FSLNWvWhPUtW7aE9SpSyz3PmTMnrKfu90jq75T57AC6FmEHMkHYgUwQdiAThB3IBGEHMkHYgUx0vM8+alT5/y9V54XXqcpc+lOnTlW67eg+S9VPnjwZHjt+/PiwfujQobCe2o46+p2l5tJfffXVYf3pp58O61X67Kk17VP3a5Na7rOb2WNmtsvMNgy5bIqZvWJmm4vPk9s5WADtN5Kn8U9IuvIrl90l6VV3ny3p1eJ7AF0sGXZ3f03SV/fRWSJpZfH1SknxXjoAGtfqGnTT3L2/+HqHpGllP2hmyyQta/F2ALRJ5QUn3d2jDRvdfYWkFRIbOwJNarX1ttPMpktS8XlX+4YEoA6thv15STcVX98k6VftGQ6AuiT77Gb2lKTvSpoqaaekuyU9J+mXks6XtFXS9e5evhn2/19XbU/jq64bX7UeSfVkU3uoR/uvV9Xb2xvWjxw5EtZT5wBUOcfgwgsvDOsff/xxy9edGldqTfqUw4cPVzq+irI+e/I1u7vfUFL6XqURAegoTpcFMkHYgUwQdiAThB3IBGEHMsGWzYVUC3JgYCCsR3p6esJ61WWHozZRqsWUmsKakrr+aNvkqCZJixYtamlMp0W/0xMnToTHpqa4Vvl7aAqP7EAmCDuQCcIOZIKwA5kg7EAmCDuQCcIOZKKr+ux1budcdTnnKuq+7QMHDpTWUv3iVK87dXyqTx8tF51axvq6664L60ePHg3rY8eOLa2l+uyp31mTWzK3ikd2IBOEHcgEYQcyQdiBTBB2IBOEHcgEYQcy0fE+ezS3u5t75dGSyanllFPq3Fb50ksvDY+dM2dOWE8tJf3cc8+F9UjUB5ekhQsXhvUqW3inlqGOzl2Qqi/B3QQe2YFMEHYgE4QdyARhBzJB2IFMEHYgE4QdyETH++zRnPU6++ipufKped1RT3j06PhuXLp0aVhPHb9kyZKwPmbMmNLa3Llzw2MnTZoU1lO97Ndff73l42fPnh0em1qbPdXrXr9+fWnt8ssvD4+N7lOpO/voKclHdjN7zMx2mdmGIZctN7PtZrau+Liq3mECqGokT+OfkHTlMJc/7O5zi4+X2jssAO2WDLu7vyZpbwfGAqBGVd6gu83M3i+e5k8u+yEzW2Zma81sbYXbAlBRq2H/maRvS5orqV/SQ2U/6O4r3H2eu89r8bYAtEFLYXf3ne4+4O6nJP1c0vz2DgtAu7UUdjObPuTbH0jaUPazALqDpfqoZvaUpO9Kmippp6S7i+/nSnJJn0q6xd37kzdmFt5Yqt+cmvcdmTVrVli/5pprwvrixYtLa6l516l526m509H+61K8hnlfX194bErVed3R7/SLL74Ij504cWJYT9m8eXNpbdWqVeGxDz1U+spUUnf32d192JNKkifVuPsNw1z8aOURAegoTpcFMkHYgUwQdiAThB3IBGEHMpFsvbX1xsw8Wna5zimud999d1hfvnx5WN+zZ09pberUqa0M6UuprYf37o2nJkT1Cy64IDw21RZMbdmccuzYsdJaahpp6u8h1YqNpi2ntlx++eWXw/rNN98c1pvc0rms9cYjO5AJwg5kgrADmSDsQCYIO5AJwg5kgrADmeh4nz2qV9maODXVMtX3rLLt8q5du8L61q1bw/oDDzwQ1levXh3W580rXwTo4YcfDo9Nbdk8eXLpimOSpG3btoX16Hf6xBNPhMd+8sknYf3aa68N69HU46rTa1988cWwnpoyXSf67EDmCDuQCcIOZIKwA5kg7EAmCDuQCcIOZKKjffZRo0Z5ND/6+PHj4fHnnHNOaW337t3hsak+e2rudNQvTm0HvWnTprA+ZcqUsJ5atjha7vn8888Pj03NZ08t771v376wfuONN5bWXnjhhfDYlNQ6AtFy0YsWLQqPTa0xkLpfUst/14k+O5A5wg5kgrADmSDsQCYIO5AJwg5kgrADmeiq+exVpPqeK1euDOvXX399y9d/+PDh8Nhx48aF9dS2yKl5/gMDA6W11Lrvb775Zlh/8sknw/q6devC+htvvFFaS51fkOrhp37n0Xkb8+fPD499++23w/rjjz8e1lPrytep5T67mZ1nZr81s41m9oGZ/aS4fIqZvWJmm4vP8SoHABo1kqfxJyX9jbtfLOnPJf3YzC6WdJekV919tqRXi+8BdKlk2N29393fK74+IOlDSTMkLZF0+rnxSklL6xokgOriFz1fYWazJH1H0tuSprl7f1HaIWlayTHLJC1rfYgA2mHE78abWZ+kNZJud/f9Q2s++C7fsG++ufsKd5/n7uWrIgKo3YjCbmZnaDDov3D3Z4qLd5rZ9KI+XVK8xCqARiVbbzY4f3OlpL3ufvuQyx+Q9Jm732dmd0ma4u5/m7iu8MbOPffccCw7duwI65Fo+15JmjlzZli/9957S2szZswIj01tuZzaujjaLlqS7r///tLaxo0bw2NTU1xT2yKnpKYtR1JtwxMnToT1aOpx6u9+woQJYb3qlOk6lbXeRvKa/S8k/ZWk9WZ2uqn6U0n3Sfqlmf1I0lZJcaMaQKOSYXf3/5JU9l/k99o7HAB14XRZIBOEHcgEYQcyQdiBTBB2IBMdneLa09PjUV83NVU06n3u37+/tCZJfX19YT3VN416vlX6vVK655s6RyDqZad6+MeOHQvrVUW/79Ryzampwam/lyq/s5SqY6sTS0kDmSPsQCYIO5AJwg5kgrADmSDsQCYIO5CJrlpKOjWHOOqlp5YVrjove/r06aW1/v7+0tpI9Pb2hvXUls11XndqGetDhw6F9SpzylNGjYofq6rMKW/6/IQq6LMDmSPsQCYIO5AJwg5kgrADmSDsQCYIO5CJruqzA6iOPjuQOcIOZIKwA5kg7EAmCDuQCcIOZIKwA5lIht3MzjOz35rZRjP7wMx+Uly+3My2m9m64uOq+ocLoFXJk2rMbLqk6e7+npmdJeldSUs1uB/7QXd/cMQ3xkk1QO3KTqoZyf7s/ZL6i68PmNmHkma0d3gA6vYHvWY3s1mSviPp7eKi28zsfTN7zMwmlxyzzMzWmtnaSiMFUMmIz403sz5J/ynpXnd/xsymSdojySX9gwaf6t+cuA6exgM1K3saP6Kwm9kZkl6U9Gt3/6dh6rMkvejuf5K4HsIO1KzliTA2uDzoo5I+HBr04o27034gaUPVQQKoz0jejV8g6XVJ6yWdXpv3p5JukDRXg0/jP5V0S/FmXnRdPLIDNav0NL5dCDtQP+azA5kj7EAmCDuQCcIOZIKwA5kg7EAmCDuQCcIOZIKwA5kg7EAmCDuQCcIOZIKwA5kg7EAmkgtOttkeSVuHfD+1uKwbdevYunVcEmNrVTvHdkFZoaPz2b9242Zr3X1eYwMIdOvYunVcEmNrVafGxtN4IBOEHchE02Ff0fDtR7p1bN06LomxtaojY2v0NTuAzmn6kR1AhxB2IBONhN3MrjSz/zWzj8zsribGUMbMPjWz9cU21I3uT1fsobfLzDYMuWyKmb1iZpuLz8PusdfQ2LpiG+9gm/FG77umtz/v+Gt2M+uRtEnSX0raJukdSTe4+8aODqSEmX0qaZ67N34ChpktlHRQ0qrTW2uZ2T9K2uvu9xX/UU5297/rkrEt1x+4jXdNYyvbZvyv1eB9187tz1vRxCP7fEkfufsWdz8u6WlJSxoYR9dz99ck7f3KxUskrSy+XqnBP5aOKxlbV3D3fnd/r/j6gKTT24w3et8F4+qIJsI+Q9Lvhny/Td2137tL+o2ZvWtmy5oezDCmDdlma4ekaU0OZhjJbbw76SvbjHfNfdfK9udV8Qbd1y1w9z+T9H1JPy6ernYlH3wN1k29059J+rYG9wDsl/RQk4MpthlfI+l2d98/tNbkfTfMuDpyvzUR9u2Szhvy/czisq7g7tuLz7skPavBlx3dZOfpHXSLz7saHs+X3H2nuw+4+ylJP1eD912xzfgaSb9w92eKixu/74YbV6futybC/o6k2Wb2LTM7U9IPJT3fwDi+xszGF2+cyMzGS7pC3bcV9fOSbiq+vknSrxocy+/plm28y7YZV8P3XePbn7t7xz8kXaXBd+Q/lvT3TYyhZFwXSvrv4uODpscm6SkNPq07ocH3Nn4k6WxJr0raLOk/JE3porH9mwa39n5fg8Ga3tDYFmjwKfr7ktYVH1c1fd8F4+rI/cbpskAmeIMOyARhBzJB2IFMEHYgE4QdyARhBzJB2IFM/B+tIjCppYWKvAAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "from pkgutil import get_data\n",
    "import onnx.numpy_helper as nph\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "raw_i = get_data(\"finn\", \"data/onnx/mnist-conv/test_data_set_0/input_0.pb\")\n",
    "x = nph.to_array(onnx.load_tensor_from_string(raw_i))\n",
    "plt.imshow(x.reshape(28,28), cmap='gray')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Recall that we partitioned our original network into a parent graph that contained the non-synthesizable nodes and a child graph that contained the bulk of the network, which we turned into a bitfile. We'll load up the parent graph, modify the `StreamingDataflowPartition` node so that it points to the deployed ONNX graph."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {},
   "outputs": [],
   "source": [
    "parent_model = ModelWrapper(build_dir+\"/tfc_w1_a1_dataflow_parent.onnx\")\n",
    "sdp_node = parent_model.graph.node[2]\n",
    "remote_exec_model = build_dir + \"/tfc_w1_a1_pynq_deploy.onnx\"\n",
    "getCustomOp(sdp_node).set_nodeattr(\"model\", remote_exec_model)\n",
    "parent_model.save(build_dir+\"/tfc_w1_a1_dataflow_parent_with_remote_bitfile_exec.onnx\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, we can call `execute_onnx` on the parent graph, which will internally call remote execution with the bitfile once the `StreamingDataflowPartition` node is reached, grab the results, then continue executing the last portion of the network. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from finn.core.onnx_exec import execute_onnx\n",
    "iname = parent_model.graph.input[0].name\n",
    "oname = parent_model.graph.output[0].name\n",
    "ishape = parent_model.get_tensor_shape(iname)\n",
    "input_dict = {iname: x.reshape(ishape)}\n",
    "ret = execute_onnx(parent_model, input_dict, True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We'll pass the output of the network through a softmax function to interpret it as probabilities, and plot the per-class probabilities as a bar chart."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<BarContainer object of 10 artists>"
      ]
     },
     "execution_count": 62,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8li6FKAAAMoUlEQVR4nO3cf6jd913H8edryercD1sxV9AkLgEzNQyl5dJVC1pshbSV5A9FGqjoKMs/y6yuKJlKHfWfzcn8gXUa5xzO2azWIcFGI7iKILbkdp11SYxcstrcrNK7rtYfQ7Pg2z/uiZzd3ptzkp57T/u+zwcEzvf7/XC+75ObPDn3e36kqpAkvfa9btoDSJImw6BLUhMGXZKaMOiS1IRBl6QmNk/rxFu2bKkdO3ZM6/SS9Jr05JNPfqmqZlY6NrWg79ixg7m5uWmdXpJek5L8y2rHvOQiSU0YdElqwqBLUhMjg57kY0meT/L5VY4nyW8mmU/ydJIbJj+mJGmUcZ6hfxzYc5njtwO7Bn8OAB955WNJkq7UyKBX1d8CX77Mkn3AH9aSx4HrknzLpAaUJI1nEtfQtwLnhrYXBvteJsmBJHNJ5hYXFydwaknSJev6omhVHa6q2aqanZlZ8X3xkqSrNImgnwe2D21vG+yTJK2jSXxS9ChwMMkR4B3AS1X13ATuV8vsOPTomp/jmQ/cuebnkLQ2RgY9yUPALcCWJAvALwGvB6iq3wGOAXcA88BXgHeu1bCSpNWNDHpV7R9xvIB3T2wiSdJV8ZOiktSEQZekJgy6JDVh0CWpCYMuSU0YdElqwqBLUhMGXZKaMOiS1IRBl6QmDLokNWHQJakJgy5JTRh0SWrCoEtSEwZdkpow6JLUhEGXpCYMuiQ1YdAlqQmDLklNGHRJasKgS1ITBl2SmjDoktSEQZekJgy6JDVh0CWpCYMuSU0YdElqwqBLUhMGXZKaMOiS1MRYQU+yJ8mZJPNJDq1w/NuSPJbkqSRPJ7lj8qNKki5nZNCTbAIeBG4HdgP7k+xetuwXgYer6nrgLuC3Jz2oJOnyxnmGfiMwX1Vnq+oCcATYt2xNAd8wuH0t8MXJjShJGsc4Qd8KnBvaXhjsG/Z+4O4kC8Ax4D0r3VGSA0nmkswtLi5exbiSpNVM6kXR/cDHq2obcAfwiSQvu++qOlxVs1U1OzMzM6FTS5JgvKCfB7YPbW8b7Bt2D/AwQFX9PfAGYMskBpQkjWecoJ8AdiXZmeQall70PLpszbPArQBJvouloHtNRZLW0cigV9VF4CBwHDjN0rtZTiZ5IMnewbL7gHcl+QfgIeAnq6rWamhJ0sttHmdRVR1j6cXO4X33D90+Bdw82dEkSVfCT4pKUhMGXZKaMOiS1IRBl6QmDLokNWHQJakJgy5JTRh0SWrCoEtSEwZdkpow6JLUhEGXpCYMuiQ1YdAlqQmDLklNGHRJasKgS1ITBl2SmjDoktSEQZekJgy6JDVh0CWpCYMuSU0YdElqwqBLUhMGXZKaMOiS1IRBl6QmDLokNWHQJakJgy5JTRh0SWrCoEtSE2MFPcmeJGeSzCc5tMqaH0tyKsnJJH882TElSaNsHrUgySbgQeCHgAXgRJKjVXVqaM0u4H3AzVX1YpJvXquBJUkrG+cZ+o3AfFWdraoLwBFg37I17wIerKoXAarq+cmOKUkaZZygbwXODW0vDPYNexvwtiR/l+TxJHtWuqMkB5LMJZlbXFy8uoklSSua1Iuim4FdwC3AfuD3kly3fFFVHa6q2aqanZmZmdCpJUkwXtDPA9uHtrcN9g1bAI5W1Ver6gvAP7MUeEnSOhkn6CeAXUl2JrkGuAs4umzNn7H07JwkW1i6BHN2gnNKkkYYGfSquggcBI4Dp4GHq+pkkgeS7B0sOw68kOQU8Bjws1X1wloNLUl6uZFvWwSoqmPAsWX77h+6XcB7B38kSVPgJ0UlqQmDLklNGHRJasKgS1ITBl2SmjDoktSEQZekJgy6JDVh0CWpCYMuSU0YdElqwqBLUhMGXZKaMOiS1IRBl6QmDLokNWHQJakJgy5JTRh0SWrCoEtSEwZdkpow6JLUhEGXpCYMuiQ1YdAlqQmDLklNGHRJasKgS1ITBl2SmjDoktSEQZekJgy6JDVh0CWpibGCnmRPkjNJ5pMcusy6H0lSSWYnN6IkaRwjg55kE/AgcDuwG9ifZPcK694C3As8MekhJUmjjfMM/UZgvqrOVtUF4Aiwb4V1vwx8EPjvCc4nSRrTOEHfCpwb2l4Y7Pt/SW4AtlfVo5e7oyQHkswlmVtcXLziYSVJq3vFL4omeR3wYeC+UWur6nBVzVbV7MzMzCs9tSRpyDhBPw9sH9reNth3yVuAtwN/k+QZ4CbgqC+MStL6GifoJ4BdSXYmuQa4Czh66WBVvVRVW6pqR1XtAB4H9lbV3JpMLEla0cigV9VF4CBwHDgNPFxVJ5M8kGTvWg8oSRrP5nEWVdUx4NiyffevsvaWVz6WJOlK+UlRSWrCoEtSEwZdkpow6JLUhEGXpCYMuiQ1YdAlqQmDLklNGHRJasKgS1ITBl2SmjDoktSEQZekJgy6JDVh0CWpCYMuSU0YdElqwqBLUhMGXZKaMOiS1IRBl6QmDLokNWHQJakJgy5JTRh0SWrCoEtSEwZdkpow6JLUhEGXpCYMuiQ1YdAlqQmDLklNGHRJamKsoCfZk+RMkvkkh1Y4/t4kp5I8neSvk7x18qNKki5nZNCTbAIeBG4HdgP7k+xetuwpYLaqvht4BPiVSQ8qSbq8cZ6h3wjMV9XZqroAHAH2DS+oqseq6iuDzceBbZMdU5I0yjhB3wqcG9peGOxbzT3AX6x0IMmBJHNJ5hYXF8efUpI00kRfFE1yNzALfGil41V1uKpmq2p2ZmZmkqeWpA1v8xhrzgPbh7a3DfZ9jSS3Ab8A/EBV/c9kxpMkjWucZ+gngF1Jdia5BrgLODq8IMn1wO8Ce6vq+cmPKUkaZWTQq+oicBA4DpwGHq6qk0keSLJ3sOxDwJuBP0nyuSRHV7k7SdIaGeeSC1V1DDi2bN/9Q7dvm/BckqQr5CdFJakJgy5JTRh0SWrCoEtSEwZdkpow6JLUhEGXpCYMuiQ1YdAlqQmDLklNGHRJasKgS1ITBl2SmjDoktSEQZekJgy6JDVh0CWpCYMuSU0YdElqwqBLUhMGXZKaMOiS1IRBl6QmDLokNWHQJakJgy5JTRh0SWrCoEtSEwZdkpow6JLUhEGXpCYMuiQ1YdAlqYmxgp5kT5IzSeaTHFrh+Ncl+dTg+BNJdkx6UEnS5Y0MepJNwIPA7cBuYH+S3cuW3QO8WFXfDvwa8MFJDypJurzNY6y5EZivqrMASY4A+4BTQ2v2Ae8f3H4E+K0kqaqa4Kyaoh2HHl3zczzzgTvX/ByvNWv99+7feS/jBH0rcG5oewF4x2prqupikpeAbwK+NLwoyQHgwGDzP5OcuZqhr9KW5fNsEFf0uDPF360mfG5/3mOY5s97wjbSz/utqx0YJ+gTU1WHgcPrec5LksxV1ew0zj1NPu6Nxce9sY3zouh5YPvQ9rbBvhXXJNkMXAu8MIkBJUnjGSfoJ4BdSXYmuQa4Czi6bM1R4CcGt38U+IzXzyVpfY285DK4Jn4QOA5sAj5WVSeTPADMVdVR4PeBTySZB77MUvRfbaZyqedVwMe9sfi4N7D4RFqSevCTopLUhEGXpCbaB33U1xZ0lGR7kseSnEpyMsm9055pPSXZlOSpJH8+7VnWU5LrkjyS5J+SnE7yvdOeaT0k+ZnBv/PPJ3koyRumPdO0tA76mF9b0NFF4L6q2g3cBLx7gzzuS+4FTk97iCn4DeAvq+o7ge9hA/wdJNkK/BQwW1VvZ+mNG6/GN2Wsi9ZBZ+hrC6rqAnDpawtaq6rnquqzg9v/wdJ/7K3TnWp9JNkG3Al8dNqzrKck1wLfz9I7zqiqC1X1b9Odat1sBr5+8BmYNwJfnPI8U9M96Ct9bcGGCNslg2++vB54YrqTrJtfB34O+N9pD7LOdgKLwB8MLjd9NMmbpj3UWquq88CvAs8CzwEvVdVfTXeq6eke9A0tyZuBPwV+uqr+fdrzrLUkPww8X1VPTnuWKdgM3AB8pKquB/4LaP+aUZJvZOm37p3AtwJvSnL3dKeanu5BH+drC1pK8nqWYv7Jqvr0tOdZJzcDe5M8w9LltR9M8kfTHWndLAALVXXpN7FHWAp8d7cBX6iqxar6KvBp4PumPNPUdA/6OF9b0E6SsHQt9XRVfXja86yXqnpfVW2rqh0s/aw/U1Ub4tlaVf0rcC7Jdwx23crXfsV1V88CNyV54+Df/a1sgBeDV7Ou37a43lb72oIpj7UebgZ+HPjHJJ8b7Pv5qjo2xZm09t4DfHLw5OUs8M4pz7PmquqJJI8An2Xp3V1PsYG/BsCP/ktSE90vuUjShmHQJakJgy5JTRh0SWrCoEtSEwZdkpow6JLUxP8BwjHuoBhu1y0AAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "def softmax(x):\n",
    "    \"\"\"Compute softmax values for each sets of scores in x.\"\"\"\n",
    "    e_x = np.exp(x - np.max(x))\n",
    "    return e_x / e_x.sum()\n",
    "\n",
    "logits = ret[oname].flatten()\n",
    "prob = softmax(logits)\n",
    "\n",
    "plt.bar(np.arange(10), prob)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We see that the network correctly predicts this as a digit 2 with high probability. This concludes our tutorial on how to take a simple fully-connected BNN all the way down to hardware with FINN, and execute it remotely on a PYNQ board."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}