Skip to content
Snippets Groups Projects
Commit 830b2dca authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

Merge branch 'feature/notebook_code_gen' into dev

parents f7f1ada0 7836c7d3
No related branches found
No related tags found
No related merge requests found
No preview for this file type
This diff is collapsed.
%% Cell type:markdown id: tags:
# FINN - CustomOps
-----------------------------------------------------------------
<font size="3">This notebook should give a more detailed insight into FINN custom operation nodes. </font>
%% Cell type:markdown id: tags:
<font size="3">Following showSrc function is used to print the source code of function calls in the Jupyter notebook: </font>
%% Cell type:code id: tags:
``` python
import inspect
def showSrc(what):
print("".join(inspect.getsourcelines(what)[0]))
```
%% Cell type:markdown id: tags:
## FINN Custom Ops
---------------------------
<font size="3">FINN uses many custom operations (`op_type` in ONNX NodeProto) that are not defined in the ONNX operator schema. These custom nodes are marked with `domain="finn"` in the protobuf to identify them as such. These nodes can represent specific operations that we need for low-bit networks, or operations that are specific to a particular hardware backend.
A very abstract version of a custom op node representing a streaming fc layer is shown below. </font>
%% Cell type:markdown id: tags:
`FCLayer_node = helper.make_node(
"StreamingFCLayer_Batch",
node_inp_list,
node_outp_list,
## Outline
---------------------------
* <font size="3">Basic FINN-ONNX node</font>
* <font size="3">CustomOp class</font>
* <font size="3">HLS FINN-ONNX node</font>
* <font size="3">HLSCustomOp class</font>
%% Cell type:markdown id: tags:
## Basic FINN-ONNX node
<font size="3">To create a FINN-ONNX node you can use the helper function of ONNX. Because it is an ONNX NodeProtobuf, but with several additional attributes. The procedure is shown with an example for a multithreshold node. </font>
`multithreshold_node = helper.make_node(
"MultiThreshold",
["v", "thresholds"],
["out"],
domain="finn",
backend="fpgadataflow",
code_gen_dir="",
executable_path="",
resType="ap_resource_lut()",
MW=mw,
MH=mh,
SIMD=simd,
PE=pe,
inputDataType=<FINN DataType>,
weightDataType=<FINN DataType>,
outputDataType=<FINN DataType>,
ActVal=actval,
binaryXnorMode=<0/1>,
noActivation=<0/1>
out_scale=2.0,
out_bias=-1.0,
out_dtype="",
)`
%% Cell type:markdown id: tags:
<font size="3">Unlike standard nodes, the custom op nodes has several additional attributes. The node is created using the helper function of ONNX. `"StreamingFCLayer_Batch"` describes the op_type, then the inputs and outputs are declared. Since this is a custom op node of FINN, the attribute `domain="finn"` must be set. The streaming fc layer is a custom op from the finn-hls library, this is set in the node using the `backend` attribute. To execute a custom op from the finn-hls library, the corresponding c++ code must be created and an executable must be produced. Where the generated code is stored is specified in the `code_gen_dir` attribute and `executable_path` specifies the path to the produced executable. In addition to the data types of the input and output tensors, the node also contains various other attributes resulting from the parameters of the corresponding finn-hls library function. This will not be discussed here.</font>
<font size="3">The `helper.make_node` function gets the op_type as first argument. In this case it is *MultiThreshold*. Then the inputs and outputs are passed. Beside the data input the multithreshold node has an additional input to pass the threshold values.
The next attribute (`domain`) is to specify that it is a FINN-ONNX node. It must be set to `"finn"`, so that the functions that work with FINN-ONNX nodes can directly recognize that it is a CustomOp. The attributes `out_scale` and `out_bias` are special multithreshold attributes to manipulate the output value. `out_dtype` contains the output data type.
**Note**: each FINN-ONNX node has its own special attributes, which must be set correctly to ensure proper processing.</font>
%% Cell type:markdown id: tags:
<font size="3">Custom Ops are represented in FINN as ONNX nodes on the one hand and by a CustomOp class on the other hand. This allows easier access to the different attributes and introduces special custom op functions. See below for the standard CustomOp class.</font>
## CustomOp class
<font size="3">Custom Ops are represented in FINN as ONNX nodes on the one hand and by a CustomOp class on the other hand. This allows easier access to different attributes and introduces special custom op functions. See below for the standard CustomOp class.</font>
%% Cell type:code id: tags:
``` python
from finn.custom_op import CustomOp
showSrc(CustomOp)
```
%% Output
class CustomOp(ABC):
def __init__(self, onnx_node):
super().__init__()
self.onnx_node = onnx_node
def get_nodeattr(self, name):
"""Get a node attribute by name. Data is stored inside the ONNX node's
AttributeProto container. Attribute must be part of get_nodeattr_types.
Default value is returned if attribute is not set."""
try:
(dtype, req, def_val) = self.get_nodeattr_types()[name]
attr = get_by_name(self.onnx_node.attribute, name)
if attr is not None:
# dtype indicates which ONNX Attribute member to use
# (such as i, f, s...)
ret = attr.__getattribute__(dtype)
if dtype == "s":
# decode string attributes
ret = ret.decode("utf-8")
return ret
else:
# not set, return default value
return def_val
except KeyError:
raise AttributeError("Op has no such attribute: " + name)
def set_nodeattr(self, name, value):
"""Set a node attribute by name. Data is stored inside the ONNX node's
AttributeProto container. Attribute must be part of get_nodeattr_types."""
try:
(dtype, req, def_val) = self.get_nodeattr_types()[name]
attr = get_by_name(self.onnx_node.attribute, name)
if attr is not None:
# dtype indicates which ONNX Attribute member to use
# (such as i, f, s...)
if dtype == "s":
# encode string attributes
value = value.encode("utf-8")
attr.__setattr__(dtype, value)
else:
# not set, create and insert AttributeProto
attr_proto = helper.make_attribute(name, value)
self.onnx_node.attribute.append(attr_proto)
except KeyError:
raise AttributeError("Op has no such attribute: " + name)
@abstractmethod
def get_nodeattr_types(self):
"""Returns a dict of permitted attributes for node, where:
returned_dict[attribute_name] = (dtype, require, default_value)
- dtype indicates which member of the ONNX AttributeProto
will be utilized
- require indicates whether this attribute is required
- default_val indicates the default value that will be used if the
attribute is not set
"""
pass
@abstractmethod
def make_shape_compatible_op(self):
"""Returns a standard ONNX op which is compatible with this CustomOp
for performing shape inference."""
pass
@abstractmethod
def infer_node_datatype(self, model):
"""Set the DataType annotations corresponding to the outputs of this
node."""
pass
@abstractmethod
def execute_node(self, context, graph):
"""Execute this CustomOp instance, given the execution context and
ONNX graph."""
pass
%% Cell type:markdown id: tags:
<font size="3">When instantiating the class, the ONNX node is passed to access all attributes of the node within the class. This is accompanied by the functions `get_nodeattr()`and `set_nodeattr()`, which each instance of this class has. Furthermore 4 abstract methods are implemented, which are described in more detail in the comments in the code. </font>
<font size="3">When instantiating the class, the ONNX node is passed to access all attributes of the node within the class. This is accompanied by the functions `get_nodeattr()`and `set_nodeattr()`, which each instance of this class has. Furthermore 4 abstract methods are implemented, which are described in more detail in the commands of the code and will be exemplarily explained for the multithreshold node in the following. </font>
%% Cell type:code id: tags:
``` python
from finn.custom_op.multithreshold import MultiThreshold
showSrc(MultiThreshold)
```
%% Output
class MultiThreshold(CustomOp):
def get_nodeattr_types(self):
return {
"out_dtype": ("s", True, ""),
"out_scale": ("f", False, 1.0),
"out_bias": ("f", False, 0.0),
}
def make_shape_compatible_op(self):
node = self.onnx_node
return helper.make_node("Relu", [node.input[0]], [node.output[0]])
def infer_node_datatype(self, model):
node = self.onnx_node
odt = self.get_nodeattr("out_dtype")
model.set_tensor_datatype(node.output[0], DataType[odt])
def execute_node(self, context, graph):
node = self.onnx_node
# save inputs
v = context[node.input[0]]
thresholds = context[node.input[1]]
# retrieve attributes if output scaling is used
out_scale = self.get_nodeattr("out_scale")
out_bias = self.get_nodeattr("out_bias")
# calculate output
output = multithreshold(v, thresholds, out_scale, out_bias)
# setting context according to output
context[node.output[0]] = output
%% Cell type:markdown id: tags:
<font size="3"> `get_nodeattr_types`: returns a dict for the permitted attributes for node. It returns a triple with following values for each of the special multithreshold attributes. </font>
* <font size="3">`dtype`: indicates which member of the ONNX AttributeProto will be utilized </font>
* <font size="3">`require`: indicates whether this attribute is required </font>
* <font size="3">`default_value`: indicates the default value that will be used if the attribute is not set </font>
%% Cell type:markdown id: tags:
<font size="3">`make_shape_compatible_op`: To use the flow of FINN, the transformation pass [infer_shapes](https://github.com/Xilinx/finn/blob/dev/src/finn/transformation/infer_shapes.py) is applied to the graphs in various places. In order for this transformation to be applied to CustomOps, they must first be converted to standard ONNX nodes with the same shape behavior. This means, nodes where the relationship between input and output shape is the same.
This is done at this point. Since the output shape of a multithreshold node is the same as the input shape, it can be replaced by a `"Relu"` node from the standard node library of onnx.</font>
%% Cell type:markdown id: tags:
<font size="3">`infer_node_datatype`: sets the output tensor data type accordingly to the attribute `out_dtype` </font>
%% Cell type:markdown id: tags:
<font size="3">`execute_node`: This function allows the execution of the node, depending on the CustomOp a different functionality has to be implemented. In the case of the multithreshold node the input values and the thresholds are first extracted and after the attributes for the output scaling have been retrieved, the output is calculated with the help of a separate function. For more details regarding this function please take a look in the code [here](https://github.com/Xilinx/finn/blob/dev/src/finn/custom_op/multithreshold.py). </font>
%% Cell type:markdown id: tags:
<font size="3">If it is a node from the finn-hls library another class is used which is derived from the CustomOp class:</font>
<font size="3">FINN has a subset of CustomOps that correspond to the [finn-hls](https://finn-hlslib.readthedocs.io/en/latest/) library. In the next part of the Jupyter notebook these are described in more detail. </font>
%% Cell type:markdown id: tags:
## HLS FINN-ONNX node
<font size="3">The creation of an HLS FINN-ONNX node looks very similar to the creation of a basic FINN-ONNX node. But three new attributes are introduced that are necessary to enable the processing of HLS FINN-ONNX nodes in FINN.</font>
`FCLayer_node = helper.make_node(
"StreamingFCLayer_Batch",
node_inp_list,
node_outp_list,
domain="finn",
backend="fpgadataflow",
code_gen_dir="",
executable_path="",
resType="ap_resource_lut()",
MW=mw,
MH=mh,
SIMD=simd,
PE=pe,
inputDataType=<FINN DataType>,
weightDataType=<FINN DataType>,
outputDataType=<FINN DataType>,
ActVal=actval,
binaryXnorMode=<0/1>,
noActivation=<0/1>
)`
%% Cell type:markdown id: tags:
<font size="3">`"StreamingFCLayer_Batch"` describes the op_type, then the inputs and outputs are declared. This is still like building a default onnx node without additional attributes. But since this is a custom op node of FINN, the attribute `domain="finn"` must be set. The streaming fc layer is a custom op from the [finn-hls](https://finn-hlslib.readthedocs.io/en/latest/) library, this information is set in the node using the `backend` attribute. To execute a custom op from the [finn-hls](https://finn-hlslib.readthedocs.io/en/latest/) library, the corresponding c++ code must be created and an executable must be produced. Where the generated code is stored is specified in the `code_gen_dir` attribute and `executable_path` specifies the path to the produced executable. In addition to the data types of the input and output tensors, the node also contains various other attributes resulting from the parameters of the corresponding [finn-hls](https://finn-hlslib.readthedocs.io/en/latest/) library function. More detailed information can be found in the documentation of [finn-hlslib](https://finn-hlslib.readthedocs.io/en/latest/).</font>
%% Cell type:markdown id: tags:
## HLSCustomOp class
<font size="3">If it is a node from the [finn-hls](https://finn-hlslib.readthedocs.io/en/latest/) library another class is used which is derived from the CustomOp class:</font>
%% Cell type:code id: tags:
``` python
from finn.custom_op.fpgadataflow import HLSCustomOp
showSrc(HLSCustomOp)
```
%% Output
class HLSCustomOp(CustomOp):
def __init__(self, onnx_node):
super().__init__(onnx_node)
# template for single node execution
self.docompute_template = """
#include "cnpy.h"
#include "npy2apintstream.hpp"
#include <vector>
#include "bnn-library.h"
// includes for network parameters
$GLOBALS$
// defines for network parameters
$DEFINES$
int main(){
$STREAMDECLARATIONS$
$READNPYDATA$
$DOCOMPUTE$
$DATAOUTSTREAM$
$SAVEASCNPY$
}
"""
self.code_gen_dict = {}
def get_nodeattr_types(self):
return {"code_gen_dir": ("s", False, ""), "executable_path": ("s", False, "")}
def code_generation(self, model):
node = self.onnx_node
self.generate_params(model)
self.global_includes()
self.defines()
self.read_npy_data()
self.strm_decl()
self.docompute()
self.dataoutstrm()
self.save_as_npy()
template = self.docompute_template
for key in self.code_gen_dict:
# transform list into long string separated by '\n'
code_gen_line = "\n".join(self.code_gen_dict[key])
template = template.replace(key, code_gen_line)
code_gen_dir = self.get_nodeattr("code_gen_dir")
f = open(os.path.join(code_gen_dir, "execute_{}.cpp".format(node.op_type)), "w")
f.write(template)
f.close()
def compile_singlenode_code(self):
code_gen_dir = self.get_nodeattr("code_gen_dir")
builder = CppBuilder()
builder.append_includes("-I/workspace/finn/src/finn/data/cpp")
builder.append_includes("-I/workspace/cnpy/")
builder.append_includes("-I/workspace/finn-hlslib")
builder.append_includes("-I/workspace/vivado-hlslib")
builder.append_includes("--std=c++11")
builder.append_sources(code_gen_dir + "/*.cpp")
builder.append_sources("/workspace/cnpy/cnpy.cpp")
builder.append_includes("-lz")
builder.set_executable_path(code_gen_dir + "/node_model")
builder.build(code_gen_dir)
self.set_nodeattr("executable_path", builder.executable_path)
def dynamic_input_to_npy(self, context, count):
node = self.onnx_node
code_gen_dir = self.get_nodeattr("code_gen_dir")
if code_gen_dir == "":
raise Exception(
"""
Found no codegen dir for this node, did you run the codegen transformation?
"""
)
# create a npy file for each input of the node (in_ind is input index)
# assuming dynamic inputs start from 0
for in_ind in range(count):
current_input_name = node.input[in_ind]
np.save(
os.path.join(code_gen_dir, "input_{}.npy".format(in_ind)),
context[current_input_name],
)
def npy_to_dynamic_output(self, context):
# TODO support multi-output nodes as needed
node = self.onnx_node
code_gen_dir = self.get_nodeattr("code_gen_dir")
output = np.load("{}/output.npy".format(code_gen_dir))
context[node.output[0]] = output
def exec_precompiled_singlenode_model(self):
# execute precompiled executable
executable_path = self.get_nodeattr("executable_path")
if executable_path == "":
raise Exception(
"""
Found no executable for this node, did you run the codegen and
compilation transformations?
"""
)
process_execute = subprocess.Popen(executable_path, stdout=subprocess.PIPE)
process_execute.communicate()
def execute_node(self, context, graph):
# save input(s)
self.dynamic_input_to_npy(context, 1)
# execute the precompiled model
self.exec_precompiled_singlenode_model()
# load output npy file
self.npy_to_dynamic_output(context)
def generate_params(self, model):
pass
@abstractmethod
def global_includes(self):
pass
@abstractmethod
def defines(self):
pass
@abstractmethod
def read_npy_data(self):
pass
@abstractmethod
def strm_decl(self):
pass
@abstractmethod
def docompute(self):
pass
@abstractmethod
def dataoutstrm(self):
pass
@abstractmethod
def save_as_npy(self):
pass
%% Cell type:markdown id: tags:
<font size="3">When creating an instance of this class, a template is introduced, which forms the layout for the c++ code to execute the node. It has some general constructs, like the inclusion of bnn-library.h, which contains the references to the finn-hls library, and of cnpy.h and npy2apintstream.hpp, which support the transfer of python numpy arrays in c++. The idea of this template is to replace the variables marked with `$ $` with c++ calls during code generation. Then the template can be written into a .cpp file and be compiled.
**`get_nodeattr_types()`**: each instance of the HLSCustomOp class must have the attributes `code_gen_dir` and `executable_path`, since to execute these nodes c++ code must be generated and correspondingly the executables.
</font>
%% Cell type:markdown id: tags:
<font size="3">**`code_generation(model)`**: all functions required for code generation are called and the `$ $` variables in the template are replaced accordingly and written into a .cpp file. Almost all of these subfunctions are implemented as abstract methods in the class, so they are completely customized for each custom op node. A special function is `generate_params()`. This is not implemented as an abstract method, but as a normal function, but contains by default only `pass`. This is because some custom op nodes do not have parameters that need to be generated and in this way the function is skipped. For example for a streaming fc layer node a parameter generation is necessary. How such a parameter generation can look like is described in more detail in the course of this notebook.
</font>
%% Cell type:markdown id: tags:
<font size="3">**`compile_singlenode_code()`**: To compile the generated code, the compile command must be built. This is done in this function. It creates an instance of the `CppBuilder()` class and assembles the various components for the function. The `.build` function creates the executable and then sets the corresponding attribute. The class `CppBuilder` is a transformation and a more detailed description can be found in Jupyter notebook *FINN-CodeGenerationAndCompilation*.
<font size="3">**`compile_singlenode_code()`**: To compile the generated code, the compile command must be built. This is done in this function. It creates an instance of the `CppBuilder()` class and assembles the various components for the function. The `.build` function creates the executable and then sets the corresponding attribute. The class `CppBuilder` is a transformation and a more detailed description can be found in Jupyter notebook [FINN-CodeGenerationAndCompilation](FINN-CodeGenerationAndCompilation.ipynb).
</font>
%% Cell type:markdown id: tags:
<font size="3">**`dynamic_input_to_npy(context, count)`**:</font>
<font size="3">**`dynamic_input_to_npy(context, count)`**: creates a .npy file for all inputs of the node. These files will be stored in the directory specified by code_gen_dir. The argument `count` must be used to specify the number of inputs. `context` contains the values for the inputs.</font>
%% Cell type:markdown id: tags:
<font size="3">**`npy_to_dynamic_output(context)`**: reads the output values and sets `context` dictionary accordingly. When executing the c++ executable of the node, the output values are written to a .npy file. </font>
%% Cell type:markdown id: tags:
<font size="3">**`exec_precompiled_singlenode_model()`**: executes precompiled executable which is specified in `executable_path`</font>
%% Cell type:markdown id: tags:
<font size="3">**`execute_node(context,graph)`**: calls first `dynamic_input_to_npy()`, then executes the executable using `exec_precompiled_singlenode_model()` and at the end reads the output .npy file with `npy_to_dynamic_output`</font>
%% Cell type:markdown id: tags:
#### Generate Parameter
<font size="3">Parameters have to be generated for specific types of HLSCustomOps. For example if the node is a streaming fc layer, there are weights and activation values, which are written to separate .h files and added to the template using `#include`. For streaming fc layer the parameter generation looks like this:
</font>
%% Cell type:code id: tags:
``` python
from finn.custom_op.fpgadataflow.streamingfclayer_batch import StreamingFCLayer_Batch
showSrc(StreamingFCLayer_Batch.generate_params)
```
%% Output
def generate_params(self, model):
# weights
weights = model.get_initializer(self.onnx_node.input[1])
# convert weights into hlslib-compatible format
weight_tensor = self.get_hls_compatible_weight_tensor(weights)
export_wdt = self.get_weight_datatype()
# we have converted bipolar weights to binary for export,
# so use it as such for weight generation
if self.get_weight_datatype() == DataType.BIPOLAR:
export_wdt = DataType.BINARY
weight_hls_code = numpy_to_hls_code(
weight_tensor, export_wdt, "weights", True, True
)
# write weights into params.h
code_gen_dir = self.get_nodeattr("code_gen_dir")
f_weights = open("{}/params.h".format(code_gen_dir), "w")
if export_wdt.bitwidth() != 1:
f_weights.write(
"static FixedPointWeights<{},{},{},{}> weights = ".format(
self.get_nodeattr("SIMD"),
export_wdt.get_hls_datatype_str(),
self.get_nodeattr("PE"),
self.calc_wmem(),
)
)
else:
f_weights.write(
"static BinaryWeights<{},{},{}> weights = ".format(
self.get_nodeattr("SIMD"), self.get_nodeattr("PE"), self.calc_wmem()
)
)
f_weights.write(weight_hls_code)
f_weights.close()
# thresholds
if len(self.onnx_node.input) > 2:
thresholds = model.get_initializer(self.onnx_node.input[2])
if thresholds is not None:
threshold_tensor = self.get_hls_compatible_threshold_tensor(thresholds)
tdt = DataType.INT32
# use UINT32 threshold export for bipolar times bipolar
inp_is_bipolar = self.get_input_datatype() == DataType.BIPOLAR
wt_is_bipolar = self.get_weight_datatype() == DataType.BIPOLAR
# reinterpret inp/wt as bipolar if bin_xnor_mode is iset
inp_is_binary = self.get_input_datatype() == DataType.BINARY
wt_is_binary = self.get_weight_datatype() == DataType.BINARY
bin_xnor_mode = self.get_nodeattr("binaryXnorMode") == 1
inp_is_bipolar = inp_is_bipolar or (inp_is_binary and bin_xnor_mode)
wt_is_bipolar = wt_is_bipolar or (wt_is_binary and bin_xnor_mode)
if inp_is_bipolar and wt_is_bipolar:
tdt = DataType.UINT32
thresholds_hls_code = numpy_to_hls_code(
threshold_tensor, tdt, "thresholds", False, True
)
# write thresholds into thresh.h
code_gen_dir = self.get_nodeattr("code_gen_dir")
f_thresh = open("{}/thresh.h".format(code_gen_dir), "w")
tdt_hls = tdt.get_hls_datatype_str()
# use binary to export bipolar activations
export_odt = self.get_output_datatype()
if self.get_output_datatype() == DataType.BIPOLAR:
export_odt = DataType.BINARY
odt_hls = export_odt.get_hls_datatype_str()
f_thresh.write(
"static ThresholdsActivation<{},{},{},{},{},{},{}> threshs \
= ".format(
self.calc_tmem(),
self.get_nodeattr("PE"),
threshold_tensor.shape[-1],
tdt_hls,
odt_hls,
self.get_nodeattr("ActVal"),
"std::less_equal<%s>" % tdt_hls,
)
)
f_thresh.write(thresholds_hls_code)
f_thresh.close()
%% Cell type:markdown id: tags:
<font size="3">First, the values for the weights are extracted with `get_initializer()` using the ModelWrapper. At this point it is assumed that the second input of the streamingfclayer specifies the weights. After a few manipulations the weights are written in `params.h`. If there are threshold values, they will be prepared and written to `thresh.h`. </font>
%% Cell type:code id: tags:
``` python
```
......
%% Cell type:markdown id: tags:
# FINN - ModelWrapper
--------------------------------------
<font size="3"> This notebook is about the ModelWrapper class within FINN.
Following showSrc function is used to print the source code of function calls in the Jupyter notebook:</font>
%% Cell type:code id: tags:
``` python
import inspect
def showSrc(what):
print("".join(inspect.getsourcelines(what)[0]))
```
%% Cell type:markdown id: tags:
## General Information
------------------------------
* <font size="3"> wrapper around ONNX ModelProto that exposes several utility
functions for graph manipulation and exploration </font>
* <font size="3"> ModelWrapper instance takes onnx model proto and `make_deepcopy` flag as input </font>
* <font size="3"> onnx model proto can either be a string with the path to a stored .onnx file on disk, or serialized bytes </font>
* <font size="3"> ModelWrapper instance takes ONNX ModelProto and `make_deepcopy` flag as input </font>
* <font size="3"> ONNX ModelProto can either be a string with the path to a stored .onnx file on disk, or serialized bytes </font>
* <font size="3"> `make_deepcopy` is by default False but can be set to True if a (deep) copy should be created </font>
%% Cell type:markdown id: tags:
### Create a ModelWrapper instance
<font size="3">Here we use a premade ONNX file on disk to load up the ModelWrapper, but this could have been produced from e.g. a trained Brevitas PyTorch model. See [this notebook](brevitas-network-import.ipynb) for more details.</font>
%% Cell type:code id: tags:
``` python
from finn.core.modelwrapper import ModelWrapper
model = ModelWrapper("LFCW1A1.onnx")
```
%% Cell type:markdown id: tags:
### Access the ONNX GraphProto through ModelWrapper
<font size="3">ModelWrapper is a thin wrapper around the ONNX protobuf, and it offers a range of helper functions as well as direct access to the underlying protobuf. The `.model` member gives access to the full ONNX ModelProto, whereas `.graph` gives access to the GraphProto, as follows:</font>
%% Cell type:code id: tags:
``` python
# access the ONNX ModelProto
modelproto = model.model
print("ModelProto IR version is %d" % modelproto.ir_version)
# the graph
graphproto = model.graph
print("GraphProto top-level outputs are %s" % str(graphproto.output))
# the node list
nodes = model.graph.node
print("There are %d nodes in the graph" % len(nodes))
print("The first node is \n%s" % str(nodes[0]))
```
%% Output
ModelProto IR version is 4
GraphProto top-level outputs are [name: "60"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 10
}
}
}
}
]
There are 29 nodes in the graph
The first node is
input: "0"
output: "21"
op_type: "Shape"
%% Cell type:markdown id: tags:
### Helper functions for tensors
<font size="3"> Every input and output of every node in the onnx model is represented as tensor with several properties (i.e. name, shape, data type). ModelWrapper provides some utility functions to work with the tensors. </font>
%% Cell type:markdown id: tags:
##### Get all tensor names
<font size="3">Produces a list of all tensor names (inputs, activations, weights, outputs...) in the graph.</font>
%% Cell type:code id: tags:
``` python
# get all tensor names
tensor_list = model.get_all_tensor_names()
print(tensor_list)
```
%% Output
['0', 'features.3.weight', 'features.3.bias', 'features.3.running_mean', 'features.3.running_var', 'features.7.weight', 'features.7.bias', 'features.7.running_mean', 'features.7.running_var', 'features.11.weight', 'features.11.bias', 'features.11.running_mean', 'features.11.running_var', '20', '23', '28', '30', '33', '34', '41', '42', '49', '50', '57', '58', '60']
%% Cell type:markdown id: tags:
##### Producer and consumer of a tensor
<font size="3">A tensor can have a producer node and/or a consumer node in the onnx model. ModelWrapper provides two helper functions to access these nodes, they are shown in the following.
It may be that a tensor does not have a producer or consumer node, for example if the tensor represents a constant that is already set. In that case `None` will be returned.</font>
%% Cell type:code id: tags:
``` python
# get random tensor and find producer and consumer (returns node)
tensor_name = tensor_list[25]
print("Producer node of tensor {}:".format(tensor_name))
print(model.find_producer(tensor_name))
tensor_name = tensor_list[0]
print("Consumer node of tensor {}:".format(tensor_name))
print(model.find_consumer(tensor_name))
print("Producer of tensor 0: %s" % str(model.find_producer("0")))
```
%% Output
Producer node of tensor 60:
input: "59"
input: "58"
output: "60"
op_type: "Mul"
Consumer node of tensor 0:
input: "0"
output: "21"
op_type: "Shape"
Producer of tensor 0: None
%% Cell type:markdown id: tags:
##### Tensor shape
<font size="3">Each tensor has a specific shape which can be accessed with the following ModelWrapper helper functions.</font>
%% Cell type:code id: tags:
``` python
# get tensor_shape
print("Shape of tensor 0 is %s" % str(model.get_tensor_shape("0")))
```
%% Output
Shape of tensor 0 is [1, 1, 28, 28]
%% Cell type:markdown id: tags:
<font size="3">It is also possible to set the tensor shape with a helper function. The syntax would be the following:
`onnx_model.set_tensor_shape(tensor_name, tensor_shape)`
Optionally, the dtype (container datatype) of the tensor can also be specified as third argument. By default it is set to TensorProto.FLOAT.
**Important:** dtype should not be confused with FINN data type, which specifies the quantization annotation. See the remarks about FINN-ONNX in [this notebook](finn-basics.ipynb). It is safest to use floating point tensors as the container data type for best compatibility inside FINN.</font>
%% Cell type:markdown id: tags:
##### Tensor FINN DataType
%% Cell type:markdown id: tags:
<font size="3">FINN introduces its [own data types](https://github.com/Xilinx/finn/blob/dev/src/finn/core/datatype.py) because ONNX does not natively support precisions less than 8 bits. FINN is about quantized neural networks, so precision of i.e. 4 bits, 3 bits, 2 bits or 1 bit are of interest. To represent the data within FINN, float tensors are used with additional annotation to specify the quantized data type of a tensor. The following helper functions are about this quantization annotation.</font>
%% Cell type:code id: tags:
``` python
# get tensor data type (FINN data type)
print("The FINN DataType of tensor 0 is " + str(model.get_tensor_datatype("0")))
print("The FINN DataType of tensor 32 is " + str(model.get_tensor_datatype("32")))
```
%% Output
The FINN DataType of tensor 0 is DataType.FLOAT32
The FINN DataType of tensor 32 is DataType.BIPOLAR
%% Cell type:markdown id: tags:
<font size="3">In addition to the get_tensor_datatatype() function, the (FINN) datatype of a tensor can be set using the `set_tensor_datatype(tensor_name, datatype)` function.</font>
%% Cell type:markdown id: tags:
##### Tensor initializers
<font size="3">Some tensors have *initializers*, like tensors that represent constants or i.e. the trained weight values.
ModelWrapper contains two helper functions for this case, one to determine the current initializer and one to set the initializer of a tensor. If there is no initializer, `None` is returned.</font>
%% Cell type:code id: tags:
``` python
# get tensor initializer
tensor_name = tensor_list[1]
print("Initializer for tensor 33:\n" + str(model.get_initializer("33")))
print("Initializer for tensor 0:\n" + str(model.get_initializer("0")))
```
%% Output
Initializer for tensor 33:
[[ 1. 1. 1. ... 1. 1. -1.]
[ 1. 1. -1. ... 1. 1. -1.]
[-1. 1. -1. ... -1. 1. -1.]
...
[-1. 1. -1. ... -1. -1. 1.]
[ 1. 1. -1. ... 1. 1. -1.]
[-1. 1. 1. ... -1. -1. 1.]]
Initializer for tensor 0:
None
%% Cell type:markdown id: tags:
<font size="3">Like for the other tensor helper functions there is a `set_initializer(tensor_name, tensor_value)` function.</font>
<font size="3">Like for the other tensor helper functions there is a corresponding set function (`set_initializer(tensor_name, tensor_value)`).</font>
%% Cell type:markdown id: tags:
### More helper functions
<font size="3">ModelWrapper contains more useful functions, if you are interested please have a look at the [Python code](https://github.com/Xilinx/finn/blob/dev/src/finn/core/modelwrapper.py) directly. Additionally, in the folder notebooks/ a Jupyter notebook about transformation passes and one about analysis passes can be found.</font>
<font size="3">ModelWrapper contains more useful functions, if you are interested please have a look at the [Python code](https://github.com/Xilinx/finn/blob/dev/src/finn/core/modelwrapper.py) directly. Additionally, in the folder notebooks/ a Jupyter notebook about transformation passes [FINN-HowToTransformationPass](FINN-HowToTransformationPass.ipynb) and one about analysis passes [FINN-HowToAnalysisPass](FINN-HowToAnalysisPass.ipynb) can be found.</font>
%% Cell type:code id: tags:
``` python
```
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment