Skip to content
Snippets Groups Projects
Unverified Commit 24022404 authored by auphelia's avatar auphelia Committed by GitHub
Browse files

Merge pull request #127 from quetric/feature/MoveOpPastFork_transform

Feature/move op past fork transform
parents 603b8bad e31bc457
No related branches found
No related tags found
No related merge requests found
......@@ -444,3 +444,90 @@ class MakeMaxPoolNHWC(Transformation):
graph.node.insert(node_ind - 1, consumer)
graph_modified = True
return (model, graph_modified)
class MoveOpPastFork(Transformation):
"""Move node operations past graph forks. Used when a node before a fork
can be merged with nodes in the branches
"""
def __init__(self, op_name_list):
super().__init__()
self.ops_to_move = op_name_list
def apply(self, model):
graph = model.graph
graph_modified = False
nodes = [n for n in graph.node]
node_ind = 0
for n in nodes:
node_ind += 1
if (
n.op_type in self.ops_to_move
and model.is_fork_node(n)
and not model.is_join_node(n)
):
# Restrict this transform to operations with constant parameters
# Assuming parameters is in input 1
op_init_param = model.get_initializer(n.input[1])
if op_init_param is None:
continue
# Check case when branches are empty and go
# to the same node
consumers = model.find_consumers(n.output[0])
unique_consumer = True
for consum_node in consumers[1:]:
if consumers[0] != consum_node:
unique_consumer = False
break
if unique_consumer:
continue
for consumer_node in consumers[1:]:
# create new node
new_param_name = model.make_new_valueinfo_name()
new_output_tensor_name = model.make_new_valueinfo_name()
new_node = oh.make_node(
n.op_type,
[n.input[0], new_param_name],
[new_output_tensor_name],
)
graph.node.insert(node_ind, new_node)
node_ind += 1
model.set_initializer(new_param_name, op_init_param)
# change consumer input tensor
graph.node.remove(consumer_node)
for idx, consumer_input in enumerate(consumer_node.input):
if consumer_input == n.output[0]:
consumer_node.input[idx] = new_output_tensor_name
break
else:
raise Exception(
"Consumer should have the current node output as input"
)
graph.node.insert(node_ind, consumer_node)
graph_modified = True
model = model.transform(InferShapes())
return (model, graph_modified)
class MoveAddPastFork(MoveOpPastFork):
def __init__(self):
super().__init__(["Add"])
class MoveMulPastFork(MoveOpPastFork):
def __init__(self):
super().__init__(["Mul"])
class MoveLinearPastFork(MoveOpPastFork):
def __init__(self):
super().__init__(["Add", "Mul"])
from onnx import TensorProto, helper
import numpy as np
import finn.core.onnx_exec as oxe
from finn.core.modelwrapper import ModelWrapper
from finn.transformation.streamline.reorder import MoveLinearPastFork
from finn.transformation.infer_shapes import InferShapes
import pytest
@pytest.mark.parametrize("ch", [64, 1])
# ifmdim
@pytest.mark.parametrize("ifmdim", [-1, 7])
def test_move_past_fork(ch, ifmdim):
# generate test vectors of correct shape
if ifmdim == -1:
input_shape = (1, ch)
else:
input_shape = (1, ch, ifmdim, ifmdim)
top_in = helper.make_tensor_value_info("top_in", TensorProto.FLOAT, input_shape)
top_out = helper.make_tensor_value_info("top_out", TensorProto.FLOAT, input_shape)
num_of_params = 8
value_info = []
for i in range(num_of_params):
value_info += [
helper.make_tensor_value_info("p" + str(i), TensorProto.FLOAT, input_shape)
]
add_1_to_move = helper.make_node("Add", ["top_in", "p0"], ["fork1"])
mul_1_to_move = helper.make_node("Mul", ["t5", "p4"], ["fork2"])
add_2_to_move = helper.make_node("Add", ["fork2", "p5"], ["t6"])
mul_1_not_to_move = helper.make_node("Mul", ["t8", "p7"], ["fork3"])
modelproto = helper.make_model(
helper.make_graph(
name="test",
inputs=[top_in],
outputs=[top_out],
value_info=value_info,
nodes=[
# fork1
add_1_to_move,
helper.make_node("Mul", ["fork1", "p1"], ["t2"]),
helper.make_node("Mul", ["fork1", "p2"], ["t3"]),
helper.make_node("Add", ["t2", "t3"], ["t4"]),
helper.make_node("Add", ["t4", "p3"], ["t5"]),
# fork2
mul_1_to_move,
add_2_to_move,
helper.make_node("Add", ["fork2", "p6"], ["t7"]),
helper.make_node("Add", ["t6", "t7"], ["t8"]),
# empty branches: do nothing
mul_1_not_to_move,
helper.make_node("Add", ["fork3", "fork3"], ["top_out"]),
],
)
)
model = ModelWrapper(modelproto)
model = model.transform(InferShapes())
np.random.seed(0)
for i in range(num_of_params):
model.set_initializer(
"p" + str(i), np.random.rand(*input_shape).astype(np.float32)
)
# Transform
new_model = model.transform(MoveLinearPastFork())
inp_dict = {"top_in": np.random.rand(*input_shape).astype(np.float32)}
# Test
assert oxe.compare_execution(model, new_model, inp_dict)
assert not new_model.is_fork_node(add_1_to_move)
assert not new_model.is_fork_node(mul_1_to_move)
assert not new_model.is_fork_node(add_2_to_move)
assert new_model.is_fork_node(mul_1_not_to_move)
assert len(new_model.graph.node) == 14
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment