Skip to content
Snippets Groups Projects
Commit 6950f3f7 authored by auphelia's avatar auphelia
Browse files

[Notebook] Add link to section about mem mode to tfc end2end notebook

parent 7f19a2f0
No related branches found
No related tags found
No related merge requests found
%% Cell type:markdown id: tags:
# FINN - End-to-End Flow
-----------------------------------------------------------------
In this notebook, we will show how to take a simple, binarized, fully-connected network trained on the MNIST data set and take it all the way down to a customized bitfile running on a PYNQ board.
This notebook is quite lengthy, and some of the cells (involving Vivado synthesis) may take up to an hour to finish running. To let you save and resume your progress, we will save the intermediate ONNX models that are generated in the various steps to disk, so that you can jump back directly to where you left off.
%% Cell type:markdown id: tags:
## Overview
The FINN compiler comes with many *transformations* that modify the ONNX representation of the network according to certain patterns. This notebook will demonstrate a *possible* sequence of such transformations to take a particular trained network all the way down to hardware, as shown in the figure below.
%% Cell type:markdown id: tags:
![](finn-design-flow-example.svg)
%% Cell type:markdown id: tags:
The white fields show the state of the network representation in the respective step. The colored fields represent the transformations that are applied to the network to achieve a certain result. The diagram is divided into 5 sections represented by a different color, each of it includes several flow steps. The flow starts in top left corner with Brevitas export (green section), followed by the preparation of the network (blue section) for the Vivado HLS synthesis and Vivado IPI stitching (orange section), and finally building a PYNQ overlay bitfile and testing it on a PYNQ board (yellow section).
There is an additional section for functional verification (red section) on the right side of the diagram, which we will not cover in this notebook. For details please take a look in the verification notebook which you can find [here](tfc_end2end_verification.ipynb)
This Jupyter notebook is organized based on the sections described above. We will use the following helper functions, `showSrc` to show source code of FINN library calls and `showInNetron` to show the ONNX model at the current transformation step. The Netron displays are interactive, but they only work when running the notebook actively and not on GitHub (i.e. if you are viewing this on GitHub you'll only see blank squares).
%% Cell type:code id: tags:
``` python
from finn.util.visualization import showSrc, showInNetron
from finn.util.basic import make_build_dir
build_dir = "/workspace/finn"
```
%% Cell type:markdown id: tags:
## Outline
-------------
1. [Brevitas export](#brev_exp)
2. [Network preparation](#nw_prep)
3. [Vivado HLS and IPI](#vivado)
4. [PYNQ hardware generation and deployment](#hw_test)
%% Cell type:markdown id: tags:
## 1. Brevitas export <a id='brev_exp'></a>
FINN expects an ONNX model as input. This can be a model trained with [Brevitas](https://github.com/Xilinx/brevitas). Brevitas is a PyTorch library for quantization-aware training and the FINN Docker image comes with several [example Brevitas networks](https://github.com/maltanar/brevitas_cnv_lfc). To show the FINN end-to-end flow, we'll use the TFC-w1a1 model as example network.
First a few things have to be imported. Then the model can be loaded with the pretrained weights.
%% Cell type:code id: tags:
``` python
import onnx
from finn.util.test import get_test_model_trained
import brevitas.onnx as bo
tfc = get_test_model_trained("TFC", 1, 1)
bo.export_finn_onnx(tfc, (1, 1, 28, 28), build_dir+"/tfc_w1_a1.onnx")
```
%% Output
/workspace/brevitas_cnv_lfc/training_scripts/models/TFC.py:85: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
x = 2.0 * x - torch.tensor([1.0]).to(self.device)
%% Cell type:markdown id: tags:
The model was now exported, loaded with the pretrained weights and saved under the name "lfc_w1_a1.onnx".
To visualize the exported model, Netron can be used. Netron is a visualizer for neural networks and allows interactive investigation of network properties. For example, you can click on the individual nodes and view the properties.
%% Cell type:code id: tags:
``` python
showInNetron(build_dir+"/tfc_w1_a1.onnx")
```
%% Output
Serving '/workspace/finn/tfc_w1_a1.onnx' at http://0.0.0.0:8081
<IPython.lib.display.IFrame at 0x7fe1ad0b6e80>
%% Cell type:markdown id: tags:
Now that we have the model in .onnx format, we can work with it using FINN. For that FINN `ModelWrapper` is used. It is a wrapper around the ONNX model which provides several helper functions to make it easier to work with the model.
%% Cell type:code id: tags:
``` python
from finn.core.modelwrapper import ModelWrapper
model = ModelWrapper(build_dir+"/tfc_w1_a1.onnx")
```
%% Cell type:markdown id: tags:
Now the model is prepared and could be simulated using Python. How this works is described in the Jupyter notebook about verification and can be found [here](tfc_end2end_verification.ipynb#simpy).
The model can now also be processed in different ways. The principle of FINN are analysis and transformation passes, which can be applied to the model. An analysis pass extracts specific information about the model and returns it to the user in the form of a dictionary. A transformation pass changes the model and returns the changed model back to the FINN flow.
Since the goal in this notebook is to process the model to such an extent that a bitstream can be generated from it, the focus is on the transformations that are necessary for this. In the next section these are discussed in more detail.
%% Cell type:markdown id: tags:
## 2. Network preparation <a id='nw_prep'></a>
* [FINN-style Dataflow Architectures](#dataflow_arch)
* [Tidy-up transformations](#basic_trafo)
* [Streamlining](#streamline)
* [Conversion to HLS layers](#hls_layers)
* [Creating a Dataflow Partition](#dataflow_partition)
* [Folding and Datawidth Converter, FIFO and TLastMarker Insertion](#folding)
In this section, we will put the network through a series of transformations that puts it in a form that can be stitched together to form a FINN-style dataflow architecture, yielding a high-performance, high-efficiency FPGA accelerator.
%% Cell type:markdown id: tags:
### FINN-style Dataflow Architectures <a id='dataflow_arch'></a>
We start with a quick recap of FINN-style dataflow architectures. The key idea in such architectures is to parallelize across layers as well as within layers by dedicating a proportionate amount of compute resources to each layer, as illustrated in the figure below taken from the [FINN-R paper](https://arxiv.org/pdf/1809.04570.pdf):
![](finn-hw-arch.png)
In practice, the compute arrays are instantiated by function calls to optimized Vivado HLS building blocks from the [finn-hlslib](https://github.com/Xilinx/finn-hlslib) library. As these function calls can only handle certain patterns/cases, we need to transform the network into an appropriate form so that we can replace network layers with these function calls, which is the goal of the network preparation process.
%% Cell type:markdown id: tags:
### Tidy-up transformations <a id='basic_trafo'></a>
This section deals with some basic transformations, which are applied to the model like a kind of "tidy-up" to make it easier to be processed. They do not appear in the diagram above, but they are applied in many steps in the FINN flow to postprocess the model after a transformation and/or to prepare it for the next transformation.
%% Cell type:markdown id: tags:
These transformations are:
* GiveUniqueNodeNames
* GiveReadableTensorNames
* InferShapes
* InferDataTypes
* FoldConstants
%% Cell type:markdown id: tags:
In the first two transformations (`GiveUniqueNodeNames`, `GiveReadableTensorNames`) the nodes in the graph are first given unique (by enumeration) names, then the tensors are given human-readable names (based on the node names). The following two transformations (`InferShapes`, `InferDataTypes`) derive the shapes and data types of the tensors from the model properties and set them in the `ValueInfo` of the model. These transformations can almost always be applied without negative effects and do not affect the structure of the graph, ensuring that all the information needed is available.
The last listed transformation is `FoldConstants`, which performs constant folding. It identifies a node with constant inputs and determines its output. The result is then set as constant-only inputs for the following node and the old node is removed. Although this transformation changes the structure of the model, it is a transformation that is usually always desired and can be applied to any model.
%% Cell type:markdown id: tags:
These transformations can be imported and applied as follows.
%% Cell type:code id: tags:
``` python
from finn.transformation.general import GiveReadableTensorNames, GiveUniqueNodeNames
from finn.transformation.infer_shapes import InferShapes
from finn.transformation.infer_datatypes import InferDataTypes
from finn.transformation.fold_constants import FoldConstants
model = model.transform(InferShapes())
model = model.transform(FoldConstants())
model = model.transform(GiveUniqueNodeNames())
model = model.transform(GiveReadableTensorNames())
model = model.transform(InferDataTypes())
model.save(build_dir+"/tfc_w1_a1_tidy.onnx")
```
%% Cell type:markdown id: tags:
The result of these transformations can be viewed with netron after the model has been saved again. By clicking on the individual nodes, it can now be seen, for example, that each node has been given a name. Also the whole upper area could be folded, so that now the first node is "Reshape".
%% Cell type:code id: tags:
``` python
showInNetron(build_dir+"/tfc_w1_a1_tidy.onnx")
```
%% Output
Stopping http://0.0.0.0:8081
Serving '/workspace/finn/tfc_w1_a1_tidy.onnx' at http://0.0.0.0:8081
<IPython.lib.display.IFrame at 0x7fe1ad0639e8>
%% Cell type:markdown id: tags:
### Streamlining <a id='streamline'></a>
Streamlining is a transformation containing several sub-transformations. The goal of streamlining is to eliminate floating point operations by moving them around, then collapsing them into one operation and in the last step transform them into multi-thresholding nodes. For more information on the theoretical background of this, see [this paper](https://arxiv.org/pdf/1709.04060).
Let's have a look at which sub-transformations `Streamline` consists of:
%% Cell type:code id: tags:
``` python
from finn.transformation.streamline import Streamline
showSrc(Streamline)
```
%% Output
class Streamline(Transformation):
"""Apply the streamlining transform, see arXiv:1709.04060."""
def apply(self, model):
streamline_transformations = [
ConvertSubToAdd(),
ConvertDivToMul(),
BatchNormToAffine(),
ConvertSignToThres(),
MoveAddPastMul(),
MoveScalarAddPastMatMul(),
MoveScalarAddPastConv(),
MoveScalarMulPastMatMul(),
MoveScalarMulPastConv(),
MoveAddPastMul(),
CollapseRepeatedAdd(),
CollapseRepeatedMul(),
AbsorbAddIntoMultiThreshold(),
FactorOutMulSignMagnitude(),
AbsorbMulIntoMultiThreshold(),
Absorb1BitMulIntoMatMul(),
Absorb1BitMulIntoConv(),
RoundAndClipThresholds(),
]
for trn in streamline_transformations:
model = model.transform(trn)
model = model.transform(GiveUniqueNodeNames())
model = model.transform(GiveReadableTensorNames())
model = model.transform(InferDataTypes())
return (model, False)
%% Cell type:markdown id: tags:
As can be seen, several transformations are involved in the streamlining transformation. There are move and collapse transformations. In the last step the operations are transformed into multithresholds. The involved transformations can be viewed in detail [here](https://github.com/Xilinx/finn/tree/master/src/finn/transformation/streamline). After each transformation, three of the tidy-up transformations (`GiveUniqueNodeNames`, `GiveReadableTensorNames` and `InferDataTypes`) are applied to the model.
After streamlining the network looks as follows:
%% Cell type:code id: tags:
``` python
model = ModelWrapper(build_dir+"/tfc_w1_a1_tidy.onnx")
model = model.transform(Streamline())
model.save(build_dir+"/tfc_w1_a1_streamlined.onnx")
showInNetron(build_dir+"/tfc_w1_a1_streamlined.onnx")
```
%% Output
Stopping http://0.0.0.0:8081
Serving '/workspace/finn/tfc_w1_a1_streamlined.onnx' at http://0.0.0.0:8081
<IPython.lib.display.IFrame at 0x7fe1346e4ef0>
%% Cell type:markdown id: tags:
You can see that the network has become simplified considerably compared to the previous step -- a lot of nodes have disappeared between the `MatMul` layers, and the `Sign` nodes have been replaced with `MultiThreshold` nodes instead.
**The current implementation of streamlining is highly network-specific and may not work for your network if its topology is very different than the example network here. We hope to rectify this in future releases.**
Our example network is a quantized network with 1-bit bipolar (-1, +1 values) precision, and we want FINN to implement them as XNOR-popcount operations [as described in the original FINN paper](https://arxiv.org/pdf/1612.07119). For this reason, after streamlining, the resulting bipolar matrix multiplications are converted into xnorpopcount operations. This transformation produces operations that are again collapsed and converted into thresholds. This procedure is shown below.
%% Cell type:code id: tags:
``` python
from finn.transformation.bipolar_to_xnor import ConvertBipolarMatMulToXnorPopcount
import finn.transformation.streamline.absorb as absorb
from finn.transformation.streamline.round_thresholds import RoundAndClipThresholds
model = model.transform(ConvertBipolarMatMulToXnorPopcount())
model = model.transform(absorb.AbsorbAddIntoMultiThreshold())
model = model.transform(absorb.AbsorbMulIntoMultiThreshold())
model = model.transform(RoundAndClipThresholds())
model.save(build_dir+"/tfc_w1a1_ready_for_hls_conversion.onnx")
showInNetron(build_dir+"/tfc_w1a1_ready_for_hls_conversion.onnx")
```
%% Output
Stopping http://0.0.0.0:8081
Serving '/workspace/finn/tfc_w1a1_ready_for_hls_conversion.onnx' at http://0.0.0.0:8081
<IPython.lib.display.IFrame at 0x7fe1346f7780>
%% Cell type:markdown id: tags:
Observe the pairs of `XnorPopcountmatMul` and `MultiThreshold` layers following each other -- this is the particular pattern that the next step will be looking for in order to convert them to HLS layers.
%% Cell type:markdown id: tags:
### Conversion to HLS layers <a id='hls_layers'></a>
Converts the nodes to HLS layers that correspond to the functions in [finn-hls library](https://finn-hlslib.readthedocs.io/en/latest/). In our case this transformation converts pairs of binary XnorPopcountMatMul layers to StreamingFCLayer_Batch layers. Any immediately following MultiThreshold layers will also be absorbed into the MVTU.
Below is the code for the transformation and the network is visualized using netron to create the new structure with `StreamingFCLayer_Batch` nodes, which will correspond to a function call from the [finn-hlslib](https://finn-hlslib.readthedocs.io/en/latest/library/fclayer.html#_CPPv4I_j_j_j_j000_i_i000E22StreamingFCLayer_BatchvRN3hls6streamI7ap_uintI9InStreamWEEERN3hls6streamI7ap_uintI10OutStreamWEEERK2TWRK2TAKjRK1R) library.
%% Cell type:markdown id: tags:
**Note:** The transformation `to_hls.InferBinaryStreamingFCLayer` gets the string "decoupled" as argument, this indicates the `mem_mode` for the weights. In FINN there are different options to set the way the weights are stored and accessed. For details please see the corresponding FINN readthedocs website.
**Note:** The transformation `to_hls.InferBinaryStreamingFCLayer` gets the string "decoupled" as argument, this indicates the `mem_mode` for the weights. In FINN there are different options to set the way the weights are stored and accessed. For details please have a look on the [FINN readthedocs website](https://finn.readthedocs.io/) under Internals.
%% Cell type:code id: tags:
``` python
import finn.transformation.fpgadataflow.convert_to_hls_layers as to_hls
model = ModelWrapper(build_dir+"/tfc_w1a1_ready_for_hls_conversion.onnx")
model = model.transform(to_hls.InferBinaryStreamingFCLayer("decoupled"))
model.save(build_dir+"/tfc_w1_a1_hls_layers.onnx")
showInNetron(build_dir+"/tfc_w1_a1_hls_layers.onnx")
```
%% Output
Stopping http://0.0.0.0:8081
Serving '/workspace/finn/tfc_w1_a1_hls_layers.onnx' at http://0.0.0.0:8081
<IPython.lib.display.IFrame at 0x7fe1346f1080>
%% Cell type:markdown id: tags:
Each StreamingFCLayer_Batch node has two attributes that specify the degree of folding, PE and SIMD. In all nodes the values for these attributes are set as default to 1, which would correspond to a maximum folding (time multiplexing) and thus minimum performance. We will shortly cover how these can be adjusted, but first we want to separate the HLS layers from the non-HLS layers in this network.
%% Cell type:markdown id: tags:
### Creating a Dataflow Partition <a id='dataflow_partition'></a>
In the graph above, you can see that there is a mixture of FINN HLS layers (StreamingFCLayer_Batch) with regular ONNX layers (Reshape, Mul, Add). To create a bitstream, FINN needs a model with only HLS layers. In order to achieve this, we will use the `CreateDataflowPartition` transformation to create a "dataflow partition" in this graph, separating out the HLS layers into another model, and replacing them with a placeholder layer called StreamingDataflowPartition:
%% Cell type:code id: tags:
``` python
from finn.transformation.fpgadataflow.create_dataflow_partition import CreateDataflowPartition
model = ModelWrapper(build_dir+"/tfc_w1_a1_hls_layers.onnx")
parent_model = model.transform(CreateDataflowPartition())
parent_model.save(build_dir+"/tfc_w1_a1_dataflow_parent.onnx")
showInNetron(build_dir+"/tfc_w1_a1_dataflow_parent.onnx")
```
%% Output
Stopping http://0.0.0.0:8081
Serving '/workspace/finn/tfc_w1_a1_dataflow_parent.onnx' at http://0.0.0.0:8081
<IPython.lib.display.IFrame at 0x7fe1ad0b6e48>
%% Cell type:markdown id: tags:
We can see that the StreamingFCLayer instances have all been replaced with a single `StreamingDataflowPartition`, which has an attribute `model` that points to the extracted, HLS dataflow-only graph:
%% Cell type:code id: tags:
``` python
from finn.custom_op.registry import getCustomOp
sdp_node = getCustomOp(parent_model.graph.node[2])
dataflow_model_filename = sdp_node.get_nodeattr("model")
showInNetron(dataflow_model_filename)
```
%% Output
Stopping http://0.0.0.0:8081
Serving '/tmp/finn_dev_jakobap/dataflow_partition_pbrjefjg/df_model.onnx' at http://0.0.0.0:8081
<IPython.lib.display.IFrame at 0x7fe1346f3550>
%% Cell type:markdown id: tags:
We can see all the extracted `StreamingFCLayer` instances have been moved to the child (dataflow) model. We will load the child model with `ModelWrapper` and continue working on it.
%% Cell type:code id: tags:
``` python
model = ModelWrapper(dataflow_model_filename)
```
%% Cell type:markdown id: tags:
### Folding and Datawidth Converter, FIFO and TLastMarker Insertion <a id='folding'></a>
*Folding* in FINN describes how much a layer is time-multiplexed in terms of execution resources. There are several *folding factors* for each layer, controlled by the PE (parallelization over outputs) and SIMD (parallelization over inputs) parameters as described by the original [FINN paper](https://arxiv.org/pdf/1612.07119). The higher the PE and SIMD values are set, the faster the generated accelerator will run, and the more FPGA resources it will consume.
Since the folding parameters are node attributes, they can be easily accessed and changed using a helper function of the `ModelWrapper`. But first we take a closer look at one of the nodes that implement a StreamingFCLayer_Batch operation. This is where the Netron visualization helps us, in the above diagram we can see that the first four nodes are StreamingFCLayer_Batch. So as an example we extract the first node.
%% Cell type:markdown id: tags:
We can use the higher-level [HLSCustomOp](https://github.com/Xilinx/finn/blob/master/src/finn/custom_op/fpgadataflow/__init__.py) wrappers for this node. These wrappers provide easy access to specific properties of these nodes, such as the folding factors (PE and SIMD). Let's have a look at which node attributes are defined by the CustomOp wrapper, and adjust the SIMD and PE attributes.
%% Cell type:code id: tags:
``` python
fc0 = model.graph.node[0]
fc0w = getCustomOp(fc0)
print("CustomOp wrapper is of class " + fc0w.__class__.__name__)
fc0w.get_nodeattr_types()
```
%% Output
CustomOp wrapper is of class StreamingFCLayer_Batch
{'PE': ('i', True, 0),
'SIMD': ('i', True, 0),
'MW': ('i', True, 0),
'MH': ('i', True, 0),
'resType': ('s', True, ''),
'ActVal': ('i', False, 0),
'inputDataType': ('s', True, ''),
'weightDataType': ('s', True, ''),
'outputDataType': ('s', True, ''),
'binaryXnorMode': ('i', False, 0),
'noActivation': ('i', False, 0),
'numInputVectors': ('ints', False, [1]),
'mem_mode': ('s', False, 'const'),
'ram_style': ('s', False, 'auto'),
'backend': ('s', True, 'fpgadataflow'),
'code_gen_dir_cppsim': ('s', False, ''),
'code_gen_dir_ipgen': ('s', False, ''),
'executable_path': ('s', False, ''),
'ipgen_path': ('s', False, ''),
'ip_path': ('s', False, ''),
'ip_vlnv': ('s', False, ''),
'exec_mode': ('s', False, ''),
'sim_cycles': ('i', False, 0),
'rtlsim_trace': ('s', False, ''),
'res_estimate': ('s', False, ''),
'res_hls': ('s', False, ''),
'res_synth': ('s', False, ''),
'rtlsim_so': ('s', False, ''),
'inFIFODepth': ('i', False, 2),
'outFIFODepth': ('i', False, 2)}
%% Cell type:markdown id: tags:
We can see that the PE and SIMD are listed as node attributes, as well as the depths of the FIFOs that will be inserted between consecutive layers, and all can be adjusted using `set_nodeattr` subject to certain constraints.
**In this notebook we are setting the folding factors and FIFO depths manually, but in a future version we will support determining the folding factors given an FPGA resource budget according to the analytical model from the [FINN-R paper](https://arxiv.org/pdf/1809.04570).**
%% Cell type:code id: tags:
``` python
fc_layers = model.get_nodes_by_op_type("StreamingFCLayer_Batch")
# (PE, SIMD, in_fifo_depth, out_fifo_depth, ramstyle) for each layer
config = [
(16, 49, 16, 64, "block"),
(8, 8, 64, 64, "auto"),
(8, 8, 64, 64, "auto"),
(10, 8, 64, 10, "distributed"),
]
for fcl, (pe, simd, ififo, ofifo, ramstyle) in zip(fc_layers, config):
fcl_inst = getCustomOp(fcl)
fcl_inst.set_nodeattr("PE", pe)
fcl_inst.set_nodeattr("SIMD", simd)
fcl_inst.set_nodeattr("inFIFODepth", ififo)
fcl_inst.set_nodeattr("outFIFODepth", ofifo)
fcl_inst.set_nodeattr("ram_style", ramstyle)
```
%% Cell type:markdown id: tags:
We are setting PE and SIMD so that each layer has a total folding of 16.
%% Cell type:markdown id: tags:
Besides PE and SIMD three other node attributes are set. `ram_style` specifies how the weights are to be stored (BRAM, LUTRAM, and so on). It can be selected explicitly or with the option `auto` you can let Vivado decide.
`inFIFODepth` and `outFIFODepth` specifies the FIFO depths that is needed by the node from the surrounding FIFOs. These attributes are used in the transformation 'InsertFIFO' to insert the appropriate FIFOs between the nodes.
But before FIFOs can be added, it must be determined whether datawidth converters (DWC) are required and they must be inserted correctly. Because by setting the folding, the folded output shape of one node may not match the folded input shape of the next node.
In the following, first DWCs and then FIFOs are inserted using the corresponding transformations in FINN.
%% Cell type:code id: tags:
``` python
from finn.transformation.fpgadataflow.insert_dwc import InsertDWC
from finn.transformation.fpgadataflow.insert_fifo import InsertFIFO
model = model.transform(InsertDWC())
model = model.transform(InsertFIFO())
```
%% Cell type:markdown id: tags:
Finally, we will run the `InsertTLastMarker` transformation to get a `TLastMarker` node at the output of this graph, which is necessary to run the DMA engines correctly. Using netron we can observe that now the nodes contain the set folding, if necessary a DWC is inserted, inbetween the nodes are FIFOs inserted and the last node is the `TLastMarker` node we insert in the following.
%% Cell type:code id: tags:
``` python
from finn.transformation.fpgadataflow.insert_tlastmarker import InsertTLastMarker
model = model.transform(InsertTLastMarker())
model.save(build_dir+"/tfc_w1_a1_set_folding_factors.onnx")
showInNetron(build_dir+"/tfc_w1_a1_set_folding_factors.onnx")
```
%% Output
Stopping http://0.0.0.0:8081
Serving '/workspace/finn/tfc_w1_a1_set_folding_factors.onnx' at http://0.0.0.0:8081
<IPython.lib.display.IFrame at 0x7fe135b84780>
%% Cell type:markdown id: tags:
This completes the network preparation and the network can be passed on to the next block *Vivado HLS and IPI*, which is described below.
%% Cell type:markdown id: tags:
## 3. Vivado HLS and IPI <a id='vivado'></a>
* [Generating HLS Code](#hls_per_layer)
* [Synthesizing HLS to IP Blocks](#hls_synth)
* [IP Stitching](#ip_stitching)
As we will be dealing with FPGA synthesis tools in these tasks, we'll define two helper variables that describe the Xilinx FPGA part name and the PYNQ board name that we are targeting.
%% Cell type:code id: tags:
``` python
# print the names of the supported PYNQ boards
from finn.util.basic import pynq_part_map
print(pynq_part_map.keys())
```
%% Output
dict_keys(['Ultra96', 'Pynq-Z1', 'Pynq-Z2', 'ZCU104'])
%% Cell type:code id: tags:
``` python
# change this if you have a different PYNQ board, see list above
pynq_board = "Pynq-Z1"
fpga_part = pynq_part_map[pynq_board]
target_clk_ns = 10
```
%% Cell type:markdown id: tags:
### Generating HLS Code <a id='hls_per_layer'></a>
This section deals with the generation of an IP block from the different layers. These can then be stitched to a block design that corresponds to the complete model. The single conversion into IP blocks allows a good transparency and we can check the functionality of each IP block and compare it with the behaviour of the corresponding ONNX node.
%% Cell type:markdown id: tags:
Two transformations are required to generate HLS IP blocks for each layer:
* `PrepareIP` which generates the HLS C++ code for the node and a tcl-script which starts the HLS synthesis and exports the design as IP.
* `HLSSynthIP` which passes the tcl-script to Vivado HLS and thus performs the actual IP generation.
We start off by giving unique node names using the basic transformation `GiveUniqueNodeNames`, and then proceed with the HLS C++ code generation with `PrepareIP`.
%% Cell type:code id: tags:
``` python
model = ModelWrapper(build_dir+"/tfc_w1_a1_set_folding_factors.onnx")
model = model.transform(GiveUniqueNodeNames())
from finn.transformation.fpgadataflow.prepare_ip import PrepareIP
model = model.transform(PrepareIP(fpga_part, target_clk_ns))
```
%% Cell type:markdown id: tags:
### Synthesizing HLS to IP Blocks <a id='hls_synth'></a>
Now that we have generated the HLS code for each layer, we can call the `HLSSynthIP` transformation to convert the generated HLS into Vivado IP blocks. **As this involves calling HLS synthesis, this transformation will run for some time (several minutes).**
%% Cell type:code id: tags:
``` python
from finn.transformation.fpgadataflow.hlssynth_ip import HLSSynthIP
model = model.transform(HLSSynthIP())
model.save(build_dir+"/tfc_w1_a1_ipgen.onnx")
```
%% Cell type:markdown id: tags:
Each `StreamingFCLayer_Batch` node now has new attributes which can be examined more closely with netron.
%% Cell type:code id: tags:
``` python
showInNetron(build_dir+"/tfc_w1_a1_ipgen.onnx")
```
%% Output
Stopping http://0.0.0.0:8081
Serving '/workspace/finn/tfc_w1_a1_ipgen.onnx' at http://0.0.0.0:8081
<IPython.lib.display.IFrame at 0x7fe1346f7588>
%% Cell type:markdown id: tags:
There are two additional attributes:
* `code_gen_dir_ipgen` which contains the directory path where all the files generated by the ipgen transformations are stored
* `ipgen_path` which contains the path to the project directory in which the generated IP block is stored
We can further investigate which files are produced by taking a look in this directory. For example for the first StreamingFCLayer_Batch node.
%% Cell type:code id: tags:
``` python
fc0w = getCustomOp(model.graph.node[1])
code_gen_dir = fc0w.get_nodeattr("code_gen_dir_ipgen")
!ls {code_gen_dir}
```
%% Output
StreamingFCLayer_Batch_0_memstream.v thresh.h
hls_syn_StreamingFCLayer_Batch_0.tcl top_StreamingFCLayer_Batch_0.cpp
ipgen.sh vivado_hls.log
memblock_0.dat weights.npy
project_StreamingFCLayer_Batch_0
%% Cell type:markdown id: tags:
Directory *project_StreamingFCLayer_Batch_0* contains the project created by Vivado HLS into which the IP Block is exported, along with other files generated by Vivado HLS. If we compare it to the above visualization of the network with netron, this is exactly the name of the folder stored in the node attribute `ipgen_path`. The .cpp code that is passed to Vivado HLS can be found in the file *top_StreamingFCLayer_Batch_0.cpp*. The file *thresh.h* belongs to that as well, it contains the value for the thresholds. The weights are stored as .npy file and as .dat file (*memblock_0.dat*). *vivado_hls.log* is the log file from Vivado HLS. Besides these files, the folder contains *ipgen.sh* and *hls_syn_StreamingFCLayer_Batch_0.tcl* and because we use the StreamingFCLayer in "decoupled" mode a verilog wrapper (*StreamingFCLayer_Batch_0_memstream.v*) is produced, for more details on "decoupled" and "const" mode please see on the readthedocs website.
Directory *project_StreamingFCLayer_Batch_0* contains the project created by Vivado HLS into which the IP Block is exported, along with other files generated by Vivado HLS. If we compare it to the above visualization of the network with netron, this is exactly the name of the folder stored in the node attribute `ipgen_path`. The .cpp code that is passed to Vivado HLS can be found in the file *top_StreamingFCLayer_Batch_0.cpp*. The file *thresh.h* belongs to that as well, it contains the value for the thresholds. The weights are stored as .npy file and as .dat file (*memblock_0.dat*). *vivado_hls.log* is the log file from Vivado HLS. Besides these files, the folder contains *ipgen.sh* and *hls_syn_StreamingFCLayer_Batch_0.tcl* and because we use the StreamingFCLayer in "decoupled" mode a verilog wrapper (*StreamingFCLayer_Batch_0_memstream.v*) is produced, for more details on "decoupled" and "const" mode please see on the [FINN readthedocs website](https://finn.readthedocs.io/) under Internals.
In the following we take a closer look at the two generated scripts. We start with *ipgen.sh*.
%% Cell type:code id: tags:
``` python
shell_script = code_gen_dir + "/ipgen.sh"
!cat {shell_script}
```
%% Output
#!/bin/bash
cd /tmp/finn_dev_jakobap/code_gen_ipgen_StreamingFCLayer_Batch_0_edb__5oc
vivado_hls /tmp/finn_dev_jakobap/code_gen_ipgen_StreamingFCLayer_Batch_0_edb__5oc/hls_syn_StreamingFCLayer_Batch_0.tcl
cd /workspace/finn
%% Cell type:markdown id: tags:
The script consists only of two framing `cd` commands and a command to pass the tcl script to *vivado_hls*. The directory has to be changed to create the files in the correct folder and will then be changed back to the original directory.
Below is the tcl script which is passed to *vivado_hls*.
%% Cell type:code id: tags:
``` python
tcl_script = code_gen_dir + "/hls_syn_StreamingFCLayer_Batch_0.tcl"
!cat {tcl_script}
```
%% Output
set config_proj_name project_StreamingFCLayer_Batch_0
puts "HLS project: $config_proj_name"
set config_hwsrcdir "/tmp/finn_dev_jakobap/code_gen_ipgen_StreamingFCLayer_Batch_0_edb__5oc"
puts "HW source dir: $config_hwsrcdir"
set config_proj_part "xc7z020clg400-1"
set config_bnnlibdir "/workspace/finn-hlslib"
set config_toplevelfxn "StreamingFCLayer_Batch_0"
set config_clkperiod 10
open_project $config_proj_name
add_files $config_hwsrcdir/top_StreamingFCLayer_Batch_0.cpp -cflags "-std=c++0x -I$config_bnnlibdir"
set_top $config_toplevelfxn
open_solution sol1
set_part $config_proj_part
config_interface -m_axi_addr64
config_rtl -auto_prefix
create_clock -period $config_clkperiod -name default
csynth_design
export_design -format ip_catalog
exit 0
%% Cell type:markdown id: tags:
In the first part of the script the project is configured. For example the FPGA part and the clock are set. Then the project is opened and the files are added. The toplevel function is set and after creating a clock, the design is first synthesized with `csynth` and then exported as an IP block.
Now that all IP blocks are in place, they can be stitched together to create an IP design that matches the ONNX model. This is covered in the next section.
%% Cell type:markdown id: tags:
### IP Stitching <a id='ip_stitching'></a>
We now have IP blocks for each of our layers, and will stitch them together into a larger IP that implements the whole network using the `CreateStitchedIP` transformation. Bear in mind that this transformation can only be applied on a graph that only contains HLS nodes that already have been through the `HLSSynthIP` transformation, which is the last step we performed. Prior to calling IP stitching, we'll also use the `ReplaceVerilogRelPaths` transformation to convert any relative `$readmemh` paths in the generated IP blocks to absolute ones, which prevents errors later on. **This step invokes Vivado and may take a few minutes to run.**
%% Cell type:code id: tags:
``` python
from finn.transformation.fpgadataflow.create_stitched_ip import CreateStitchedIP
from finn.transformation.fpgadataflow.replace_verilog_relpaths import ReplaceVerilogRelPaths
model = ModelWrapper(build_dir+"/tfc_w1_a1_ipgen.onnx")
model = model.transform(ReplaceVerilogRelPaths())
model = model.transform(CreateStitchedIP(fpga_part))
```
%% Cell type:markdown id: tags:
If you examine the nodes themselves on the transformed model you won't see a difference, because the IP stitching adds model-level metadata to the graph. This can be accessed using the `.model.metadata_props`, the `get_metadata_prop` function in `ModelWrapper`, or by clicking on the global input/output tensors in Netron.
%% Cell type:code id: tags:
``` python
model.model.metadata_props
```
%% Output
[key: "vivado_stitch_proj"
value: "/tmp/finn_dev_jakobap/vivado_stitch_proj_oa43bqzl"
, key: "vivado_stitch_vlnv"
value: "xilinx_finn:finn:finn_design:1.0"
, key: "wrapper_filename"
value: "/tmp/finn_dev_jakobap/vivado_stitch_proj_oa43bqzl/finn_vivado_stitch_proj.srcs/sources_1/bd/finn_design/hdl/finn_design_wrapper.v"
]
%% Cell type:code id: tags:
``` python
model.get_metadata_prop("vivado_stitch_proj")
```
%% Output
'/tmp/finn_dev_jakobap/vivado_stitch_proj_oa43bqzl'
%% Cell type:markdown id: tags:
If you navigate to the folder above (remember the /tmp/finn_xxx folder is mounted on the host as well as inside Docker) you can open the Vivado project (.xpr) file there using Vivado, and view the following stitched IP block design:
%% Cell type:markdown id: tags:
![](stitched_ip.png)
%% Cell type:code id: tags:
``` python
model.save(build_dir+"/tfc_w1_a1_ipstitch.onnx")
```
%% Cell type:markdown id: tags:
At this point, one could take the generated stitched IP and integrate it into your own project using Vivado IP Integrator if desired. Here, we will continue the tutorial by assuming that we want to do a stand-alone deployment for this accelerator for a PYNQ board.
%% Cell type:markdown id: tags:
## 4. PYNQ hardware generation and deployment <a id='hw_test'></a>
* [Inserting the IP into a PYNQ Overlay Shell](#pynq_shell)
* [Synthesis, Place and Route](#synth_pl_ro)
* [Driver Generation](#driver_gen)
* [Deployment and Remote Execution](#deploy)
* [Throughput Test on PYNQ Board](#throughput)
We are almost done preparing our hardware design. We'll now put it in a form suitable for use as a PYNQ overlay, synthesize and deploy it.
%% Cell type:markdown id: tags:
### Inserting the IP into a PYNQ Overlay Shell <a id='pynq_shell'></a>
To deploy our accelerator on a PYNQ platform, it needs to be put inside an appropriate *shell* that bridges it with the interfaces that the underlying system exposes. FINN makes it easy to create a PYNQ-compatible overlay by inserting the stitched IP into an appropriate PYNQ shell with the `MakePYNQProject` transformation, and view the created PYNQ shell project directory using the `metadata_props`. **This invokes Vivado and may take a few minutes to run.**
%% Cell type:code id: tags:
``` python
from finn.transformation.fpgadataflow.make_pynq_proj import MakePYNQProject
model = ModelWrapper(build_dir+"/tfc_w1_a1_ipstitch.onnx")
model = model.transform(MakePYNQProject(pynq_board))
model.model.metadata_props
```
%% Output
[key: "vivado_stitch_proj"
value: "/tmp/finn_dev_jakobap/vivado_stitch_proj_oa43bqzl"
, key: "vivado_stitch_vlnv"
value: "xilinx_finn:finn:finn_design:1.0"
, key: "wrapper_filename"
value: "/tmp/finn_dev_jakobap/vivado_stitch_proj_oa43bqzl/finn_vivado_stitch_proj.srcs/sources_1/bd/finn_design/hdl/finn_design_wrapper.v"
, key: "vivado_pynq_proj"
value: "/tmp/finn_dev_jakobap/vivado_pynq_proj_ljn53hfs"
, key: "vivado_synth_rpt"
value: "/tmp/finn_dev_jakobap/vivado_pynq_proj_ljn53hfs/synth_report.xml"
]
%% Cell type:code id: tags:
``` python
! ls {model.get_metadata_prop("vivado_pynq_proj")}
```
%% Output
ip_config.tcl resizer.cache resizer.ip_user_files resizer.xpr
make_project.sh resizer.hw resizer.srcs synth_project.sh
%% Cell type:markdown id: tags:
If we open the created Vivado project (.xpr) under the `vivado_pynq_proj` directory above, we can see the system-level block design as below, with the FINN-generated part of the design highlighted. Various other components, such as the DMA engine and data width converters, have also been instantiated.
![](pynq_shell_project.png)
%% Cell type:code id: tags:
``` python
model.save(build_dir + "/tfc_w1_a1_pynq_project.onnx")
```
%% Cell type:markdown id: tags:
### Synthesis, Place and Route <a id='synth_pl_ro'></a>
%% Cell type:markdown id: tags:
We are now ready for the final hardware generation step, which is synthesis, place and route to generate an FPGA bitfile. This can be done by either running the `synth_project.sh` script in the generated Vivado PYNQ project directory inside Docker, or by executing the `SynthPYNQProject` transformation. **This step involves launching Vivado for synthesis and may take a few hours.**
%% Cell type:code id: tags:
``` python
from finn.transformation.fpgadataflow.synth_pynq_proj import SynthPYNQProject
model = ModelWrapper(build_dir + "/tfc_w1_a1_pynq_project.onnx")
model = model.transform(SynthPYNQProject())
model.model.metadata_props
```
%% Output
[key: "vivado_stitch_proj"
value: "/tmp/finn_dev_jakobap/vivado_stitch_proj_oa43bqzl"
, key: "vivado_stitch_vlnv"
value: "xilinx_finn:finn:finn_design:1.0"
, key: "wrapper_filename"
value: "/tmp/finn_dev_jakobap/vivado_stitch_proj_oa43bqzl/finn_vivado_stitch_proj.srcs/sources_1/bd/finn_design/hdl/finn_design_wrapper.v"
, key: "vivado_pynq_proj"
value: "/tmp/finn_dev_jakobap/vivado_pynq_proj_ljn53hfs"
, key: "vivado_synth_rpt"
value: "/tmp/finn_dev_jakobap/vivado_pynq_proj_ljn53hfs/synth_report.xml"
, key: "vivado_pynq_bitfile"
value: "/tmp/finn_dev_jakobap/vivado_pynq_proj_ljn53hfs/resizer.bit"
]
%% Cell type:code id: tags:
``` python
model.save(build_dir + "/tfc_w1_a1_post_synthesis.onnx")
```
%% Cell type:markdown id: tags:
### Driver Generation <a id='driver_gen'></a>
Now that we have synthesized a bitfile for our network, we will generate some Python code for PYNQ that will act as the driver for this bitfile, package everything into a deployment folder and copy that to our PYNQ board.
%% Cell type:code id: tags:
``` python
from finn.transformation.fpgadataflow.make_pynq_driver import MakePYNQDriver
model = ModelWrapper(build_dir + "/tfc_w1_a1_post_synthesis.onnx")
model = model.transform(MakePYNQDriver())
```
%% Cell type:markdown id: tags:
The generated driver is placed in a folder that is indicated by the `pynq_driver_dir` top-level metadata. We can examine the generated PYNQ Python driver code as follows:
%% Cell type:code id: tags:
``` python
driver_dir = model.get_metadata_prop("pynq_driver_dir")
! cat {driver_dir}/driver.py
```
%% Output
import argparse
from pynq import Overlay
import numpy as np
from pynq import allocate
import time
from finn.util.data_packing import (
finnpy_to_packed_bytearray,
packed_bytearray_to_finnpy
)
from finn.core.datatype import DataType
class FINNAccelDriver():
def __init__(self, N, bitfile):
"""Instantiate the FINN accelerator driver.
Gets batchsize (N) as integer and path to bitfile as string."""
self.N = N
# input FINN DataType
self.idt = DataType.BINARY
# output FINN DataType
self.odt = DataType.UINT32
# input and output shapes
self.ishape_normal = (N, 784)
self.oshape_normal = (N, 10)
self.ishape_folded = (N, 16, 49)
self.oshape_folded = (N, 1, 10)
self.ishape_packed = (N, 16, 7) # datatype np.uint8
self.oshape_packed = (N, 1, 40) # datatype np.uint8
# load bitfile and set up accelerator
self.ol = Overlay(bitfile)
self.dma = self.ol.axi_dma_0
self.ctrl_regs = self.ol.resize_accel_0
# neuron folding factor of output = iterations per sample
self.itersPerSample = self.oshape_packed[-2]
# AXI lite register offset for number of iterations
# used by TLastMarker to signal end of transmission for AXI CDMA
self.REG_OFFSET_NUM_ITERS = 0x10
# set up TLastMarker with correct num. samples
self.ctrl_regs.write(self.REG_OFFSET_NUM_ITERS, self.N*self.itersPerSample)
# allocate a PYNQ buffer for the packed input and buffer
self.ibuf_packed_device = allocate(shape=self.ishape_packed, dtype=np.uint8)
self.obuf_packed_device = allocate(shape=self.oshape_packed, dtype=np.uint8)
def fold_input(self, ibuf_normal):
"""Reshapes input in desired shape.
Gets input data (ibuf_normal), checks if data is in expected normal shape.
Returns folded input."""
# ensure that shape is as expected
assert ibuf_normal.shape == self.ishape_normal
# convert to folded form
ibuf_folded = ibuf_normal.reshape(self.ishape_folded)
return ibuf_folded
def pack_input(self, ibuf_folded):
"""Packs folded input and reverses both SIMD dim and endianness.
Gets input data in folded shape and returns packed input data."""
ibuf_packed = finnpy_to_packed_bytearray(
ibuf_folded, self.idt, reverse_endian=True, reverse_inner=True
)
return ibuf_packed
def unpack_output(self, obuf_packed):
"""Unpacks the packed output buffer from accelerator.
Gets packed output and returns output data in folded shape."""
obuf_folded = packed_bytearray_to_finnpy(
obuf_packed, self.odt, self.oshape_folded, reverse_endian=True, reverse_inner=True
)
return obuf_folded
def unfold_output(self, obuf_folded):
"""Unfolds output data to normal shape.
Gets folded output data and returns output data in normal shape."""
obuf_normal = obuf_folded.reshape(self.oshape_normal)
return obuf_normal
def copy_input_data_to_device(self, data):
"""Copies given input data to PYNQ buffer."""
np.copyto(self.ibuf_packed_device, data)
def execute(self):
"""Executes accelerator by setting up the DMA and
waiting until all transfers complete. Uses only member variables and
returns nothing."""
dma = self.dma
dma.sendchannel.transfer(self.ibuf_packed_device)
dma.recvchannel.transfer(self.obuf_packed_device)
dma.sendchannel.wait()
dma.recvchannel.wait()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Set exec mode, batchsize N, bitfile name, inputfile name and outputfile name')
parser.add_argument('--exec_mode', help='Please select functional verification ("execute") or throughput test ("throughput_test")', default="execute")
parser.add_argument('--batchsize', help='number of samples for inference', type=int, default=1)
parser.add_argument('--bitfile', help='name of bitfile (i.e. "resizer.bit")', default="resizer.bit")
parser.add_argument('--inputfile', help='name of input npy file (i.e. "input.npy")', default="input.npy")
parser.add_argument('--outputfile', help='name of output npy file (i.e. "output.npy")', default="output.npy")
# parse arguments
args = parser.parse_args()
exec_mode = args.exec_mode
N = args.batchsize
bitfile = args.bitfile
inputfile = args.inputfile
outputfile = args.outputfile
# instantiate FINN accelerator driver and pass batchsize and bitfile
finnDriver = FINNAccelDriver(N, bitfile)
# for the remote execution the data from the input npy file has to be loaded,
# packed and copied to the PYNQ buffer
if exec_mode == "execute":
# load desired input .npy file
ibuf_normal = np.load(inputfile)
ibuf_folded = finnDriver.fold_input(ibuf_normal)
ibuf_packed = finnDriver.pack_input(ibuf_folded)
finnDriver.copy_input_data_to_device(ibuf_packed)
elif exec_mode != "throughput_test":
raise Exception("Exec mode has to be set to remote_pynq or throughput_test")
# for the throughput test the runtime of the network has to be measured
if exec_mode == "throughput_test":
# measure runtime of network
start = time.time()
# dictionary for results of throughput test
res={}
# execute accelerator
finnDriver.execute()
# measure run time and fill dictionary with results of the throughput test
if exec_mode == "throughput_test":
end = time.time()
runtime = end - start
res["runtime[ms]"] = runtime*1000
res["throughput[images/s]"] = N / runtime
res["DRAM_in_bandwidth[Mb/s]"] = np.prod(finnDriver.ishape_packed)*0.000001 / runtime
res["DRAM_out_bandwidth[Mb/s]"] = np.prod(finnDriver.oshape_packed)*0.000001 / runtime
file = open("nw_metrics.txt", "w")
file.write(str(res))
file.close()
# if execution is selected unpack, unfold and save output to output npy file
else:
obuf_folded = finnDriver.unpack_output(finnDriver.obuf_packed_device)
obuf_normal = finnDriver.unfold_output(obuf_folded)
np.save(outputfile, obuf_normal)
%% Cell type:markdown id: tags:
We can see that in the generated driver a class is implemented which implements the FINN accelerator. The constructor gets the batchsize (N) as integer and the bitfile as string. It also contains the expected input/output shapes, and takes care of the instantiation of the accelerator by loading the bitfile and setting up dma and buffer. Several member functions take care of the data folding and packing. The function `copy_input_data_to_device` copies the input data into the PYNQ buffer and `execute` sets up the dma channels and waits until the transfer is completed. This class is used in the main function. But first the arguments are parsed, which are passed to the script. The driver can be used in two modes: "execute" and "throughput_test". By default all arguments are set to "execute" mode. In this mode the batch size is 1, and the passed files are set to the names used by the FINN transformations.
In the "execute" mode works as follows:
1. the data is loaded from the "inputfile"
2. the data is folded using `fold_input`
3. the data is packed using `pack_input`
4. the data is copied to the device using `copy_input_data_to_device`
5. FINNAccelDriver is executed using `execute`
6. the data is unpacked using `unpack_output`
7. the data is unfolded using `unfold_output`
8. the data is stored in the "outputfile"
If "throughput_test" is selected as `exec_mode`, no actual data needs to be loaded. The batchsize N should be set to a high value (i.e. 1000) and a time measurement is implemented in python. An empty dictionary (`res`) is created and after running the accelerator with the measured runtime it is filled with the metrics and saved in a .txt file.
You can build your own applications around the accelerator by modifying the driver, or use the remote execution capabilities that FINN provides just to check if it is working, which will be our next step.
%% Cell type:markdown id: tags:
### Deployment and Remote Execution <a id='deploy'></a>
We'll now use the `DeployToPYNQ` transformation to create a deployment folder with the bitfile and driver file(s), and copy that to the PYNQ board. You can change the default IP address, username, password and target folder for the PYNQ below.
%% Cell type:code id: tags:
``` python
from finn.transformation.fpgadataflow.make_deployment import DeployToPYNQ
ip = "192.168.3.1"
port = "22"
username = "xilinx"
password = "xilinx"
target_dir = "/home/xilinx/finn_tfc_end2end_example"
model = model.transform(DeployToPYNQ(ip, port, username, password, target_dir))
model.save(build_dir + "/tfc_w1_a1_pynq_deploy.onnx")
```
%% Cell type:markdown id: tags:
Let's verify that the remote access credentials is saved in the model metadata, and that the deployment folder has been successfully copied to the board:
%% Cell type:code id: tags:
``` python
model.model.metadata_props
```
%% Output
[key: "vivado_stitch_proj"
value: "/tmp/finn_dev_jakobap/vivado_stitch_proj_oa43bqzl"
, key: "vivado_stitch_vlnv"
value: "xilinx_finn:finn:finn_design:1.0"
, key: "wrapper_filename"
value: "/tmp/finn_dev_jakobap/vivado_stitch_proj_oa43bqzl/finn_vivado_stitch_proj.srcs/sources_1/bd/finn_design/hdl/finn_design_wrapper.v"
, key: "vivado_pynq_proj"
value: "/tmp/finn_dev_jakobap/vivado_pynq_proj_ljn53hfs"
, key: "vivado_synth_rpt"
value: "/tmp/finn_dev_jakobap/vivado_pynq_proj_ljn53hfs/synth_report.xml"
, key: "vivado_pynq_bitfile"
value: "/tmp/finn_dev_jakobap/vivado_pynq_proj_ljn53hfs/resizer.bit"
, key: "pynq_driver_dir"
value: "/tmp/finn_dev_jakobap/pynq_driver_j_9suyqm"
, key: "pynq_ip"
value: "51.37.47.42"
, key: "pynq_port"
value: "23"
, key: "pynq_username"
value: "xilinx"
, key: "pynq_password"
value: "x1l1nx_f!nn"
, key: "pynq_target_dir"
value: "/home/xilinx/finn_tfc_end2end_example"
, key: "pynq_deployment_dir"
value: "/tmp/finn_dev_jakobap/pynq_deployment_962qxwkv"
, key: "pynq_deploy_dir"
value: "/tmp/finn_dev_jakobap/pynq_deployment_962qxwkv"
, key: "exec_mode"
value: "remote_pynq"
]
%% Cell type:code id: tags:
``` python
! sshpass -p {password} ssh {username}@{ip} -p {port} 'ls -l {target_dir}/*'
```
%% Output
/home/xilinx/finn_tfc_end2end_example/pynq_deployment_26e8h5jo:
total 4276
-rw-r--r-- 1 xilinx xilinx 6363 May 7 10:35 driver.py
drwxr-xr-x 4 xilinx xilinx 4096 May 7 10:35 finn
-rw-r--r-- 1 xilinx xilinx 3264 May 7 10:55 input.npy
-rw-r--r-- 1 root root 172 May 7 10:37 nw_metrics.txt
-rw-r--r-- 1 root root 120 May 7 10:55 output.npy
-rw-r--r-- 1 xilinx xilinx 4045675 May 7 10:35 resizer.bit
-rw-r--r-- 1 xilinx xilinx 302015 May 7 10:35 resizer.hwh
-rw-r--r-- 1 root root 32 May 7 10:55 sds_trace_data.dat
/home/xilinx/finn_tfc_end2end_example/pynq_deployment_962qxwkv:
total 4260
-rw-r--r-- 1 xilinx xilinx 6363 May 7 17:44 driver.py
drwxr-xr-x 4 xilinx xilinx 4096 May 7 17:44 finn
-rw-r--r-- 1 xilinx xilinx 4045675 May 7 17:44 resizer.bit
-rw-r--r-- 1 xilinx xilinx 302015 May 7 17:44 resizer.hwh
/home/xilinx/finn_tfc_end2end_example/pynq_deployment_kvurnk0c:
total 4300
-rw-r--r-- 1 xilinx xilinx 3861 Apr 27 12:36 driver.py
drwxr-xr-x 4 xilinx xilinx 4096 Apr 27 12:37 finn
-rw-r--r-- 1 xilinx xilinx 3264 Apr 27 12:37 input.npy
-rw-r--r-- 1 root root 78 Apr 27 12:38 nw_metrics.txt
-rw-r--r-- 1 root root 120 Apr 27 12:37 output.npy
-rw-r--r-- 1 xilinx xilinx 4045675 Apr 27 12:36 resizer.bit
-rw-r--r-- 1 xilinx xilinx 329531 Apr 27 12:36 resizer.hwh
-rw-r--r-- 1 root root 32 Apr 27 12:38 sds_trace_data.dat
/home/xilinx/finn_tfc_end2end_example/pynq_deployment__tnbutz_:
total 4276
-rw-r--r-- 1 xilinx xilinx 6363 May 6 17:34 driver.py
drwxr-xr-x 4 xilinx xilinx 4096 May 6 17:34 finn
-rw-r--r-- 1 xilinx xilinx 3264 May 6 17:34 input.npy
-rw-r--r-- 1 root root 173 May 6 17:35 nw_metrics.txt
-rw-r--r-- 1 root root 120 May 6 17:34 output.npy
-rw-r--r-- 1 xilinx xilinx 4045675 May 6 17:34 resizer.bit
-rw-r--r-- 1 xilinx xilinx 302015 May 6 17:34 resizer.hwh
-rw-r--r-- 1 root root 32 May 6 17:35 sds_trace_data.dat
/home/xilinx/finn_tfc_end2end_example/pynq_deployment_w4aa1r9k:
total 4276
-rw-r--r-- 1 xilinx xilinx 6363 May 7 15:05 driver.py
drwxr-xr-x 4 xilinx xilinx 4096 May 7 15:05 finn
-rw-r--r-- 1 xilinx xilinx 3264 May 7 15:06 input.npy
-rw-r--r-- 1 root root 172 May 7 15:11 nw_metrics.txt
-rw-r--r-- 1 root root 120 May 7 15:06 output.npy
-rw-r--r-- 1 xilinx xilinx 4045675 May 7 15:05 resizer.bit
-rw-r--r-- 1 xilinx xilinx 302015 May 7 15:05 resizer.hwh
-rw-r--r-- 1 root root 32 May 7 15:11 sds_trace_data.dat
%% Cell type:markdown id: tags:
We only have two more steps to be able to remotely execute the deployed bitfile with some test data from the MNIST dataset. Let's load up some test data that comes bundled with FINN.
%% Cell type:code id: tags:
``` python
from pkgutil import get_data
import onnx.numpy_helper as nph
import matplotlib.pyplot as plt
raw_i = get_data("finn", "data/onnx/mnist-conv/test_data_set_0/input_0.pb")
x = nph.to_array(onnx.load_tensor_from_string(raw_i))
plt.imshow(x.reshape(28,28), cmap='gray')
```
%% Output
<matplotlib.image.AxesImage at 0x7fe11dda48d0>
%% Cell type:markdown id: tags:
Recall that we partitioned our original network into a parent graph that contained the non-synthesizable nodes and a child graph that contained the bulk of the network, which we turned into a bitfile. We'll load up the parent graph, modify the `StreamingDataflowPartition` node so that it points to the deployed ONNX graph.
%% Cell type:code id: tags:
``` python
parent_model = ModelWrapper(build_dir+"/tfc_w1_a1_dataflow_parent.onnx")
sdp_node = parent_model.graph.node[2]
remote_exec_model = build_dir + "/tfc_w1_a1_pynq_deploy.onnx"
getCustomOp(sdp_node).set_nodeattr("model", remote_exec_model)
parent_model.save(build_dir+"/tfc_w1_a1_dataflow_parent_with_remote_bitfile_exec.onnx")
```
%% Cell type:markdown id: tags:
Finally, we can call `execute_onnx` on the parent graph, which will internally call remote execution with the bitfile once the `StreamingDataflowPartition` node is reached, grab the results, then continue executing the last portion of the network.
%% Cell type:code id: tags:
``` python
import numpy as np
from finn.core.onnx_exec import execute_onnx
iname = parent_model.graph.input[0].name
oname = parent_model.graph.output[0].name
ishape = parent_model.get_tensor_shape(iname)
input_dict = {iname: x.reshape(ishape)}
ret = execute_onnx(parent_model, input_dict, True)
```
%% Cell type:markdown id: tags:
We'll pass the output of the network through a softmax function to interpret it as probabilities, and plot the per-class probabilities as a bar chart.
%% Cell type:code id: tags:
``` python
def softmax(x):
"""Compute softmax values for each sets of scores in x."""
e_x = np.exp(x - np.max(x))
return e_x / e_x.sum()
logits = ret[oname].flatten()
prob = softmax(logits)
plt.bar(np.arange(10), prob)
```
%% Output
<BarContainer object of 10 artists>
%% Cell type:markdown id: tags:
We see that the network correctly predicts this as a digit 2 with high probability. This concludes our tutorial on how to take a simple fully-connected BNN all the way down to hardware with FINN, and execute it remotely on a PYNQ board.
%% Cell type:markdown id: tags:
### Throughput Test on PYNQ Board <a id='throughput'></a>
In addition to the functional verification, FINN also offers the possibility to measure the network performance directly on the PYNQ board. This can be done using the core function `throughput_test`. In the next section we import the function and execute it.
First we extract the `remote_exec_model` again and pass it to the function. The function returns the metrics of the network as dictionary.
%% Cell type:code id: tags:
``` python
from finn.core.throughput_test import throughput_test
child_model = ModelWrapper(getCustomOp(sdp_node).get_nodeattr("model"))
res = throughput_test(child_model)
print("Network metrics:")
for key in res:
print(str(key) + ": " + str(res[key]))
```
%% Output
Network metrics:
runtime[ms]: 1.4772415161132812
throughput[images/s]: 676937.378954164
DRAM_in_bandwidth[Mb/s]: 75.81698644286635
DRAM_out_bandwidth[Mb/s]: 27.07749515816656
%% Cell type:markdown id: tags:
Together with the values for folding we can evaluate the performance of our accelerator. Each layer has a total folding factor of 64 and because the network is fully pipelined, it follows: `II = 64`. II is the initiation interval and indicates how many cycles are needed for one input to be processed.
%% Cell type:code id: tags:
``` python
II = 64
# frequency in MHz
f_MHz = 100
# expected throughput in MFPS
expected_throughput = f_MHz / II
# measured throughput (FPS) from throughput test, converted to MFPS
measured_throughput = res["throughput[images/s]"] * 0.000001
# peformance
print("We reach approximately " + str(round((measured_throughput / expected_throughput)*100)) + "% of the ideal performance.")
```
%% Output
We reach approximately 43% of the ideal performance.
%% Cell type:markdown id: tags:
The measured values were recorded with a batch size of 1000 and at a frequency of 100 MHz. We will be improving the efficiency of the generated accelerator examples in the coming FINN releases.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment