Skip to content
Snippets Groups Projects
Commit 4ede818c authored by Tobi-Alonso's avatar Tobi-Alonso
Browse files

[Transformation] New transform to move linear ops past several ops

parent ca3f3446
No related branches found
No related tags found
No related merge requests found
......@@ -502,6 +502,68 @@ class MoveLinearPastEltwiseAdd(Transformation):
return (model, graph_modified)
class MoveScalarLinearPastInvariants(Transformation):
"""Move scalar linear operations (mul, add) past functions which are invariant
to them. Specifically, matches and transforms the following patterns:
f(x*C) -> f(x) * C
f(x+C) -> f(x) + C
where x is a dynamic input, C is a constant tensor.
Known f which obey this property are: Reshape, Flatten, Transpose,
GlobalAveragePool
"""
def apply(self, model):
graph = model.graph
node_ind = 0
graph_modified = False
nodes = [n for n in graph.node]
for n in nodes:
node_ind += 1
if (
n.op_type == "GlobalAveragePool"
or n.op_type == "Reshape"
or n.op_type == "Transpose"
or n.op_type == "Flatten"
):
in0 = n.input[0]
if in0 is None:
continue
# find and check producer on our input
prod0 = model.find_producer(in0)
if prod0 is None:
continue
if prod0.op_type == "Mul" or prod0.op_type == "Add":
# check if second input of producer is an initializer
init0 = model.get_initializer(prod0.input[1])
# if either initializer is None, skip
if init0 is None:
continue
# if initializer is not scalar, skip
if np.prod(init0.shape) != 1:
continue
# move prod0 from input to output,
old_prod0_in = prod0.input[0]
old_prod0_out = prod0.output[0]
old_n_out = n.output[0]
in_shape = model.get_tensor_shape(n.input[0])
out_shape = model.get_tensor_shape(n.output[0])
n.input[0] = old_prod0_in
n.output[0] = old_prod0_out
prod0.input[0] = old_prod0_out
prod0.output[0] = old_n_out
model.set_tensor_shape(n.input[0], in_shape)
model.set_tensor_shape(n.output[0], out_shape)
model.set_tensor_shape(prod0.output[0], out_shape)
graph.node.remove(prod0)
graph.node.insert(node_ind - 1, prod0)
graph_modified = True
else:
continue
model = model.transform(InferShapes())
return (model, graph_modified)
class MakeMaxPoolNHWC(Transformation):
"""Convert (MaxPool, NHWCTranpose) into (MaxPoolNHWC)."""
......@@ -685,6 +747,7 @@ class MoveMaxPoolPastMultiThreshold(Transformation):
model = model.transform(InferShapes())
return (model, graph_modified)
class MoveFlattenPastTopK(Transformation):
"""Move flatten node past a succeeding topk node, if the "axis" attribute in topk
is set to -1 and the data layout before the flatten is NHWC with H=W=1"""
......@@ -745,6 +808,7 @@ class MoveFlattenPastTopK(Transformation):
model = model.transform(InferShapes())
return (model, graph_modified)
class MoveFlattenPastAffine(Transformation):
"""Moves a node that implements a (1, -1) reshape past a MatMul, Mul or Add node."""
......@@ -831,9 +895,10 @@ class MoveFlattenPastAffine(Transformation):
model = model.transform(InferShapes())
model = model.transform(InferDataTypes())
model = model.transform(InferDataLayouts())
model = model.transform(InferDataLayouts())
return (model, graph_modified)
class MoveTransposePastScalarMul(Transformation):
"""Moves a Transpose node past a scalar Mul node"""
......@@ -895,4 +960,3 @@ class MoveTransposePastScalarMul(Transformation):
model = model.transform(InferDataLayouts())
model = model.transform(InferShapes())
return (model, graph_modified)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment