Skip to content
Snippets Groups Projects
Unverified Commit ddd52d2d authored by Yaman Umuroglu's avatar Yaman Umuroglu Committed by GitHub
Browse files

[Transform] scalar chk in MoveScalarLinearPastEltwiseAdd

parent 623266c2
No related branches found
No related tags found
No related merge requests found
......@@ -34,6 +34,8 @@ from finn.transformation.infer_shapes import InferShapes
from finn.core.onnx_exec import execute_node
from finn.util.basic import get_by_name
def is_scalar(x):
return np.prod(x.shape) == 1
class MoveAddPastMul(Transformation):
"""Move add operations past multiply operations. The aim is to have them
......@@ -326,6 +328,10 @@ class MoveScalarLinearPastEltwiseAdd(Transformation):
# if either initializer is None, skip
if init0 is None or init1 is None:
continue
# if either initializer is non-scalar, skip
# TODO relax this to 1D tensors?
if (not is_scalar(init0)) or (not is_scalar(init1)):
continue
if prod0.op_type == "Mul" and prod1.op_type == "Mul":
if np.array_equal(init0, init1):
self.move_node(graph, n, prod0, prod1, node_ind)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment