Skip to content
Snippets Groups Projects
Commit b464b112 authored by auphelia's avatar auphelia
Browse files

[notebook - trafo passes] Fixed grammar mistakes and included netron visualization

parent 6cc91587
No related branches found
No related tags found
No related merge requests found
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
# FINN - Transformation passes # FINN - Transformation passes
-------------------------------------- --------------------------------------
<font size="3">In this notebook the idea behind transformation passes in FINN will be explained and with the help of an example the procedure of a transformation will be shown. <font size="3">In this notebook the idea behind transformation passes in FINN will be explained and with the help of an example the procedure of a transformation will be shown.
Following showSrc function is used to print the source code of function calls in the Jupyter notebook:</font> Following showSrc function is used to print the source code of function calls in the Jupyter notebook:</font>
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
import inspect import inspect
def showSrc(what): def showSrc(what):
print("".join(inspect.getsourcelines(what)[0])) print("".join(inspect.getsourcelines(what)[0]))
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
## General Information
-----------------------------
* <font size="3">changes (transforms) the given graph</font> * <font size="3">changes (transforms) the given graph</font>
* <font size="3">input: ModelWrapper</font> * <font size="3">input: ModelWrapper</font>
* <font size="3">returns the changed model (ModelWrapper) and flag `model_was_changed`</font> * <font size="3">returns the changed model (ModelWrapper) and flag `model_was_changed`</font>
%% Cell type:markdown id: tags:
## General Information
-----------------------------
<font size="3">Transformation passes have a base class and must inherit from that. Transformations are meant to be applied using .transform function from the ModelWrapper. This function makes a deep copy of the input model by default. The next cell shows .transform of ModelWrapper. </font> <font size="3">Transformation passes have a base class and must inherit from that. Transformations are meant to be applied using .transform function from the ModelWrapper. This function makes a deep copy of the input model by default. The next cell shows .transform of ModelWrapper. </font>
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
### .transform() from ModelWrapper ### .transform() from ModelWrapper
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
from finn.core.modelwrapper import ModelWrapper from finn.core.modelwrapper import ModelWrapper
showSrc(ModelWrapper.transform) showSrc(ModelWrapper.transform)
``` ```
%% Output %% Output
def transform(self, transformation, make_deepcopy=True): def transform(self, transformation, make_deepcopy=True):
"""Applies given Transformation repeatedly until no more changes can be made """Applies given Transformation repeatedly until no more changes can be made
and returns a transformed ModelWrapper instance. and returns a transformed ModelWrapper instance.
If make_deepcopy is specified, operates on a new (deep)copy of model. If make_deepcopy is specified, operates on a new (deep)copy of model.
""" """
transformed_model = self transformed_model = self
if make_deepcopy: if make_deepcopy:
transformed_model = copy.deepcopy(self) transformed_model = copy.deepcopy(self)
model_was_changed = True model_was_changed = True
while model_was_changed: while model_was_changed:
(transformed_model, model_was_changed) = transformation.apply( (transformed_model, model_was_changed) = transformation.apply(
transformed_model transformed_model
) )
return transformed_model return transformed_model
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
<font size="3">When the function is called, the model, the name of the transformation and, if required, the flag make_deepcopy are passed. It is also possible not to make a copy of the model. In this case `make_deepcopy` must be set to False. Then the branch `if make_deepcopy:` would not be taken and no copy of the model would be made. <font size="3">When the function is called, the model, the name of the transformation and, if required, the flag make_deepcopy are passed. It is also possible not to make a copy of the model. In this case `make_deepcopy` must be set to False. Then the branch `if make_deepcopy:` would not be taken and no copy of the model would be made.
The unchanged model is first passed to the variable `transformed_model` to pass this variable on to the transformation later. The unchanged model is first passed to the variable `transformed_model` to pass this variable on to the transformation later.
`model_was_changed` indicates whether the transformation needs to be applied more then once. Because it needs to be applied at least one time `model_was_changed` is first set to True and then depending on the return values of the transformation function the transformation can be applied more then once. `model_was_changed` indicates whether the transformation needs to be applied more then once. Because it needs to be applied at least one time `model_was_changed` is first set to True and then depending on the return values of the transformation function the transformation can be applied more then once.
**Important**: Due to the structure of this function, `model_was_changed` must be set to False at some point. Otherwise the loop is infinite. **Important**: Due to the structure of this function, `model_was_changed` must be set to False at some point. Otherwise the loop is infinite.
Each new transformation must correspond to the scheme of the base class and contain at least the function `apply(model)`, which returns the changed model and a bool value for `model_was_changed`. Each new transformation must correspond to the scheme of the base class and contain at least the function `apply(model)`, which returns the changed model and a bool value for `model_was_changed`.
</font> </font>
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
### transformation base class ### transformation base class
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
from finn.transformation import Transformation from finn.transformation import Transformation
showSrc(Transformation) showSrc(Transformation)
``` ```
%% Output %% Output
class Transformation(ABC): class Transformation(ABC):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@abstractmethod @abstractmethod
def apply(self, model): def apply(self, model):
pass pass
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
<font size="3"> Base class is abstract class (`import ABC`) with only one abstract method (`apply()`) which gets the model as input. To show what a transformation should look like, the following example is taken from FINN. </font> <font size="3"> Base class is abstract class (`import ABC`) with only one abstract method (`apply()`) which gets the model as input. To show what a transformation should look like, the following example is taken from FINN. </font>
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
## Example - ConvertSubToAdd ## Example - ConvertSubToAdd
----------------------------- -----------------------------
<font size="3">The transformation replaces all subtraction nodes in a model with addition nodes with appropriate sign. For that an onnx model is loaded which contains one subtraction node.</font> <font size="3">The transformation replaces all subtraction nodes in a model with addition nodes with appropriate sign. For that an onnx model is loaded which contains one subtraction node.
Netron is used to visualize the result before and after. </font>
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
import onnx import onnx
onnx_model = onnx.load('LFCW1A1.onnx') onnx_model = onnx.load('LFCW1A1.onnx')
from finn.core.modelwrapper import ModelWrapper from finn.core.modelwrapper import ModelWrapper
onnx_model = ModelWrapper(onnx_model) onnx_model = ModelWrapper(onnx_model)
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
import netron
netron.start('LFCW1A1.onnx', port=8081, host="0.0.0.0")
```
%% Output
Stopping http://0.0.0.0:8081
Serving 'LFCW1A1.onnx' at http://0.0.0.0:8081
%% Cell type:code id: tags:
``` python
%%html
<iframe src="http://0.0.0.0:8081/" style="position: relative; width: 100%;" height="400"></iframe>
```
%% Output
%% Cell type:code id: tags:
``` python
from finn.transformation import Transformation from finn.transformation import Transformation
class ConvertSubToAdd(Transformation): class ConvertSubToAdd(Transformation):
def apply(self, model): def apply(self, model):
graph = model.graph graph = model.graph
for n in graph.node: for n in graph.node:
if n.op_type == "Sub": if n.op_type == "Sub":
A = model.get_initializer(n.input[1]) A = model.get_initializer(n.input[1])
if A is not None: if A is not None:
n.op_type = "Add" n.op_type = "Add"
model.set_initializer(n.input[1], -A) model.set_initializer(n.input[1], -A)
return (model, False) return (model, False)
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
<font size="3">First the transformation class must be imported. Then a class can be created for the new transformation, which is derived from the base class. In this case the transformation has only the `apply()` function. <font size="3">First the transformation class must be imported. Then a class can be created for the new transformation, which is derived from the base class. In this case the transformation has only the `apply()` function.
All nodes are checked by first extracting the graph from the model and then iterating over the node list. With the help of .op_type the operation type of the node can be checked, if the node is a subtraction node the condition `if n.op_type == "Sub"` is true. It may be that the subtraction input of the node has no value, this is checked with `model.get_initializer(n.input[1])`. All nodes are checked by first extracting the graph from the model and then iterating over the node list. With the help of .op_type the operation type of the node can be checked, if the node is a subtraction node the condition `if n.op_type == "Sub"` is true. It may be that the subtraction input of the node has no value, this is checked with `model.get_initializer(n.input[1])`.
**Important:** FINN always assumes a certain order of inputs, this is especially important if you want to create additional custom operation nodes. **Important:** FINN always assumes a certain order of inputs, this is especially important if you want to create additional custom operation nodes.
When the input is initialized, the operation type of the node is converted to `"Add"`, this can simply be done by using the equal sign. Then the sign of the initial value must be changed. For this the ModelWrapper function `.set_initializer` can be used. When the input is initialized, the operation type of the node is converted to `"Add"`, this can simply be done by using the equal sign. Then the sign of the initial value must be changed. For this the ModelWrapper function `.set_initializer` can be used.
At the end the changed model is returned and `model_was_changed` is set to False, because the transformation has to be executed only once.</font> At the end the changed model is returned and `model_was_changed` is set to False, because the transformation has to be executed only once.</font>
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
onnx_model_transformed = onnx_model.transform(ConvertSubToAdd()) onnx_model_transformed = onnx_model.transform(ConvertSubToAdd())
onnx_model_transformed.save('LFCW1A1_changed.onnx') onnx_model_transformed.save('LFCW1A1_changed.onnx')
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
netron.start('LFCW1A1_changed.onnx', port=8081, host="0.0.0.0")
```
%% Output
Stopping http://0.0.0.0:8081
Serving 'LFCW1A1_changed.onnx' at http://0.0.0.0:8081
%% Cell type:code id: tags:
``` python
%%html
<iframe src="http://0.0.0.0:8081/" style="position: relative; width: 100%;" height="400"></iframe>
```
%% Output
%% Cell type:code id: tags:
``` python
``` ```
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment