Skip to content
Snippets Groups Projects
Commit 8ab4c41f authored by auphelia's avatar auphelia
Browse files

[Streamline] Add AbsorbTransposeIntoFlatten trafo

parent cf54f617
No related branches found
No related tags found
No related merge requests found
...@@ -28,8 +28,10 @@ ...@@ -28,8 +28,10 @@
import numpy as np import numpy as np
from onnx import helper as oh from onnx import helper as oh
import warnings
from finn.core.datatype import DataType from finn.core.datatype import DataType
import finn.core.data_layout as DataLayout
from finn.transformation import Transformation from finn.transformation import Transformation
from finn.util.basic import get_by_name from finn.util.basic import get_by_name
from finn.custom_op.registry import getCustomOp from finn.custom_op.registry import getCustomOp
...@@ -290,3 +292,65 @@ class AbsorbTransposeIntoMultiThreshold(Transformation): ...@@ -290,3 +292,65 @@ class AbsorbTransposeIntoMultiThreshold(Transformation):
if graph_modified: if graph_modified:
model = model.transform(InferDataTypes()) model = model.transform(InferDataTypes())
return (model, graph_modified) return (model, graph_modified)
class AbsorbTransposeIntoFlatten(Transformation):
"""Absorb transpose node into succeeding flatten node, if H=W=1 and the first
dimension stays the same. Can also be applied if flatten is implemented implicitly
by a reshape node with shape [1, -1] and the first input dimension is 1"""
def apply(self, model):
graph = model.graph
graph_modified = False
node_ind = 0
for n in graph.node:
node_ind += 1
if (
n.op_type == "Reshape"
and (model.get_initializer(n.input[1]) == [1, -1]).all()
) or n.op_type == "Flatten":
prod = model.find_producer(n.input[0])
if (
prod is not None
and prod.op_type == "Transpose"
# we ensure that the first dimension is not changed from the
# transpose operation
and get_by_name(prod.attribute, "perm").ints[0] == 0
):
data_layout = model.get_tensor_layout(prod.input[0])
# check for the data layout to interpret input shape correctly
if data_layout is None:
warnings.warn(
"""Datalayout for input tensor of transpose node is not set.
To use AbsorbTransposeIntoFlatten transformation
please set tensor data layout."""
)
continue
elif data_layout == DataLayout.NCHW:
(b, c, h, w) = model.get_tensor_shape(prod.input[0])
# if h=w=1 the transposition can be absorbed, otherwise
# the absorption would lead to an error in the behavior
if h != 1 or w != 1:
continue
# the flatten node from onnx keeps by default the first
# dim and flattens the rest, that is why this transformation
# can only works with b != 1 if the model contains already a
# flatten node and not a reshape node with shape = [1, -1].
# If the first dim of the input tensor is not 1, flatten and
# reshape (with shape = [1, -1]) would lead to different results
if n.op_type == "Reshape" and b != 1:
continue
elif data_layout == DataLayout.NHWC:
(b, h, w, c) = model.get_tensor_shape(prod.input[0])
if h != 1 or w != 1:
continue
if n.op_type == "Reshape" and b != 1:
continue
# create single flatten node and remove obsolete nodes
node = oh.make_node("Flatten", [prod.input[0]], [n.output[0]])
graph.node.remove(n)
graph.node.remove(prod)
graph.node.insert(node_ind, node)
graph_modified = True
model = model.transform(InferDataTypes())
return (model, graph_modified)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment