Skip to content
Snippets Groups Projects
Unverified Commit 31a3db47 authored by Yaman Umuroglu's avatar Yaman Umuroglu Committed by GitHub
Browse files

Merge pull request #169 from Xilinx/feature/absorb_transp_in_flatten

Feature/absorb transp in flatten
parents bb1db2e5 0f7e8ba1
No related branches found
No related tags found
No related merge requests found
......@@ -31,6 +31,7 @@ from onnx import helper as oh
import warnings
from finn.core.datatype import DataType
import finn.core.data_layout as DataLayout
from finn.transformation import Transformation
from finn.util.basic import get_by_name
from finn.custom_op.registry import getCustomOp
......@@ -357,7 +358,68 @@ class AbsorbTransposeIntoMultiThreshold(Transformation):
model = model.transform(InferDataTypes())
return (model, graph_modified)
class AbsorbTransposeIntoFlatten(Transformation):
"""Absorb transpose node into succeeding flatten node, if H=W=1 and the first
dimension stays the same. Can also be applied if flatten is implemented implicitly
by a reshape node with shape [1, -1] and the first input dimension is 1"""
def apply(self, model):
graph = model.graph
graph_modified = False
node_ind = 0
for n in graph.node:
node_ind += 1
if (
n.op_type == "Reshape"
and (model.get_initializer(n.input[1]) == [1, -1]).all()
) or n.op_type == "Flatten":
prod = model.find_producer(n.input[0])
if (
prod is not None
and prod.op_type == "Transpose"
# we ensure that the first dimension is not changed from the
# transpose operation
and get_by_name(prod.attribute, "perm").ints[0] == 0
):
data_layout = model.get_tensor_layout(prod.input[0])
# check for the data layout to interpret input shape correctly
if data_layout is None:
warnings.warn(
"""Data layout for input tensor of Transpose node is not set.
To use AbsorbTransposeIntoFlatten transformation
please set tensor data layout."""
)
continue
elif data_layout == DataLayout.NCHW:
(b, c, h, w) = model.get_tensor_shape(prod.input[0])
# if h=w=1 the transposition can be absorbed, otherwise
# the absorption would lead to an error in the behavior
if h != 1 or w != 1:
continue
# the flatten node from onnx keeps by default the first
# dim and flattens the rest, that is why this transformation
# can only work with b != 1 if the model contains already a
# flatten node and not a reshape node with shape = [1, -1].
# If the first dim of the input tensor is not 1, flatten and
# reshape (with shape = [1, -1]) would lead to different results
if n.op_type == "Reshape" and b != 1:
continue
elif data_layout == DataLayout.NHWC:
(b, h, w, c) = model.get_tensor_shape(prod.input[0])
if h != 1 or w != 1:
continue
if n.op_type == "Reshape" and b != 1:
continue
# create single flatten node and remove obsolete nodes
node = oh.make_node("Flatten", [prod.input[0]], [n.output[0]])
graph.node.remove(n)
graph.node.remove(prod)
graph.node.insert(node_ind, node)
graph_modified = True
if graph_modified:
model = model.transform(InferDataTypes())
return (model, graph_modified)
class AbsorbScalarMulIntoTopK(Transformation):
"""Absorb a mul node into a suceeding topk node if the mul is scalar."""
......
import pytest
import numpy as np
from onnx import TensorProto, helper
from finn.core.modelwrapper import ModelWrapper
import finn.core.data_layout as DataLayout
from finn.transformation.infer_shapes import InferShapes
from finn.transformation.infer_datatypes import InferDataTypes
from finn.transformation.infer_data_layouts import InferDataLayouts
from finn.transformation.general import GiveUniqueNodeNames, GiveReadableTensorNames
from finn.transformation.streamline.absorb import AbsorbTransposeIntoFlatten
import finn.core.onnx_exec as oxe
# permutation of transpose node
@pytest.mark.parametrize("perm", [[0, 2, 3, 1], [0, 1, 3, 2], [3, 2, 0, 1]])
# reshape or flatten
@pytest.mark.parametrize("shape", [None, [1, -1], [-1, 1]])
# input shape
@pytest.mark.parametrize("ishape", [[1, 1, 1, 4], [2, 4, 1, 1], [1, 2, 2, 4]])
# datalayout
@pytest.mark.parametrize("data_layout", ["NCHW", "NHWC"])
def test_absorb_transp_into_flatten(perm, shape, ishape, data_layout):
inp = helper.make_tensor_value_info("inp", TensorProto.FLOAT, ishape)
transp_node = helper.make_node("Transpose", ["inp"], ["transp_out"], perm=perm)
dummy_in = np.random.uniform(low=0, high=1, size=tuple(ishape)).astype(np.float32)
if shape is None:
shape_node = helper.make_node("Flatten", ["transp_out"], ["outp"])
dummy_in = dummy_in.transpose(tuple(perm))
oshape = dummy_in.reshape(dummy_in.shape[0], -1).shape
outp = helper.make_tensor_value_info("outp", TensorProto.FLOAT, oshape)
shape0 = None
else:
shape0 = helper.make_tensor_value_info("shape0", TensorProto.FLOAT, shape)
shape_node = helper.make_node("Reshape", ["transp_out", "shape0"], ["outp"])
oshape = dummy_in.transpose(tuple(perm)).reshape(tuple(shape)).shape
outp = helper.make_tensor_value_info("outp", TensorProto.FLOAT, oshape)
graph = helper.make_graph(
nodes=[transp_node, shape_node],
name="absorb-transpose-graph",
inputs=[inp],
outputs=[outp],
)
model = helper.make_model(graph, producer_name="absorb_transpose_model")
model = ModelWrapper(model)
if shape is not None:
model.graph.value_info.append(shape0)
model.set_initializer("shape0", np.asarray(shape))
if data_layout == "NCHW":
model.set_tensor_layout("inp", DataLayout.NCHW)
else:
model.set_tensor_layout("inp", DataLayout.NHWC)
model = model.transform(InferShapes())
model = model.transform(InferDataTypes())
model = model.transform(InferDataLayouts())
model = model.transform(GiveUniqueNodeNames())
model = model.transform(GiveReadableTensorNames())
model.save("test.onnx")
model_transformed = model.transform(AbsorbTransposeIntoFlatten())
model_transformed.save("test2.onnx")
# verify transformation
inp_values = np.random.uniform(low=-1, high=1, size=tuple(ishape)).astype(
np.float32
)
idict = {model.graph.input[0].name: inp_values}
assert oxe.compare_execution(model, model_transformed, idict)
# only some of the parameter combinations lead to a graph that will be changed when
# AbsorbTransposeIntoFlatten is applied
if shape == [-1, 1]: # not a flatten operation, so the graph will not be changed
assert model.graph == model_transformed.graph
elif perm == [
3,
2,
0,
1,
]: # the first dimension is also part of the transpose operation
# so the graph will not be changed
assert model.graph == model_transformed.graph
# the following cases are the ones in which the model is transformed
# because we tested the parameters shape and perm befire we can only consider ishape
# and data_layout (the transformed model should only contain a "Flatten" node)
elif ishape == [1, 1, 1, 4] and data_layout == "NHWC":
assert model_transformed.graph.node[0].op_type == "Flatten"
elif ishape == [2, 4, 1, 1] and data_layout == "NCHW" and shape is None:
# If the first dimension of the input tensor is not 1, flatten and
# reshape (with shape = [1, -1]) would lead to different results
assert model_transformed.graph.node[0].op_type == "Flatten"
# all other cases lead to an unchanged model
else:
assert model.graph == model_transformed.graph
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment