diff --git a/src/finn/transformation/streamline/reorder.py b/src/finn/transformation/streamline/reorder.py index 2b03532ce3ba7d5159e5ae57e61c2af9c8c37fce..b47f269dd6f2671c3d98c9316954483c0e72f14f 100644 --- a/src/finn/transformation/streamline/reorder.py +++ b/src/finn/transformation/streamline/reorder.py @@ -502,6 +502,73 @@ class MoveLinearPastEltwiseAdd(Transformation): return (model, graph_modified) +class MoveScalarLinearPastInvariants(Transformation): + """Move scalar linear operations (mul, add) past functions which are invariant + to them. Specifically, matches and transforms the following patterns: + f(x*C) -> f(x) * C + f(x+C) -> f(x) + C + where x is a dynamic input, C is a constant tensor. + Known f which obey this property are: Reshape, Flatten, Transpose, + GlobalAveragePool + """ + + def apply(self, model): + graph = model.graph + node_ind = 0 + graph_modified = False + nodes = [n for n in graph.node] + for n in nodes: + node_ind += 1 + if ( + n.op_type == "GlobalAveragePool" + or n.op_type == "Reshape" + or n.op_type == "Transpose" + or n.op_type == "Flatten" + ): + in0 = n.input[0] + if in0 is None: + continue + # find and check producer on our input + prod0 = model.find_producer(in0) + if prod0 is None: + continue + + if prod0.op_type == "Mul" or prod0.op_type == "Add": + # check if second input of producer is an initializer + init0 = model.get_initializer(prod0.input[1]) + # if either initializer is None, skip + if init0 is None: + continue + # if initializer is not scalar, skip + if np.prod(init0.shape) != 1: + continue + # move prod0 from input to output, + old_prod0_in = prod0.input[0] + old_prod0_out = prod0.output[0] + scalar_op_odt = model.get_tensor_datatype(old_prod0_out) + old_n_out = n.output[0] + in_shape = model.get_tensor_shape(n.input[0]) + out_shape = model.get_tensor_shape(n.output[0]) + n.input[0] = old_prod0_in + n.output[0] = old_prod0_out + prod0.input[0] = old_prod0_out + prod0.output[0] = old_n_out + model.set_tensor_shape(n.input[0], in_shape) + model.set_tensor_shape(n.output[0], out_shape) + model.set_tensor_shape(prod0.output[0], out_shape) + model.set_tensor_datatype(prod0.output[0], scalar_op_odt) + model.set_tensor_datatype(n.output[0], DataType.FLOAT32) + graph.node.remove(prod0) + graph.node.insert(node_ind - 1, prod0) + graph_modified = True + else: + continue + if graph_modified: + model = model.transform(InferShapes()) + model = model.transform(InferDataTypes()) + return (model, graph_modified) + + class MakeMaxPoolNHWC(Transformation): """Convert (MaxPool, NHWCTranpose) into (MaxPoolNHWC).""" @@ -685,6 +752,7 @@ class MoveMaxPoolPastMultiThreshold(Transformation): model = model.transform(InferShapes()) return (model, graph_modified) + class MoveFlattenPastTopK(Transformation): """Move flatten node past a succeeding topk node, if the "axis" attribute in topk is set to -1 and the data layout before the flatten is NHWC with H=W=1""" @@ -745,6 +813,7 @@ class MoveFlattenPastTopK(Transformation): model = model.transform(InferShapes()) return (model, graph_modified) + class MoveFlattenPastAffine(Transformation): """Moves a node that implements a (1, -1) reshape past a MatMul, Mul or Add node.""" @@ -831,9 +900,10 @@ class MoveFlattenPastAffine(Transformation): model = model.transform(InferShapes()) model = model.transform(InferDataTypes()) - model = model.transform(InferDataLayouts()) + model = model.transform(InferDataLayouts()) return (model, graph_modified) - + + class MoveTransposePastScalarMul(Transformation): """Moves a Transpose node past a scalar Mul node""" @@ -895,4 +965,3 @@ class MoveTransposePastScalarMul(Transformation): model = model.transform(InferDataLayouts()) model = model.transform(InferShapes()) return (model, graph_modified) -