Skip to content
Snippets Groups Projects
Unverified Commit 603b8bad authored by Yaman Umuroglu's avatar Yaman Umuroglu Committed by GitHub
Browse files

Merge pull request #115 from quetric/feature/linear_collapse_repeated

Feature/linear collapse repeated
parents 1cada3f9 e09d6759
No related branches found
No related tags found
No related merge requests found
...@@ -48,9 +48,17 @@ class CollapseRepeatedOp(Transformation): ...@@ -48,9 +48,17 @@ class CollapseRepeatedOp(Transformation):
graph_modified = False graph_modified = False
for n in graph.node: for n in graph.node:
node_ind += 1 node_ind += 1
if n.op_type == self.op_name: if (
n.op_type == self.op_name
and not model.is_fork_node(n)
and not model.is_join_node(n)
):
consumer = model.find_consumer(n.output[0]) consumer = model.find_consumer(n.output[0])
if consumer is not None and consumer.op_type == self.op_name: if (
consumer is not None
and consumer.op_type == self.op_name
and not model.is_join_node(consumer)
):
op0_param_name = n.input[1] op0_param_name = n.input[1]
op1_param_name = consumer.input[1] op1_param_name = consumer.input[1]
op0_param = model.get_initializer(op0_param_name) op0_param = model.get_initializer(op0_param_name)
......
...@@ -34,6 +34,7 @@ import finn.core.onnx_exec as ox ...@@ -34,6 +34,7 @@ import finn.core.onnx_exec as ox
from finn.core.modelwrapper import ModelWrapper from finn.core.modelwrapper import ModelWrapper
from finn.transformation.infer_shapes import InferShapes from finn.transformation.infer_shapes import InferShapes
from finn.transformation.streamline import CollapseRepeatedAdd, CollapseRepeatedMul from finn.transformation.streamline import CollapseRepeatedAdd, CollapseRepeatedMul
import pytest
def test_collapse_repeated_op(): def test_collapse_repeated_op():
...@@ -67,3 +68,60 @@ def test_collapse_repeated_op(): ...@@ -67,3 +68,60 @@ def test_collapse_repeated_op():
new_model = new_model.transform(CollapseRepeatedMul()) new_model = new_model.transform(CollapseRepeatedMul())
inp_dict = {"top_in": np.asarray([-1.0, 1.0], dtype=np.float32)} inp_dict = {"top_in": np.asarray([-1.0, 1.0], dtype=np.float32)}
assert ox.compare_execution(model, new_model, inp_dict) assert ox.compare_execution(model, new_model, inp_dict)
assert len(new_model.graph.node) == 2
assert new_model.graph.node[0].op_type == "Add"
assert new_model.graph.node[1].op_type == "Mul"
@pytest.mark.parametrize(
"test_args", [("Add", CollapseRepeatedAdd()), ("Mul", CollapseRepeatedMul())],
)
def test_collapse_repeated_only_if_linear(test_args):
scalar_op = test_args[0]
transf_fxn = test_args[1]
input_shape = [4, 4]
output_shape = input_shape
top_in = oh.make_tensor_value_info("top_in", TensorProto.FLOAT, input_shape)
top_out = oh.make_tensor_value_info("top_out", TensorProto.FLOAT, output_shape)
value_info = [oh.make_tensor_value_info("p1", TensorProto.FLOAT, [1])]
value_info += [oh.make_tensor_value_info("p2", TensorProto.FLOAT, [1])]
value_info += [oh.make_tensor_value_info("p3", TensorProto.FLOAT, [1])]
value_info += [oh.make_tensor_value_info("p4", TensorProto.FLOAT, [1])]
value_info += [oh.make_tensor_value_info("p5", TensorProto.FLOAT, [1])]
modelproto = oh.make_model(
oh.make_graph(
name="test",
inputs=[top_in],
outputs=[top_out],
value_info=value_info,
nodes=[
oh.make_node(scalar_op, ["top_in", "p2"], ["t1"]),
oh.make_node(scalar_op, ["t1", "p1"], ["t2"]),
oh.make_node(scalar_op, ["t2", "p3"], ["t3"]),
oh.make_node(scalar_op, ["t2", "p4"], ["t4"]),
oh.make_node(scalar_op, ["t3", "t4"], ["t5"]),
oh.make_node(scalar_op, ["t5", "p5"], ["top_out"]),
],
)
)
model = ModelWrapper(modelproto)
model = model.transform(InferShapes())
np.random.seed(0)
model.set_initializer("p1", *np.random.rand(1).astype(np.float32))
model.set_initializer("p2", *np.random.rand(1).astype(np.float32))
model.set_initializer("p3", *np.random.rand(1).astype(np.float32))
model.set_initializer("p4", *np.random.rand(1).astype(np.float32))
model.set_initializer("p5", *np.random.rand(1).astype(np.float32))
# Transform
new_model = model.transform(transf_fxn)
# Test
inp_dict = {"top_in": np.random.rand(*input_shape).astype(np.float32)}
assert ox.compare_execution(model, new_model, inp_dict)
assert len(new_model.graph.node) == 5
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment