Skip to content
Snippets Groups Projects
Commit 70f2abed authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Notebook] add qnt annot example to brevitas-network-import nb

parent 8cc51f8e
No related branches found
No related tags found
No related merge requests found
%% Cell type:markdown id: tags:
# Importing Brevitas networks into FINN
In this notebook we'll go through an example of how to import a Brevitas-trained QNN into FINN. The steps will be as follows:
1. Load up the trained PyTorch model
2. Call Brevitas FINN-ONNX export and visualize with Netron
3. Import into FINN and call cleanup transformations
We'll use the following showSrc function to print the source code for function calls in the Jupyter notebook:
%% Cell type:code id: tags:
``` python
import inspect
def showSrc(what):
print("".join(inspect.getsourcelines(what)[0]))
```
%% Cell type:markdown id: tags:
## 1. Load up the trained PyTorch model
The FINN Docker image comes with several [example Brevitas networks](https://github.com/maltanar/brevitas_cnv_lfc), and we'll use the LFC-w1a1 model as the example network here. This is a binarized fully connected network trained on the MNIST dataset. Let's start by looking at what the PyTorch network definition looks like:
%% Cell type:code id: tags:
``` python
from models.LFC import LFC
showSrc(LFC)
```
%% Output
class LFC(Module):
def __init__(self, num_classes=10, weight_bit_width=None, act_bit_width=None,
in_bit_width=None, in_ch=1, in_features=(28, 28)):
super(LFC, self).__init__()
weight_quant_type = get_quant_type(weight_bit_width)
act_quant_type = get_quant_type(act_bit_width)
in_quant_type = get_quant_type(in_bit_width)
stats_op = get_stats_op(weight_quant_type)
self.features = ModuleList()
self.features.append(get_act_quant(in_bit_width, in_quant_type))
self.features.append(Dropout(p=IN_DROPOUT))
in_features = reduce(mul, in_features)
for out_features in FC_OUT_FEATURES:
self.features.append(get_quant_linear(in_features=in_features,
out_features=out_features,
per_out_ch_scaling=INTERMEDIATE_FC_PER_OUT_CH_SCALING,
bit_width=weight_bit_width,
quant_type=weight_quant_type,
stats_op=stats_op))
in_features = out_features
self.features.append(BatchNorm1d(num_features=in_features))
self.features.append(get_act_quant(act_bit_width, act_quant_type))
self.features.append(Dropout(p=HIDDEN_DROPOUT))
self.fc = get_quant_linear(in_features=in_features,
out_features=num_classes,
per_out_ch_scaling=LAST_FC_PER_OUT_CH_SCALING,
bit_width=weight_bit_width,
quant_type=weight_quant_type,
stats_op=stats_op)
def forward(self, x):
x = x.view(x.shape[0], -1)
x = 2.0 * x - torch.tensor([1.0])
for mod in self.features:
x = mod(x)
out = self.fc(x)
return out
%% Cell type:markdown id: tags:
We can see that the network topology is constructed using a few helper functions that generate the quantized linear layers and quantized activations. The bitwidth of the layers is actually parametrized in the constructor, so let's instantiate a 1-bit weights and activations version of this network. We also have pretrained weights for this network, which we will load into the model.
%% Cell type:code id: tags:
``` python
import torch
trained_lfc_w1a1_checkpoint = "/workspace/brevitas_cnv_lfc/pretrained_models/LFC_1W1A/checkpoints/best.tar"
lfc = LFC(weight_bit_width=1, act_bit_width=1, in_bit_width=1).eval()
checkpoint = torch.load(trained_lfc_w1a1_checkpoint, map_location="cpu")
lfc.load_state_dict(checkpoint["state_dict"])
lfc
```
%% Output
LFC(
(features): ModuleList(
(0): QuantHardTanh(
(act_quant_proxy): ActivationQuantProxy(
(fused_activation_quant_proxy): FusedActivationQuantProxy(
(activation_impl): Identity()
(tensor_quant): ClampedBinaryQuant(
(scaling_impl): StandaloneScaling(
(restrict_value): RestrictValue(
(forward_impl): Sequential(
(0): PowerOfTwo()
(1): ClampMin()
)
)
)
)
)
)
)
(1): Dropout(p=0.2)
(2): QuantLinear(
in_features=784, out_features=1024, bias=False
(weight_reg): WeightReg()
(weight_quant): WeightQuantProxy(
(tensor_quant): BinaryQuant(
(scaling_impl): ParameterStatsScaling(
(parameter_list_stats): ParameterListStats(
(first_tracked_param): _ViewParameterWrapper()
(stats): Stats(
(stats_impl): AbsAve()
)
)
(stats_scaling_impl): StatsScaling(
(affine_rescaling): Identity()
(restrict_scaling): RestrictValue(
(forward_impl): Sequential(
(0): PowerOfTwo()
(1): ClampMin()
)
)
(restrict_scaling_preprocess): LogTwo()
)
)
)
)
(bias_quant): BiasQuantProxy()
)
(3): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(4): QuantHardTanh(
(act_quant_proxy): ActivationQuantProxy(
(fused_activation_quant_proxy): FusedActivationQuantProxy(
(activation_impl): Identity()
(tensor_quant): ClampedBinaryQuant(
(scaling_impl): StandaloneScaling(
(restrict_value): RestrictValue(
(forward_impl): Sequential(
(0): PowerOfTwo()
(1): ClampMin()
)
)
)
)
)
)
)
(5): Dropout(p=0.2)
(6): QuantLinear(
in_features=1024, out_features=1024, bias=False
(weight_reg): WeightReg()
(weight_quant): WeightQuantProxy(
(tensor_quant): BinaryQuant(
(scaling_impl): ParameterStatsScaling(
(parameter_list_stats): ParameterListStats(
(first_tracked_param): _ViewParameterWrapper()
(stats): Stats(
(stats_impl): AbsAve()
)
)
(stats_scaling_impl): StatsScaling(
(affine_rescaling): Identity()
(restrict_scaling): RestrictValue(
(forward_impl): Sequential(
(0): PowerOfTwo()
(1): ClampMin()
)
)
(restrict_scaling_preprocess): LogTwo()
)
)
)
)
(bias_quant): BiasQuantProxy()
)
(7): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(8): QuantHardTanh(
(act_quant_proxy): ActivationQuantProxy(
(fused_activation_quant_proxy): FusedActivationQuantProxy(
(activation_impl): Identity()
(tensor_quant): ClampedBinaryQuant(
(scaling_impl): StandaloneScaling(
(restrict_value): RestrictValue(
(forward_impl): Sequential(
(0): PowerOfTwo()
(1): ClampMin()
)
)
)
)
)
)
)
(9): Dropout(p=0.2)
(10): QuantLinear(
in_features=1024, out_features=1024, bias=False
(weight_reg): WeightReg()
(weight_quant): WeightQuantProxy(
(tensor_quant): BinaryQuant(
(scaling_impl): ParameterStatsScaling(
(parameter_list_stats): ParameterListStats(
(first_tracked_param): _ViewParameterWrapper()
(stats): Stats(
(stats_impl): AbsAve()
)
)
(stats_scaling_impl): StatsScaling(
(affine_rescaling): Identity()
(restrict_scaling): RestrictValue(
(forward_impl): Sequential(
(0): PowerOfTwo()
(1): ClampMin()
)
)
(restrict_scaling_preprocess): LogTwo()
)
)
)
)
(bias_quant): BiasQuantProxy()
)
(11): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(12): QuantHardTanh(
(act_quant_proxy): ActivationQuantProxy(
(fused_activation_quant_proxy): FusedActivationQuantProxy(
(activation_impl): Identity()
(tensor_quant): ClampedBinaryQuant(
(scaling_impl): StandaloneScaling(
(restrict_value): RestrictValue(
(forward_impl): Sequential(
(0): PowerOfTwo()
(1): ClampMin()
)
)
)
)
)
)
)
(13): Dropout(p=0.2)
)
(fc): QuantLinear(
in_features=1024, out_features=10, bias=False
(weight_reg): WeightReg()
(weight_quant): WeightQuantProxy(
(tensor_quant): BinaryQuant(
(scaling_impl): ParameterStatsScaling(
(parameter_list_stats): ParameterListStats(
(first_tracked_param): _ViewParameterWrapper()
(stats): Stats(
(stats_impl): AbsAve()
)
)
(stats_scaling_impl): StatsScaling(
(affine_rescaling): Identity()
(restrict_scaling): RestrictValue(
(forward_impl): Sequential(
(0): PowerOfTwo()
(1): ClampMin()
)
)
(restrict_scaling_preprocess): LogTwo()
)
)
)
)
(bias_quant): BiasQuantProxy()
)
)
%% Cell type:markdown id: tags:
We have now instantiated our trained PyTorch network. Let's try to run an example MNIST image through the network using PyTorch.
%% Cell type:code id: tags:
``` python
import matplotlib.pyplot as plt
from pkgutil import get_data
import onnx
import onnx.numpy_helper as nph
raw_i = get_data("finn", "data/onnx/mnist-conv/test_data_set_0/input_0.pb")
input_tensor = onnx.load_tensor_from_string(raw_i)
input_tensor_npy = nph.to_array(input_tensor)
input_tensor_pyt = torch.from_numpy(input_tensor_npy).float()
imgplot = plt.imshow(input_tensor_npy.reshape(28,28))
```
%% Output
%% Cell type:code id: tags:
``` python
from torch.nn.functional import softmax
# do forward pass in PyTorch/Brevitas
produced = lfc.forward(input_tensor_pyt).detach()
probabilities = softmax(produced, dim=-1).flatten()
probabilities
```
%% Output
tensor([2.4663e-03, 6.8211e-06, 8.9177e-01, 2.1330e-05, 3.6883e-04, 3.0418e-06,
1.1795e-04, 5.0158e-05, 1.0517e-01, 2.4597e-05])
%% Cell type:code id: tags:
``` python
import numpy as np
objects = [str(x) for x in range(10)]
y_pos = np.arange(len(objects))
plt.bar(y_pos, probabilities, align='center', alpha=0.5)
plt.xticks(y_pos, objects)
plt.ylabel('Predicted Probability')
plt.title('LFC-w1a1 Predictions for Image')
plt.show()
```
%% Output
%% Cell type:markdown id: tags:
## 2. Call Brevitas FINN-ONNX export and visualize with Netron
Brevitas comes with built-in FINN-ONNX export functionality. This is similar to the regular ONNX export capabilities of PyTorch, with a few differences:
1. The weight quantization logic is not exported as part of the graph; rather, the quantized weights themselves are exported.
2. Special quantization annotations are used to preserve the low-bit quantization information. ONNX (at the time of writing) supports 8-bit quantization as the minimum bitwidth, whereas FINN-ONNX quantization annotations can go down to binary/bipolar quantization.
3. Low-bit quantized activation functions are exported as MultiThreshold operators.
It's actually quite straightforward to export ONNX from our Brevitas model as follows:
%% Cell type:code id: tags:
``` python
import brevitas.onnx as bo
export_onnx_path = "/tmp/LFCW1A1.onnx"
input_shape = (1, 1, 28, 28)
bo.export_finn_onnx(lfc, input_shape, export_onnx_path)
```
%% Output
/workspace/brevitas_cnv_lfc/training_scripts/models/LFC.py:73: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
x = 2.0 * x - torch.tensor([1.0])
%% Cell type:markdown id: tags:
Let's examine what the exported ONNX model looks like. For this, we will use the Netron visualizer:
%% Cell type:code id: tags:
``` python
import netron
netron.start(export_onnx_path, port=8081, host="0.0.0.0")
```
%% Output
Serving '/tmp/LFCW1A1.onnx' at http://0.0.0.0:8081
%% Cell type:code id: tags:
``` python
%%html
<iframe src="http://0.0.0.0:8081/" style="position: relative; width: 100%;" height="400"></iframe>
```
%% Output
%% Cell type:markdown id: tags:
When running this notebook in the FINN Docker container, you should be able to see an interactive visualization of the imported network above, and click on individual nodes to inspect their parameters. If you look at any of the MatMul nodes, you should be able to see that the weights are all {-1, +1} values, and the activations are Sign functions.
%% Cell type:markdown id: tags:
## 3. Import into FINN and call cleanup transformations
We will now import this ONNX model into FINN using the ModelWrapper, and examine some of the graph attributes from Python.
%% Cell type:code id: tags:
``` python
from finn.core.modelwrapper import ModelWrapper
model = ModelWrapper(export_onnx_path)
model.graph.node[9]
```
%% Output
input: "32"
input: "33"
output: "35"
op_type: "MatMul"
%% Cell type:markdown id: tags:
The ModelWrapper exposes a range of other useful functions as well. For instance, by convention the second input of the MatMul node will be a pre-initialized weight tensor, which we can view using the following:
%% Cell type:code id: tags:
``` python
model.get_initializer(model.graph.node[9].input[1])
```
%% Output
array([[ 1., 1., 1., ..., 1., 1., -1.],
[ 1., 1., -1., ..., 1., 1., -1.],
[-1., 1., -1., ..., -1., 1., -1.],
...,
[-1., 1., -1., ..., -1., -1., 1.],
[ 1., 1., -1., ..., 1., 1., -1.],
[-1., 1., 1., ..., -1., -1., 1.]], dtype=float32)
%% Cell type:markdown id: tags:
We can also examine the quantization annotations and shapes of various tensors using the convenience functions provided by ModelWrapper.
%% Cell type:code id: tags:
``` python
model.get_tensor_datatype(model.graph.node[9].input[1])
```
%% Output
<DataType.BIPOLAR: 8>
%% Cell type:code id: tags:
``` python
model.get_tensor_shape(model.graph.node[9].input[1])
```
%% Output
[784, 1024]
%% Cell type:markdown id: tags:
If we want to operate further on this model in FINN, it is a good idea to execute certain "cleanup" transformations on this graph. Here, we will run shape inference and constant folding on this graph, and visualize the resulting graph in Netron again.
%% Cell type:code id: tags:
``` python
from finn.transformation.fold_constants import FoldConstants
from finn.transformation.infer_shapes import InferShapes
model = model.transform(InferShapes())
model = model.transform(FoldConstants())
export_onnx_path_transformed = "/tmp/LFCW1A1-clean.onnx"
model.save(export_onnx_path_transformed)
netron.start(export_onnx_path_transformed, port=8081, host="0.0.0.0")
```
%% Output
Stopping http://0.0.0.0:8081
Serving '/tmp/LFCW1A1-clean.onnx' at http://0.0.0.0:8081
%% Cell type:code id: tags:
``` python
%%html
<iframe src="http://0.0.0.0:8081/" style="position: relative; width: 100%;" height="400"></iframe>
```
%% Output
%% Cell type:markdown id: tags:
We can see that the resulting graph has become smaller and simpler. Specifically, the input reshaping is now a single Reshape node instead of the Shape -> Gather -> Unsqueeze -> Concat -> Reshape sequence. We can now use the internal ONNX execution capabilities of FINN to ensure that we still get the same output from this model as we did with PyTorch.
%% Cell type:code id: tags:
``` python
import finn.core.onnx_exec as oxe
input_dict = {"0": nph.to_array(input_tensor)}
output_dict = oxe.execute_onnx(model, input_dict)
produced_finn = output_dict[list(output_dict.keys())[0]]
produced_finn
```
%% Output
array([[ 3.3252678 , -2.5652065 , 9.215742 , -1.4251148 , 1.4251148 ,
-3.3727715 , 0.28502294, -0.5700459 , 7.07807 , -1.2826033 ]],
dtype=float32)
%% Cell type:code id: tags:
``` python
np.isclose(produced, produced_finn).all()
```
%% Output
True
%% Cell type:markdown id: tags:
We have succesfully verified that the transformed and cleaned-up FINN graph still produces the same output, and can now use this model for further processing in FINN.
%% Cell type:code id: tags:
``` python
```
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment