diff --git a/src/finn/transformation/streamline/absorb.py b/src/finn/transformation/streamline/absorb.py index 82cc6b42d01103fbed2a01001fa55daeb19d5d22..65e67721c8b6b44af060ba59eddfcd72dadc5fc7 100644 --- a/src/finn/transformation/streamline/absorb.py +++ b/src/finn/transformation/streamline/absorb.py @@ -28,14 +28,81 @@ import numpy as np from onnx import helper as oh +import warnings from finn.core.datatype import DataType +import finn.core.data_layout as DataLayout from finn.transformation import Transformation from finn.util.basic import get_by_name from finn.custom_op.registry import getCustomOp +from finn.transformation.infer_shapes import InferShapes from finn.transformation.infer_datatypes import InferDataTypes +class AbsorbSignBiasIntoMultiThreshold(Transformation): + """Absorb scalar bias originating from signed int export back into + MultiThreshold and re-evaluate the output datatype.""" + + def apply(self, model): + graph = model.graph + node_ind = 0 + graph_modified = False + for n in graph.node: + # search for (MultiThreshold, Add) pair + node_ind += 1 + if ( + n.op_type == "MultiThreshold" + and not model.is_fork_node(n) + and not model.is_join_node(n) + ): + consumer = model.find_consumer(n.output[0]) + if consumer is not None and consumer.op_type == "Add": + mt_node = n + add_node = consumer + threshold_name = mt_node.input[1] + add_weight_name = add_node.input[1] + T = model.get_initializer(threshold_name) + A = model.get_initializer(add_weight_name) + if (A is None) or (T is None): + warnings.warn("Threshold or add bias not constant, skipping") + continue + end_name = add_node.output[0] + # we can only absorb scalar adds + is_scalar = A.ndim == 0 or all(x == 1 for x in A.shape) + if not is_scalar: + continue + bias = A.flatten()[0] + # set MultiThreshold bias property + mt_inst = getCustomOp(mt_node) + bias += mt_inst.get_nodeattr("out_bias") + mt_inst.set_nodeattr("out_bias", bias) + graph_modified = True + # compute new DataType for MultiThreshold output + steps = T.shape[-1] + new_min = bias + new_max = steps + bias + odt = DataType.get_smallest_possible(steps).name.replace( + "UINT", "INT" + ) + odt = DataType[odt] + assert odt.allowed(new_max) and odt.allowed( + new_min + ), """Could + not compute new MultiThreshold DataType (min = %d max = %d)""" % ( + new_min, + new_max, + ) + mt_inst.set_nodeattr("out_dtype", odt.name) + # remove Add node, rewire MultiThreshold + graph.node.remove(add_node) + mt_node.output[0] = end_name + # set datatype + model.set_tensor_datatype(end_name, odt) + if graph_modified: + model = model.transform(InferDataTypes()) + return (model, graph_modified) + + class AbsorbAddIntoMultiThreshold(Transformation): """Absorb preceding Add ops into MultiThreshold by updating the threshold values. Only scalar/1D add vectors can be absorbed.""" @@ -292,6 +359,104 @@ class AbsorbTransposeIntoMultiThreshold(Transformation): return (model, graph_modified) +class AbsorbTransposeIntoFlatten(Transformation): + """Absorb transpose node into succeeding flatten node, if H=W=1 and the first + dimension stays the same. Can also be applied if flatten is implemented implicitly + by a reshape node with shape [1, -1] and the first input dimension is 1""" + + def apply(self, model): + graph = model.graph + graph_modified = False + node_ind = 0 + for n in graph.node: + node_ind += 1 + if ( + n.op_type == "Reshape" + and (model.get_initializer(n.input[1]) == [1, -1]).all() + ) or n.op_type == "Flatten": + prod = model.find_producer(n.input[0]) + if ( + prod is not None + and prod.op_type == "Transpose" + # we ensure that the first dimension is not changed from the + # transpose operation + and get_by_name(prod.attribute, "perm").ints[0] == 0 + ): + data_layout = model.get_tensor_layout(prod.input[0]) + # check for the data layout to interpret input shape correctly + if data_layout is None: + warnings.warn( + """Data layout for input tensor of Transpose node is not set. + To use AbsorbTransposeIntoFlatten transformation + please set tensor data layout.""" + ) + continue + elif data_layout == DataLayout.NCHW: + (b, c, h, w) = model.get_tensor_shape(prod.input[0]) + # if h=w=1 the transposition can be absorbed, otherwise + # the absorption would lead to an error in the behavior + if h != 1 or w != 1: + continue + # the flatten node from onnx keeps by default the first + # dim and flattens the rest, that is why this transformation + # can only work with b != 1 if the model contains already a + # flatten node and not a reshape node with shape = [1, -1]. + # If the first dim of the input tensor is not 1, flatten and + # reshape (with shape = [1, -1]) would lead to different results + if n.op_type == "Reshape" and b != 1: + continue + elif data_layout == DataLayout.NHWC: + (b, h, w, c) = model.get_tensor_shape(prod.input[0]) + if h != 1 or w != 1: + continue + if n.op_type == "Reshape" and b != 1: + continue + # create single flatten node and remove obsolete nodes + node = oh.make_node("Flatten", [prod.input[0]], [n.output[0]]) + graph.node.remove(n) + graph.node.remove(prod) + graph.node.insert(node_ind, node) + graph_modified = True + if graph_modified: + model = model.transform(InferDataTypes()) + return (model, graph_modified) + + +class AbsorbScalarMulIntoTopK(Transformation): + """Absorb a mul node into a suceeding topk node if the mul is scalar.""" + + def apply(self, model): + graph = model.graph + node_ind = 0 + graph_modified = False + for n in graph.node: + node_ind += 1 + if n.op_type == "TopK": + prod = model.find_producer(n.input[0]) + if prod is not None and prod.op_type == "Mul": + prod_input = prod.input[0] + param_name = prod.input[1] + A = model.get_initializer(param_name) + if A is None: + warnings.warn("Param is not constant, skipping") + continue + if all(x == 1 for x in A.shape) and A > 0: + # if the mul is scalar and positive, we can just delete the + # mul node and rewire the top k node. Because the top k node + # works with probabilities and their relation to each other + # the relation doesn't change if every value is multiplied + # with a scalar + graph.node.remove(prod) + n.input[0] = prod_input + # to avoid error the dataype is set to float32 + model.set_tensor_datatype(n.input[0], DataType.FLOAT32) + graph_modified = True + if graph_modified: + model = model.transform(InferShapes()) + model = model.transform(InferDataTypes()) + return (model, graph_modified) + + class AbsorbConsecutiveTransposes(Transformation): """Remove (Transpose -> Transpose) patterns when the input and output of the pattern have the same layout."""