From d54cb43cb66f4d5b913eefe791325cfc9d3d1c7e Mon Sep 17 00:00:00 2001 From: auphelia <jakobapk@web.de> Date: Thu, 25 Jun 2020 15:29:10 +0100 Subject: [PATCH] [Streamline] Add AbsorbScalarMulIntoTopK transformation --- src/finn/transformation/streamline/absorb.py | 37 ++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/src/finn/transformation/streamline/absorb.py b/src/finn/transformation/streamline/absorb.py index dbcf97361..4266488c7 100644 --- a/src/finn/transformation/streamline/absorb.py +++ b/src/finn/transformation/streamline/absorb.py @@ -28,11 +28,13 @@ import numpy as np from onnx import helper as oh +import warnings from finn.core.datatype import DataType from finn.transformation import Transformation from finn.util.basic import get_by_name from finn.custom_op.registry import getCustomOp +from finn.transformation.infer_shapes import InferShapes from finn.transformation.infer_datatypes import InferDataTypes @@ -290,3 +292,38 @@ class AbsorbTransposeIntoMultiThreshold(Transformation): if graph_modified: model = model.transform(InferDataTypes()) return (model, graph_modified) + + +class AbsorbScalarMulIntoTopK(Transformation): + """Absorb a mul node into a suceeding topk node if the mul is scalar.""" + + def apply(self, model): + graph = model.graph + node_ind = 0 + graph_modified = False + for n in graph.node: + node_ind += 1 + if n.op_type == "TopK": + prod = model.find_producer(n.input[0]) + if prod is not None and prod.op_type == "Mul": + prod_input = prod.input[0] + param_name = prod.input[1] + A = model.get_initializer(param_name) + if A is None: + warnings.warn("Param is not constant, skipping") + continue + if all(x == 1 for x in A.shape) and A > 0: + # if the mul is scalar and positive, we can just delete the + # mul node and rewire the top k node. Because the top k node + # works with probabilities and their relation to each other + # the relation doesn't change if every value is multiplied + # with a scalar + graph.node.remove(prod) + n.input[0] = prod_input + # to avoid error the dataype is set to float32 + model.set_tensor_datatype(n.input[0], DataType.FLOAT32) + graph_modified = True + if graph_modified: + model = model.transform(InferShapes()) + model = model.transform(InferDataTypes()) + return (model, graph_modified) -- GitLab