Skip to content
Snippets Groups Projects
Commit 8ab4c41f authored by auphelia's avatar auphelia
Browse files

[Streamline] Add AbsorbTransposeIntoFlatten trafo

parent cf54f617
No related branches found
No related tags found
No related merge requests found
......@@ -28,8 +28,10 @@
import numpy as np
from onnx import helper as oh
import warnings
from finn.core.datatype import DataType
import finn.core.data_layout as DataLayout
from finn.transformation import Transformation
from finn.util.basic import get_by_name
from finn.custom_op.registry import getCustomOp
......@@ -290,3 +292,65 @@ class AbsorbTransposeIntoMultiThreshold(Transformation):
if graph_modified:
model = model.transform(InferDataTypes())
return (model, graph_modified)
class AbsorbTransposeIntoFlatten(Transformation):
"""Absorb transpose node into succeeding flatten node, if H=W=1 and the first
dimension stays the same. Can also be applied if flatten is implemented implicitly
by a reshape node with shape [1, -1] and the first input dimension is 1"""
def apply(self, model):
graph = model.graph
graph_modified = False
node_ind = 0
for n in graph.node:
node_ind += 1
if (
n.op_type == "Reshape"
and (model.get_initializer(n.input[1]) == [1, -1]).all()
) or n.op_type == "Flatten":
prod = model.find_producer(n.input[0])
if (
prod is not None
and prod.op_type == "Transpose"
# we ensure that the first dimension is not changed from the
# transpose operation
and get_by_name(prod.attribute, "perm").ints[0] == 0
):
data_layout = model.get_tensor_layout(prod.input[0])
# check for the data layout to interpret input shape correctly
if data_layout is None:
warnings.warn(
"""Datalayout for input tensor of transpose node is not set.
To use AbsorbTransposeIntoFlatten transformation
please set tensor data layout."""
)
continue
elif data_layout == DataLayout.NCHW:
(b, c, h, w) = model.get_tensor_shape(prod.input[0])
# if h=w=1 the transposition can be absorbed, otherwise
# the absorption would lead to an error in the behavior
if h != 1 or w != 1:
continue
# the flatten node from onnx keeps by default the first
# dim and flattens the rest, that is why this transformation
# can only works with b != 1 if the model contains already a
# flatten node and not a reshape node with shape = [1, -1].
# If the first dim of the input tensor is not 1, flatten and
# reshape (with shape = [1, -1]) would lead to different results
if n.op_type == "Reshape" and b != 1:
continue
elif data_layout == DataLayout.NHWC:
(b, h, w, c) = model.get_tensor_shape(prod.input[0])
if h != 1 or w != 1:
continue
if n.op_type == "Reshape" and b != 1:
continue
# create single flatten node and remove obsolete nodes
node = oh.make_node("Flatten", [prod.input[0]], [n.output[0]])
graph.node.remove(n)
graph.node.remove(prod)
graph.node.insert(node_ind, node)
graph_modified = True
model = model.transform(InferDataTypes())
return (model, graph_modified)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment