Skip to content
Snippets Groups Projects
Commit d54cb43c authored by auphelia's avatar auphelia
Browse files

[Streamline] Add AbsorbScalarMulIntoTopK transformation

parent cf54f617
No related branches found
No related tags found
No related merge requests found
......@@ -28,11 +28,13 @@
import numpy as np
from onnx import helper as oh
import warnings
from finn.core.datatype import DataType
from finn.transformation import Transformation
from finn.util.basic import get_by_name
from finn.custom_op.registry import getCustomOp
from finn.transformation.infer_shapes import InferShapes
from finn.transformation.infer_datatypes import InferDataTypes
......@@ -290,3 +292,38 @@ class AbsorbTransposeIntoMultiThreshold(Transformation):
if graph_modified:
model = model.transform(InferDataTypes())
return (model, graph_modified)
class AbsorbScalarMulIntoTopK(Transformation):
"""Absorb a mul node into a suceeding topk node if the mul is scalar."""
def apply(self, model):
graph = model.graph
node_ind = 0
graph_modified = False
for n in graph.node:
node_ind += 1
if n.op_type == "TopK":
prod = model.find_producer(n.input[0])
if prod is not None and prod.op_type == "Mul":
prod_input = prod.input[0]
param_name = prod.input[1]
A = model.get_initializer(param_name)
if A is None:
warnings.warn("Param is not constant, skipping")
continue
if all(x == 1 for x in A.shape) and A > 0:
# if the mul is scalar and positive, we can just delete the
# mul node and rewire the top k node. Because the top k node
# works with probabilities and their relation to each other
# the relation doesn't change if every value is multiplied
# with a scalar
graph.node.remove(prod)
n.input[0] = prod_input
# to avoid error the dataype is set to float32
model.set_tensor_datatype(n.input[0], DataType.FLOAT32)
graph_modified = True
if graph_modified:
model = model.transform(InferShapes())
model = model.transform(InferDataTypes())
return (model, graph_modified)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment