diff --git a/src/finn/transformation/streamline/absorb.py b/src/finn/transformation/streamline/absorb.py index dc01eea411fc1f640e481c9be02a92acdd59533f..f089275c221f769daace3e9628a00bf87b4e5457 100644 --- a/src/finn/transformation/streamline/absorb.py +++ b/src/finn/transformation/streamline/absorb.py @@ -31,6 +31,7 @@ from onnx import helper as oh import warnings from finn.core.datatype import DataType +import finn.core.data_layout as DataLayout from finn.transformation import Transformation from finn.util.basic import get_by_name from finn.custom_op.registry import getCustomOp @@ -357,7 +358,68 @@ class AbsorbTransposeIntoMultiThreshold(Transformation): model = model.transform(InferDataTypes()) return (model, graph_modified) +class AbsorbTransposeIntoFlatten(Transformation): + """Absorb transpose node into succeeding flatten node, if H=W=1 and the first + dimension stays the same. Can also be applied if flatten is implemented implicitly + by a reshape node with shape [1, -1] and the first input dimension is 1""" + def apply(self, model): + graph = model.graph + graph_modified = False + node_ind = 0 + for n in graph.node: + node_ind += 1 + if ( + n.op_type == "Reshape" + and (model.get_initializer(n.input[1]) == [1, -1]).all() + ) or n.op_type == "Flatten": + prod = model.find_producer(n.input[0]) + if ( + prod is not None + and prod.op_type == "Transpose" + # we ensure that the first dimension is not changed from the + # transpose operation + and get_by_name(prod.attribute, "perm").ints[0] == 0 + ): + data_layout = model.get_tensor_layout(prod.input[0]) + # check for the data layout to interpret input shape correctly + if data_layout is None: + warnings.warn( + """Data layout for input tensor of Transpose node is not set. + To use AbsorbTransposeIntoFlatten transformation + please set tensor data layout.""" + ) + continue + elif data_layout == DataLayout.NCHW: + (b, c, h, w) = model.get_tensor_shape(prod.input[0]) + # if h=w=1 the transposition can be absorbed, otherwise + # the absorption would lead to an error in the behavior + if h != 1 or w != 1: + continue + # the flatten node from onnx keeps by default the first + # dim and flattens the rest, that is why this transformation + # can only work with b != 1 if the model contains already a + # flatten node and not a reshape node with shape = [1, -1]. + # If the first dim of the input tensor is not 1, flatten and + # reshape (with shape = [1, -1]) would lead to different results + if n.op_type == "Reshape" and b != 1: + continue + elif data_layout == DataLayout.NHWC: + (b, h, w, c) = model.get_tensor_shape(prod.input[0]) + if h != 1 or w != 1: + continue + if n.op_type == "Reshape" and b != 1: + continue + # create single flatten node and remove obsolete nodes + node = oh.make_node("Flatten", [prod.input[0]], [n.output[0]]) + graph.node.remove(n) + graph.node.remove(prod) + graph.node.insert(node_ind, node) + graph_modified = True + if graph_modified: + model = model.transform(InferDataTypes()) + return (model, graph_modified) + class AbsorbScalarMulIntoTopK(Transformation): """Absorb a mul node into a suceeding topk node if the mul is scalar.""" diff --git a/tests/transformation/test_absorb_transp_into_flatten.py b/tests/transformation/test_absorb_transp_into_flatten.py new file mode 100644 index 0000000000000000000000000000000000000000..fbfa15277717c554da01e38608601997407803b2 --- /dev/null +++ b/tests/transformation/test_absorb_transp_into_flatten.py @@ -0,0 +1,99 @@ +import pytest + +import numpy as np +from onnx import TensorProto, helper + +from finn.core.modelwrapper import ModelWrapper +import finn.core.data_layout as DataLayout +from finn.transformation.infer_shapes import InferShapes +from finn.transformation.infer_datatypes import InferDataTypes +from finn.transformation.infer_data_layouts import InferDataLayouts +from finn.transformation.general import GiveUniqueNodeNames, GiveReadableTensorNames +from finn.transformation.streamline.absorb import AbsorbTransposeIntoFlatten +import finn.core.onnx_exec as oxe + +# permutation of transpose node +@pytest.mark.parametrize("perm", [[0, 2, 3, 1], [0, 1, 3, 2], [3, 2, 0, 1]]) +# reshape or flatten +@pytest.mark.parametrize("shape", [None, [1, -1], [-1, 1]]) +# input shape +@pytest.mark.parametrize("ishape", [[1, 1, 1, 4], [2, 4, 1, 1], [1, 2, 2, 4]]) +# datalayout +@pytest.mark.parametrize("data_layout", ["NCHW", "NHWC"]) +def test_absorb_transp_into_flatten(perm, shape, ishape, data_layout): + inp = helper.make_tensor_value_info("inp", TensorProto.FLOAT, ishape) + transp_node = helper.make_node("Transpose", ["inp"], ["transp_out"], perm=perm) + dummy_in = np.random.uniform(low=0, high=1, size=tuple(ishape)).astype(np.float32) + if shape is None: + shape_node = helper.make_node("Flatten", ["transp_out"], ["outp"]) + dummy_in = dummy_in.transpose(tuple(perm)) + oshape = dummy_in.reshape(dummy_in.shape[0], -1).shape + outp = helper.make_tensor_value_info("outp", TensorProto.FLOAT, oshape) + shape0 = None + else: + shape0 = helper.make_tensor_value_info("shape0", TensorProto.FLOAT, shape) + shape_node = helper.make_node("Reshape", ["transp_out", "shape0"], ["outp"]) + oshape = dummy_in.transpose(tuple(perm)).reshape(tuple(shape)).shape + outp = helper.make_tensor_value_info("outp", TensorProto.FLOAT, oshape) + + graph = helper.make_graph( + nodes=[transp_node, shape_node], + name="absorb-transpose-graph", + inputs=[inp], + outputs=[outp], + ) + + model = helper.make_model(graph, producer_name="absorb_transpose_model") + model = ModelWrapper(model) + if shape is not None: + model.graph.value_info.append(shape0) + model.set_initializer("shape0", np.asarray(shape)) + if data_layout == "NCHW": + model.set_tensor_layout("inp", DataLayout.NCHW) + else: + model.set_tensor_layout("inp", DataLayout.NHWC) + model = model.transform(InferShapes()) + model = model.transform(InferDataTypes()) + model = model.transform(InferDataLayouts()) + model = model.transform(GiveUniqueNodeNames()) + model = model.transform(GiveReadableTensorNames()) + model.save("test.onnx") + model_transformed = model.transform(AbsorbTransposeIntoFlatten()) + model_transformed.save("test2.onnx") + + # verify transformation + inp_values = np.random.uniform(low=-1, high=1, size=tuple(ishape)).astype( + np.float32 + ) + idict = {model.graph.input[0].name: inp_values} + assert oxe.compare_execution(model, model_transformed, idict) + + # only some of the parameter combinations lead to a graph that will be changed when + # AbsorbTransposeIntoFlatten is applied + + if shape == [-1, 1]: # not a flatten operation, so the graph will not be changed + assert model.graph == model_transformed.graph + + elif perm == [ + 3, + 2, + 0, + 1, + ]: # the first dimension is also part of the transpose operation + # so the graph will not be changed + assert model.graph == model_transformed.graph + + # the following cases are the ones in which the model is transformed + # because we tested the parameters shape and perm befire we can only consider ishape + # and data_layout (the transformed model should only contain a "Flatten" node) + elif ishape == [1, 1, 1, 4] and data_layout == "NHWC": + assert model_transformed.graph.node[0].op_type == "Flatten" + + elif ishape == [2, 4, 1, 1] and data_layout == "NCHW" and shape is None: + # If the first dimension of the input tensor is not 1, flatten and + # reshape (with shape = [1, -1]) would lead to different results + assert model_transformed.graph.node[0].op_type == "Flatten" + + # all other cases lead to an unchanged model + else: + assert model.graph == model_transformed.graph