Skip to content
Snippets Groups Projects
Commit f92b6014 authored by auphelia's avatar auphelia
Browse files

[Notebooks] Update basics/1_brevitas_network_import notebook

parent 0876f7a3
No related branches found
No related tags found
No related merge requests found
%% Cell type:markdown id: tags:
# Importing Brevitas networks into FINN
In this notebook we'll go through an example of how to import a Brevitas-trained QNN into FINN. The steps will be as follows:
1. Load up the trained PyTorch model
2. Call Brevitas FINN-ONNX export and visualize with Netron
3. Import into FINN and call cleanup transformations
We'll use the following utility functions to print the source code for function calls (`showSrc()`) and to visualize a network using netron (`showInNetron()`) in the Jupyter notebook:
%% Cell type:code id: tags:
``` python
import onnx
from finn.util.visualization import showSrc, showInNetron
```
%% Cell type:markdown id: tags:
## 1. Load up the trained PyTorch model
The FINN Docker image comes with several [example Brevitas networks](https://github.com/Xilinx/brevitas/tree/master/src/brevitas_examples/bnn_pynq), and we'll use the LFC-w1a1 model as the example network here. This is a binarized fully connected network trained on the MNIST dataset. Let's start by looking at what the PyTorch network definition looks like:
%% Cell type:code id: tags:
``` python
from brevitas_examples import bnn_pynq
showSrc(bnn_pynq.models.FC)
```
%% Cell type:markdown id: tags:
We can see that the network topology is constructed using a few helper functions that generate the quantized linear layers and quantized activations. The bitwidth of the layers is actually parametrized in the constructor, so let's instantiate a 1-bit weights and activations version of this network. We also have pretrained weights for this network, which we will load into the model.
%% Cell type:code id: tags:
``` python
from finn.util.test import get_test_model
lfc = get_test_model(netname = "LFC", wbits = 1, abits = 1, pretrained = True)
lfc
```
%% Cell type:markdown id: tags:
We have now instantiated our trained PyTorch network. Let's try to run an example MNIST image through the network using PyTorch.
%% Cell type:code id: tags:
``` python
import torch
import matplotlib.pyplot as plt
from pkgutil import get_data
import onnx
import onnx.numpy_helper as nph
raw_i = get_data("finn.data", "onnx/mnist-conv/test_data_set_0/input_0.pb")
raw_i = get_data("qonnx.data", "onnx/mnist-conv/test_data_set_0/input_0.pb")
input_tensor = onnx.load_tensor_from_string(raw_i)
input_tensor_npy = nph.to_array(input_tensor)
input_tensor_pyt = torch.from_numpy(input_tensor_npy).float()
imgplot = plt.imshow(input_tensor_npy.reshape(28,28), cmap='gray')
```
%% Cell type:code id: tags:
``` python
from torch.nn.functional import softmax
# do forward pass in PyTorch/Brevitas
produced = lfc.forward(input_tensor_pyt).detach()
probabilities = softmax(produced, dim=-1).flatten()
probabilities
```
%% Cell type:code id: tags:
``` python
import numpy as np
objects = [str(x) for x in range(10)]
y_pos = np.arange(len(objects))
plt.bar(y_pos, probabilities, align='center', alpha=0.5)
plt.xticks(y_pos, objects)
plt.ylabel('Predicted Probability')
plt.title('LFC-w1a1 Predictions for Image')
plt.show()
```
%% Cell type:markdown id: tags:
## 2. Call Brevitas FINN-ONNX export and visualize with Netron
Brevitas comes with built-in FINN-ONNX export functionality. This is similar to the regular ONNX export capabilities of PyTorch, with a few differences:
1. The weight quantization logic is not exported as part of the graph; rather, the quantized weights themselves are exported.
2. Special quantization annotations are used to preserve the low-bit quantization information. ONNX (at the time of writing) supports 8-bit quantization as the minimum bitwidth, whereas FINN-ONNX quantization annotations can go down to binary/bipolar quantization.
3. Low-bit quantized activation functions are exported as MultiThreshold operators.
It's actually quite straightforward to export ONNX from our Brevitas model as follows:
%% Cell type:code id: tags:
``` python
import brevitas.onnx as bo
export_onnx_path = "/tmp/LFCW1A1.onnx"
input_shape = (1, 1, 28, 28)
bo.export_finn_onnx(lfc, input_shape, export_onnx_path)
```
%% Cell type:markdown id: tags:
Let's examine what the exported ONNX model looks like. For this, we will use the Netron visualizer:
%% Cell type:code id: tags:
``` python
showInNetron('/tmp/LFCW1A1.onnx')
```
%% Cell type:markdown id: tags:
When running this notebook in the FINN Docker container, you should be able to see an interactive visualization of the imported network above, and click on individual nodes to inspect their parameters. If you look at any of the MatMul nodes, you should be able to see that the weights are all {-1, +1} values, and the activations are Sign functions.
%% Cell type:markdown id: tags:
## 3. Import into FINN and call cleanup transformations
We will now import this ONNX model into FINN using the ModelWrapper, and examine some of the graph attributes from Python.
%% Cell type:code id: tags:
``` python
from qonnx.core.modelwrapper import ModelWrapper
model = ModelWrapper(export_onnx_path)
model.graph.node[8]
```
%% Cell type:markdown id: tags:
The ModelWrapper exposes a range of other useful functions as well. For instance, by convention the second input of the MatMul node will be a pre-initialized weight tensor, which we can view using the following:
%% Cell type:code id: tags:
``` python
model.get_initializer(model.graph.node[8].input[1])
```
%% Cell type:markdown id: tags:
We can also examine the quantization annotations and shapes of various tensors using the convenience functions provided by ModelWrapper.
%% Cell type:code id: tags:
``` python
model.get_tensor_datatype(model.graph.node[8].input[1]).name
```
%% Cell type:code id: tags:
``` python
model.get_tensor_shape(model.graph.node[8].input[1])
```
%% Cell type:markdown id: tags:
If we want to operate further on this model in FINN, it is a good idea to execute certain "cleanup" transformations on this graph. Here, we will run shape inference and constant folding on this graph, and visualize the resulting graph in Netron again.
%% Cell type:code id: tags:
``` python
from qonnx.transformation.fold_constants import FoldConstants
from qonnx.transformation.infer_shapes import InferShapes
model = model.transform(InferShapes())
model = model.transform(FoldConstants())
export_onnx_path_transformed = "/tmp/LFCW1A1-clean.onnx"
model.save(export_onnx_path_transformed)
```
%% Cell type:code id: tags:
``` python
showInNetron('/tmp/LFCW1A1-clean.onnx')
```
%% Cell type:markdown id: tags:
We can see that the resulting graph has become smaller and simpler. Specifically, the input reshaping is now a single Reshape node instead of the Shape -> Gather -> Unsqueeze -> Concat -> Reshape sequence. We can now use the internal ONNX execution capabilities of FINN to ensure that we still get the same output from this model as we did with PyTorch.
%% Cell type:code id: tags:
``` python
import finn.core.onnx_exec as oxe
input_dict = {"0": nph.to_array(input_tensor)}
output_dict = oxe.execute_onnx(model, input_dict)
produced_finn = output_dict[list(output_dict.keys())[0]]
produced_finn
```
%% Cell type:code id: tags:
``` python
np.isclose(produced, produced_finn).all()
```
%% Cell type:markdown id: tags:
We have succesfully verified that the transformed and cleaned-up FINN graph still produces the same output, and can now use this model for further processing in FINN.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment