Skip to content
Snippets Groups Projects
Unverified Commit 9ddad020 authored by Yaman Umuroglu's avatar Yaman Umuroglu Committed by GitHub
Browse files

Merge pull request #120 from quetric/feature/fix_streamline_past_eltwise_add

Feature/fix streamline past eltwise add
parents 4062647c 18067458
No related branches found
No related tags found
No related merge requests found
......@@ -34,8 +34,6 @@ from finn.transformation.infer_shapes import InferShapes
from finn.core.onnx_exec import execute_node
from finn.util.basic import get_by_name
def is_scalar(x):
return np.prod(x.shape) == 1
class MoveAddPastMul(Transformation):
"""Move add operations past multiply operations. The aim is to have them
......@@ -273,12 +271,12 @@ class MoveScalarMulPastConv(Transformation):
return (model, graph_modified)
class MoveScalarLinearPastEltwiseAdd(Transformation):
"""Move scalar linear operations (mul, add) past elementwise add operations where possible. Specifically,
matches and transforms the following patterns:
class MoveLinearPastEltwiseAdd(Transformation):
"""Move linear operations (mul, add) past elementwise add operations where possible.
Specifically,matches and transforms the following patterns:
(x*C) + (y*C) -> (x + y) * C
(x+A) + (y+B) -> (x + y) + (A + B)
where x and y are dynamic inputs, A, B, C are constants.
where x and y are dynamic inputs, A, B, C are constant tensors (in general).
"""
def move_node(self, graph, n, prod0, prod1, node_ind):
......@@ -305,7 +303,8 @@ class MoveScalarLinearPastEltwiseAdd(Transformation):
graph = model.graph
node_ind = 0
graph_modified = False
for n in graph.node:
nodes = [n for n in graph.node]
for n in nodes:
node_ind += 1
if n.op_type == "Add":
# check for tensors on both inputs (eltwise add)
......@@ -321,17 +320,16 @@ class MoveScalarLinearPastEltwiseAdd(Transformation):
# check for mul with same initializer on both inputs
prod0 = model.find_producer(in0)
prod1 = model.find_producer(in1)
if prod0 is None or prod1 is None:
# Also check case when both branches are empty and come
# from the same node: (prod0 == prod1)
# Other transform should handle that
if prod0 is None or prod1 is None or (prod0 == prod1):
continue
init0 = model.get_initializer(prod0.input[1])
init1 = model.get_initializer(prod1.input[1])
# if either initializer is None, skip
if init0 is None or init1 is None:
continue
# if either initializer is non-scalar, skip
# TODO relax this to 1D tensors?
if (not is_scalar(init0)) or (not is_scalar(init1)):
continue
if prod0.op_type == "Mul" and prod1.op_type == "Mul":
if np.array_equal(init0, init1):
self.move_node(graph, n, prod0, prod1, node_ind)
......
......@@ -35,7 +35,7 @@ import finn.core.onnx_exec as oxe
from finn.core.modelwrapper import ModelWrapper
from finn.transformation.fold_constants import FoldConstants
from finn.transformation.general import GiveReadableTensorNames, GiveUniqueNodeNames
from finn.transformation.streamline.reorder import MoveScalarLinearPastEltwiseAdd
from finn.transformation.streamline.reorder import MoveLinearPastEltwiseAdd
from finn.transformation.infer_shapes import InferShapes
from finn.transformation.double_to_single_float import DoubleToSingleFloat
......@@ -95,7 +95,7 @@ def make_model(shape):
@pytest.mark.parametrize("ch", [64])
# ifmdim
@pytest.mark.parametrize("ifmdim", [-1, 7])
def test_scalar_past_eltwise(ch, ifmdim):
def test_linear_past_eltwise_add(ch, ifmdim):
# generate test vectors of correct shape
if ifmdim == -1:
input_tensor_shape = (1, ch)
......@@ -124,7 +124,7 @@ def test_scalar_past_eltwise(ch, ifmdim):
assert len(model.get_nodes_by_op_type("Add")) == 3
assert len(model.get_nodes_by_op_type("Mul")) == 2
model = model.transform(MoveScalarLinearPastEltwiseAdd())
model = model.transform(MoveLinearPastEltwiseAdd())
# verify again, to check we didnt break anything
output_dict = oxe.execute_onnx(model, input_dict, True)
......@@ -134,3 +134,68 @@ def test_scalar_past_eltwise(ch, ifmdim):
assert len(model.get_nodes_by_op_type("Mul")) == 1
os.remove(export_onnx_path)
@pytest.mark.parametrize("ch", [64, 1])
# ifmdim
@pytest.mark.parametrize("ifmdim", [-1, 7])
def test_linear_past_eltwise_add_multiple_forks(ch, ifmdim):
# generate test vectors of correct shape
if ifmdim == -1:
input_shape = (1, ch)
else:
input_shape = (1, ch, ifmdim, ifmdim)
top_in = helper.make_tensor_value_info("top_in", TensorProto.FLOAT, input_shape)
top_out = helper.make_tensor_value_info("top_out", TensorProto.FLOAT, input_shape)
num_of_params = 6
value_info = []
for i in range(num_of_params):
value_info += [
helper.make_tensor_value_info("p" + str(i), TensorProto.FLOAT, input_shape)
]
modelproto = helper.make_model(
helper.make_graph(
name="test",
inputs=[top_in],
outputs=[top_out],
value_info=value_info,
nodes=[
helper.make_node("Add", ["top_in", "p0"], ["fork1"]),
helper.make_node("Mul", ["fork1", "p1"], ["t2"]),
helper.make_node("Mul", ["fork1", "p2"], ["t3"]),
helper.make_node("Add", ["t2", "t3"], ["t4"]),
helper.make_node("Mul", ["t4", "p3"], ["fork2"]),
helper.make_node("Add", ["fork2", "p4"], ["t5"]),
helper.make_node("Add", ["fork2", "p5"], ["t6"]),
helper.make_node("Add", ["t5", "t6"], ["top_out"]),
],
)
)
model = ModelWrapper(modelproto)
model = model.transform(InferShapes())
np.random.seed(0)
for i in range(num_of_params):
model.set_initializer(
"p" + str(i), np.random.rand(*input_shape).astype(np.float32)
)
# need equal mults:
model.set_initializer("p2", model.get_initializer("p1"))
# Transform
new_model = model.transform(MoveLinearPastEltwiseAdd())
inp_dict = {"top_in": np.random.rand(*input_shape).astype(np.float32)}
# Test
assert oxe.compare_execution(model, new_model, inp_dict)
assert new_model.graph.node[0].op_type == "Add"
assert new_model.graph.node[1].op_type == "Add"
assert new_model.graph.node[2].op_type == "Mul"
assert new_model.graph.node[3].op_type == "Mul"
assert new_model.graph.node[4].op_type == "Add"
assert new_model.graph.node[5].op_type == "Add"
assert len(new_model.graph.node) == 6
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment