Skip to content
Snippets Groups Projects
Unverified Commit 6b04574d authored by Yaman Umuroglu's avatar Yaman Umuroglu Committed by GitHub
Browse files

Merge pull request #113 from quetric/feature/linear_reorder_for_residual

Feature/linear reorder for residual
parents 055a0d1f bede1ab1
No related branches found
No related tags found
No related merge requests found
......@@ -333,6 +333,22 @@ class ModelWrapper:
else:
return None
def is_fork_node(self, node):
"""Checks if the given node is a fork, that is, the node has multiple
direct successors"""
direct_successors = self.find_direct_successors(node)
is_fork = False if direct_successors is None else (len(direct_successors) > 1)
return is_fork
def is_join_node(self, node):
"""Checks if the given node is a join, that is, the node has multiple
direct predecessors"""
direct_predecessors = self.find_direct_predecessors(node)
is_join = (
False if direct_predecessors is None else (len(direct_predecessors) > 1)
)
return is_join
def get_all_tensor_names(self):
"""Returns a list of all (input, output and value_info) tensor names
in the graph."""
......
......@@ -36,8 +36,9 @@ from finn.util.basic import get_by_name
class MoveAddPastMul(Transformation):
"""Move add operations past multiply operations. The aim is to have them
next to each other such that they can be collapsed into a single add."""
"""Move add operations past multiply operations on linear segments of the graph.
The aim is to have them next to each other such that they can be collapsed into
a single add."""
def apply(self, model):
graph = model.graph
......@@ -45,9 +46,17 @@ class MoveAddPastMul(Transformation):
graph_modified = False
for n in graph.node:
node_ind += 1
if n.op_type == "Add":
if (
n.op_type == "Add"
and not model.is_fork_node(n)
and not model.is_join_node(n)
):
consumer = model.find_consumer(n.output[0])
if consumer is not None and consumer.op_type == "Mul":
if (
consumer is not None
and consumer.op_type == "Mul"
and not model.is_join_node(consumer)
):
# have: (x) -> add(,B) -> (x+B) -> mul(,A) -> (xA+BA)
# want: (x) -> mul(,A) -> (xA) -> add(,BA) -> (xA+BA)
# assume input 0 is from the previous layer, input 1 is the
......@@ -63,12 +72,16 @@ class MoveAddPastMul(Transformation):
end_name = consumer.output[0]
# compute new param value for add
BA = B * A
# make and insert new nodes
new_mul = oh.make_node(
"Mul", [start_name, mul_weight_name], [middle_name]
"Mul",
[start_name, mul_weight_name],
[middle_name],
name=consumer.name,
)
new_add = oh.make_node(
"Add", [middle_name, add_weight_name], [end_name]
"Add", [middle_name, add_weight_name], [end_name], name=n.name
)
graph.node.insert(node_ind, new_mul)
graph.node.insert(node_ind + 1, new_add)
......@@ -78,6 +91,7 @@ class MoveAddPastMul(Transformation):
graph.node.remove(n)
graph.node.remove(consumer)
graph_modified = True
model = model.transform(InferShapes())
return (model, graph_modified)
......@@ -92,9 +106,17 @@ class MoveScalarMulPastMatMul(Transformation):
graph_modified = False
for n in graph.node:
node_ind += 1
if n.op_type == "Mul":
if (
n.op_type == "Mul"
and not model.is_fork_node(n)
and not model.is_join_node(n)
):
consumer = model.find_consumer(n.output[0])
if consumer is not None and consumer.op_type == "MatMul":
if (
consumer is not None
and consumer.op_type == "MatMul"
and not model.is_join_node(consumer)
):
mul_weight_name = n.input[1]
matmul_weight_name = consumer.input[1]
A = model.get_initializer(mul_weight_name)
......@@ -109,10 +131,16 @@ class MoveScalarMulPastMatMul(Transformation):
# if the mul is scalar, we can simply swap the order of ops
# make and insert new nodes
new_matmul = oh.make_node(
"MatMul", [start_name, matmul_weight_name], [middle_name]
"MatMul",
[start_name, matmul_weight_name],
[middle_name],
name=consumer.name,
)
new_mul = oh.make_node(
"Mul", [middle_name, mul_weight_name], [end_name]
"Mul",
[middle_name, mul_weight_name],
[end_name],
name=n.name,
)
graph.node.insert(node_ind, new_matmul)
graph.node.insert(node_ind + 1, new_mul)
......@@ -135,9 +163,17 @@ class MoveScalarAddPastMatMul(Transformation):
graph_modified = False
for n in graph.node:
node_ind += 1
if n.op_type == "Add":
if (
n.op_type == "Add"
and not model.is_fork_node(n)
and not model.is_join_node(n)
):
consumer = model.find_consumer(n.output[0])
if consumer is not None and consumer.op_type == "MatMul":
if (
consumer is not None
and consumer.op_type == "MatMul"
and not model.is_join_node(consumer)
):
add_weight_name = n.input[1]
matmul_weight_name = consumer.input[1]
A = model.get_initializer(add_weight_name)
......@@ -155,10 +191,16 @@ class MoveScalarAddPastMatMul(Transformation):
# update the add weight
model.set_initializer(add_weight_name, Anew)
new_matmul = oh.make_node(
"MatMul", [start_name, matmul_weight_name], [middle_name]
"MatMul",
[start_name, matmul_weight_name],
[middle_name],
name=consumer.name,
)
new_add = oh.make_node(
"Add", [middle_name, add_weight_name], [end_name]
"Add",
[middle_name, add_weight_name],
[end_name],
name=n.name,
)
graph.node.insert(node_ind, new_matmul)
graph.node.insert(node_ind + 1, new_add)
......@@ -181,9 +223,17 @@ class MoveScalarAddPastConv(Transformation):
graph_modified = False
for n in graph.node:
node_ind += 1
if n.op_type == "Add":
if (
n.op_type == "Add"
and not model.is_fork_node(n)
and not model.is_join_node(n)
):
consumer = model.find_consumer(n.output[0])
if consumer is not None and consumer.op_type == "Conv":
if (
consumer is not None
and consumer.op_type == "Conv"
and not model.is_join_node(consumer)
):
conv_node = consumer
add_node = n
add_weight_name = n.input[1]
......@@ -238,9 +288,17 @@ class MoveScalarMulPastConv(Transformation):
graph_modified = False
for n in graph.node:
node_ind += 1
if n.op_type == "Mul":
if (
n.op_type == "Mul"
and not model.is_fork_node(n)
and not model.is_join_node(n)
):
consumer = model.find_consumer(n.output[0])
if consumer is not None and consumer.op_type == "Conv":
if (
consumer is not None
and consumer.op_type == "Conv"
and not model.is_join_node(consumer)
):
mul_weight_name = n.input[1]
A = model.get_initializer(mul_weight_name)
assert A is not None, "Initializer for mul weights is not set."
......
......@@ -127,3 +127,45 @@ def test_modelwrapper_graph_order():
assert model.get_node_index(Round_node) == 1
assert model.get_node_index(Ceil_node) == 2
assert model.get_node_index(Add_node) == 3
def test_modelwrapper_detect_forks_n_joins():
# create small network with properties to be tested
Neg_node = onnx.helper.make_node("Neg", inputs=["in1"], outputs=["neg1"])
Round_node = onnx.helper.make_node("Round", inputs=["neg1"], outputs=["round1"])
Ceil_node = onnx.helper.make_node("Ceil", inputs=["neg1"], outputs=["ceil1"])
Add_node = onnx.helper.make_node(
"Add", inputs=["round1", "ceil1"], outputs=["out1"]
)
in1 = onnx.helper.make_tensor_value_info("in1", onnx.TensorProto.FLOAT, [4, 4])
out1 = onnx.helper.make_tensor_value_info("out1", onnx.TensorProto.FLOAT, [4, 4])
graph = onnx.helper.make_graph(
nodes=[Neg_node, Round_node, Ceil_node, Add_node],
name="simple_graph",
inputs=[in1],
outputs=[out1],
value_info=[
onnx.helper.make_tensor_value_info("neg1", onnx.TensorProto.FLOAT, [4, 4]),
onnx.helper.make_tensor_value_info(
"round1", onnx.TensorProto.FLOAT, [4, 4]
),
onnx.helper.make_tensor_value_info("ceil1", onnx.TensorProto.FLOAT, [4, 4]),
],
)
onnx_model = onnx.helper.make_model(graph, producer_name="simple-model")
model = ModelWrapper(onnx_model)
# test
assert model.is_fork_node(Neg_node)
assert not model.is_fork_node(Round_node)
assert not model.is_fork_node(Ceil_node)
assert not model.is_fork_node(Add_node)
assert not model.is_join_node(Neg_node)
assert not model.is_join_node(Round_node)
assert not model.is_join_node(Ceil_node)
assert model.is_join_node(Add_node)
......@@ -60,6 +60,9 @@ def test_move_add_past_mul_single():
new_model = model.transform(MoveAddPastMul())
inp_dict = {"top_in": np.asarray([-1.0, 1.0], dtype=np.float32)}
assert ox.compare_execution(model, new_model, inp_dict)
assert new_model.graph.node[0].op_type == "Mul"
assert new_model.graph.node[1].op_type == "Add"
assert new_model.graph.node[0].output[0] == new_model.graph.node[1].input[0]
def test_move_add_past_mul_multi():
......@@ -92,3 +95,50 @@ def test_move_add_past_mul_multi():
new_model = model.transform(MoveAddPastMul())
inp_dict = {"top_in": np.asarray([-1.0, 1.0], dtype=np.float32)}
assert ox.compare_execution(model, new_model, inp_dict)
assert new_model.graph.node[0].op_type == "Mul"
assert new_model.graph.node[1].op_type == "Mul"
assert new_model.graph.node[2].op_type == "Add"
assert new_model.graph.node[3].op_type == "Add"
for i in range(len(new_model.graph.node) - 1):
assert new_model.graph.node[i].output[0] == new_model.graph.node[i + 1].input[0]
def test_move_add_past_mul_only_if_linear():
top_in = oh.make_tensor_value_info("top_in", TensorProto.FLOAT, [2])
top_out = oh.make_tensor_value_info("top_out", TensorProto.FLOAT, [2])
value_info = [oh.make_tensor_value_info("add1_param", TensorProto.FLOAT, [1])]
value_info += [oh.make_tensor_value_info("mul1_param", TensorProto.FLOAT, [1])]
value_info += [oh.make_tensor_value_info("mul2_param", TensorProto.FLOAT, [1])]
value_info += [oh.make_tensor_value_info("mul3_param", TensorProto.FLOAT, [1])]
modelproto = oh.make_model(
oh.make_graph(
name="test",
inputs=[top_in],
outputs=[top_out],
value_info=value_info,
nodes=[
oh.make_node("Add", ["top_in", "add1_param"], ["t1"]),
oh.make_node("Mul", ["t1", "mul1_param"], ["fork"]),
oh.make_node("Mul", ["fork", "mul2_param"], ["t3"]),
oh.make_node("Add", ["t3", "fork"], ["t4"]),
oh.make_node("Mul", ["t4", "mul3_param"], ["top_out"]),
],
)
)
model = ModelWrapper(modelproto)
model = model.transform(InferShapes())
np.random.seed(0)
model.set_initializer("add1_param", np.random.rand(2).astype(np.float32))
model.set_initializer("mul1_param", np.random.rand(2).astype(np.float32))
model.set_initializer("mul2_param", np.random.rand(2).astype(np.float32))
model.set_initializer("mul3_param", np.random.rand(2).astype(np.float32))
new_model = model.transform(MoveAddPastMul())
inp_dict = {"top_in": np.random.rand(2).astype(np.float32)}
assert ox.compare_execution(model, new_model, inp_dict)
assert new_model.graph.node[0].op_type == "Mul"
assert new_model.graph.node[1].op_type == "Add"
assert new_model.graph.node[2].op_type == "Mul"
assert new_model.graph.node[3].op_type == "Add"
assert new_model.graph.node[4].op_type == "Mul"
import numpy as np
import onnx.helper as oh
import pytest
from onnx import TensorProto
import finn.core.onnx_exec as ox
from finn.core.modelwrapper import ModelWrapper
from finn.transformation.infer_shapes import InferShapes
from finn.transformation.streamline import (
MoveScalarAddPastConv,
MoveScalarMulPastConv,
)
@pytest.mark.parametrize(
"test_args", [("Add", MoveScalarAddPastConv()), ("Mul", MoveScalarMulPastConv())],
)
def test_move_scalar_past_conv_only_if_linear(test_args):
scalar_op = test_args[0]
transf_fxn = test_args[1]
in_feature_dim = 7
in_chn = 1
padding = False
stages = 3
kernel_size = 3
out_feature_dim = (
in_feature_dim if padding else in_feature_dim - (kernel_size // 2 * 2) * stages
)
input_shape = [1, in_chn, in_feature_dim, in_feature_dim]
output_shape = [1, in_chn, out_feature_dim, out_feature_dim]
conv_param_shape = [in_chn, in_chn, kernel_size, kernel_size]
conv_config = {}
conv_config["dilations"] = [1, 1]
conv_config["group"] = 1
conv_config["kernel_shape"] = [kernel_size, kernel_size]
conv_config["pads"] = [0, 0, 0, 0]
conv_config["strides"] = [1, 1]
top_in = oh.make_tensor_value_info("top_in", TensorProto.FLOAT, input_shape)
top_out = oh.make_tensor_value_info("top_out", TensorProto.FLOAT, output_shape)
value_info = [oh.make_tensor_value_info("p1", TensorProto.FLOAT, [1])]
value_info += [oh.make_tensor_value_info("p2", TensorProto.FLOAT, conv_param_shape)]
value_info += [oh.make_tensor_value_info("p3", TensorProto.FLOAT, conv_param_shape)]
value_info += [oh.make_tensor_value_info("p4", TensorProto.FLOAT, conv_param_shape)]
value_info += [oh.make_tensor_value_info("p5", TensorProto.FLOAT, conv_param_shape)]
modelproto = oh.make_model(
oh.make_graph(
name="test",
inputs=[top_in],
outputs=[top_out],
value_info=value_info,
nodes=[
oh.make_node("Conv", ["top_in", "p2"], ["t1"], **conv_config),
oh.make_node(scalar_op, ["t1", "p1"], ["t2"]),
oh.make_node("Conv", ["t2", "p3"], ["t3"], **conv_config),
oh.make_node("Conv", ["t2", "p4"], ["t4"], **conv_config),
oh.make_node(scalar_op, ["t3", "t4"], ["t5"]),
oh.make_node("Conv", ["t5", "p5"], ["top_out"], **conv_config),
],
)
)
model = ModelWrapper(modelproto)
model = model.transform(InferShapes())
np.random.seed(0)
model.set_initializer("p1", *np.random.rand(1).astype(np.float32))
model.set_initializer("p2", np.random.rand(*conv_param_shape).astype(np.float32))
model.set_initializer("p3", np.random.rand(*conv_param_shape).astype(np.float32))
model.set_initializer("p4", np.random.rand(*conv_param_shape).astype(np.float32))
model.set_initializer("p5", np.random.rand(*conv_param_shape).astype(np.float32))
new_model = model.transform(transf_fxn)
inp_dict = {"top_in": np.random.rand(*input_shape).astype(np.float32)}
assert ox.compare_execution(model, new_model, inp_dict)
assert new_model.graph.node[0].op_type == "Conv"
assert new_model.graph.node[1].op_type == scalar_op
assert new_model.graph.node[2].op_type == "Conv"
assert new_model.graph.node[3].op_type == "Conv"
assert new_model.graph.node[4].op_type == scalar_op
assert new_model.graph.node[5].op_type == "Conv"
......@@ -27,6 +27,7 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import numpy as np
import pytest
import onnx.helper as oh
from onnx import TensorProto
......@@ -99,3 +100,56 @@ def test_move_scalar_add_past_matmul():
assert new_model.graph.node[0].op_type == "MatMul"
assert new_model.graph.node[1].op_type == "Add"
assert new_model.graph.node[0].output[0] == new_model.graph.node[1].input[0]
@pytest.mark.parametrize(
"test_args",
[("Add", MoveScalarAddPastMatMul()), ("Mul", MoveScalarMulPastMatMul())],
)
def test_move_scalar_past_matmul_only_if_linear(test_args):
scalar_op = test_args[0]
transf_fxn = test_args[1]
input_shape = [1, 2]
matmul_shape = [2, 2]
top_in = oh.make_tensor_value_info("top_in", TensorProto.FLOAT, input_shape)
top_out = oh.make_tensor_value_info("top_out", TensorProto.FLOAT, input_shape)
p1 = oh.make_tensor_value_info("p1", TensorProto.FLOAT, [1, 1])
p2 = oh.make_tensor_value_info("p2", TensorProto.FLOAT, matmul_shape)
p3 = oh.make_tensor_value_info("p3", TensorProto.FLOAT, matmul_shape)
p4 = oh.make_tensor_value_info("p4", TensorProto.FLOAT, matmul_shape)
modelproto = oh.make_model(
oh.make_graph(
name="test",
inputs=[top_in],
outputs=[top_out],
value_info=[p1, p2, p3, p4],
nodes=[
oh.make_node(scalar_op, ["top_in", "p1"], ["t1"]),
oh.make_node("MatMul", ["t1", "p2"], ["fork"]),
oh.make_node("MatMul", ["fork", "p3"], ["t3"]),
oh.make_node(scalar_op, ["t3", "fork"], ["t4"]),
oh.make_node("MatMul", ["t4", "p4"], ["top_out"]),
],
)
)
model = ModelWrapper(modelproto)
model = model.transform(InferShapes())
np.random.seed(0)
model.set_initializer("p1", np.random.rand(1, 1).astype(np.float32))
model.set_initializer("p2", np.random.rand(*matmul_shape).astype(np.float32))
model.set_initializer("p3", np.random.rand(*matmul_shape).astype(np.float32))
model.set_initializer("p4", np.random.rand(*matmul_shape).astype(np.float32))
# Transform
new_model = model.transform(transf_fxn)
# Test
inp_dict = {"top_in": np.random.rand(*input_shape).astype(np.float32)}
assert ox.compare_execution(model, new_model, inp_dict)
assert new_model.graph.node[0].op_type == "MatMul"
assert new_model.graph.node[1].op_type == scalar_op
assert new_model.graph.node[2].op_type == "MatMul"
assert new_model.graph.node[3].op_type == scalar_op
assert new_model.graph.node[4].op_type == "MatMul"
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment