Skip to content
Snippets Groups Projects
Commit e9aa670f authored by auphelia's avatar auphelia
Browse files

[Streamline] Add MoveFlattenPastAffine operation

parent cf54f617
No related branches found
No related tags found
No related merge requests found
......@@ -32,6 +32,7 @@ from onnx import helper as oh
from finn.transformation import Transformation
from finn.transformation.infer_shapes import InferShapes
from finn.transformation.infer_datatypes import InferDataTypes
from finn.core.onnx_exec import execute_node
from finn.util.basic import get_by_name
from finn.custom_op.registry import getCustomOp
......@@ -597,3 +598,74 @@ class MoveMaxPoolPastMultiThreshold(Transformation):
model = model.transform(InferShapes())
return (model, graph_modified)
class MoveFlattenPastAffine(Transformation):
"""Moves a node that implements a (1, -1) reshape past a MatMul, Mul or Add node."""
def apply(self, model):
graph = model.graph
graph_modified = False
node_ind = 0
for n in graph.node:
node_ind += 1
if (
n.op_type == "Flatten"
and not model.is_fork_node(n)
and not model.is_join_node(n)
):
consumer = model.find_consumer(n.output[0])
if (
consumer is not None
and (
consumer.op_type == "MatMul"
or consumer.op_type == "Mul"
or consumer.op_type == "Add"
)
and not model.is_join_node(consumer)
):
# move flatten past operation and rewire tensors
start_name = n.input[0]
middle_name = n.output[0]
end_name = consumer.output[0]
op_param_name = consumer.input[1]
A = model.get_initializer(op_param_name)
if A is None:
warnings.warn("Param is not constant, skipping")
continue
op_in_dt = model.get_tensor_datatype(consumer.input[0])
op_out_dt = model.get_tensor_datatype(consumer.output[0])
start_shape = model.get_tensor_shape(start_name)
dummy_in = np.random.uniform(low=0, high=1, size=(start_shape))
if consumer.op_type == "MatMul":
dummy_out = np.matmul(dummy_in, A)
elif consumer.op_type == "Mul":
dummy_out = dummy_in * A
elif consumer.op_type == "Add":
dummy_out = dummy_in + A
new_op = oh.make_node(
consumer.op_type,
[start_name, op_param_name],
[middle_name],
name=consumer.name,
)
new_flatten = oh.make_node("Flatten", [middle_name], [end_name])
graph.node.insert(node_ind, new_op)
graph.node.insert(node_ind + 1, new_flatten)
model.set_tensor_shape(middle_name, dummy_out.shape)
# because a flatten node doesn't change the datatype we need
# only the datatype of the op node
model.set_tensor_datatype(start_name, op_in_dt)
model.set_tensor_datatype(middle_name, op_out_dt)
model.set_tensor_datatype(end_name, op_out_dt)
# remove old nodes
graph.node.remove(n)
graph.node.remove(consumer)
graph_modified = True
model = model.transform(InferShapes())
model = model.transform(InferDataTypes())
return (model, graph_modified)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment