Skip to content
Snippets Groups Projects
Commit 53cd8cff authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

Merge branch 'feature/MoveMaxPoolPastMultiThreshold' of...

Merge branch 'feature/MoveMaxPoolPastMultiThreshold' of https://github.com/quetric/finn into quetric-feature/MoveMaxPoolPastMultiThreshold
parents 78cccb62 5d14a19b
No related branches found
No related tags found
No related merge requests found
......@@ -531,3 +531,54 @@ class MoveMulPastFork(MoveOpPastFork):
class MoveLinearPastFork(MoveOpPastFork):
def __init__(self):
super().__init__(["Add", "Mul"])
class MoveMaxPoolPastMultiThreshold(Transformation):
"""Move MaxPool nodes past MultiThreshold nodes on linear segments of the graph."""
def apply(self, model):
graph = model.graph
node_ind = 0
graph_modified = False
nodes = [n for n in graph.node]
for n in nodes:
node_ind += 1
if n.op_type == "MaxPool" and not model.is_fork_node(n):
consumer = model.find_consumer(n.output[0])
if consumer is not None and consumer.op_type == "MultiThreshold":
is_signed = True
for attr in consumer.attribute:
if (
attr.name == "out_dtype"
and len(attr.s) >= 5
and attr.s[:4] == b"UINT"
):
is_signed = False
if is_signed:
continue
# remove old nodes
graph.node.remove(n)
graph.node.remove(consumer)
# swap conections
group_in = n.input[0]
# new tensor because dims change
group_middle = model.make_new_valueinfo_name()
group_out = consumer.output[0]
consumer.input[0] = group_in
consumer.output[0] = group_middle
n.input[0] = group_middle
n.output[0] = group_out
# insert them back in
graph.node.insert(node_ind - 1, consumer)
graph.node.insert(node_ind, n)
graph_modified = True
model = model.transform(InferShapes())
return (model, graph_modified)
from onnx import TensorProto, helper
import numpy as np
import finn.core.onnx_exec as oxe
from finn.core.modelwrapper import ModelWrapper
from finn.transformation.streamline.reorder import MoveMaxPoolPastMultiThreshold
from finn.transformation.infer_shapes import InferShapes
def get_multithreshold_rand_params(channels, num_of_thres, seed=None):
if seed is not None:
np.random.seed(seed)
steps = np.random.rand(channels, 1) * 2
bias = np.random.rand(channels, 1) * 10
thres = [np.arange(num_of_thres) for chn in range(channels)]
thres = ((thres - bias) * steps).astype(np.float32)
return thres
def test_move_maxpool_past_multithreshold():
# generate test vectors of correct shape
ch = 64
ifmdim = 16
ofmdim = 16 // 4
input_shape = (1, ch, ifmdim, ifmdim)
output_shape = (1, ch, ofmdim, ofmdim)
top_in = helper.make_tensor_value_info("top_in", TensorProto.FLOAT, input_shape)
top_out = helper.make_tensor_value_info("top_out", TensorProto.FLOAT, output_shape)
maxpool_config = {}
maxpool_config["pads"] = [1, 1, 1, 1]
maxpool_config["kernel_shape"] = [3, 3]
maxpool_config["strides"] = [2, 2]
value_info = []
thres1_shape = [1, 1]
value_info += [
helper.make_tensor_value_info("thres1", TensorProto.FLOAT, thres1_shape)
]
thres2_shape = [ch, 14]
value_info += [
helper.make_tensor_value_info("thres2", TensorProto.FLOAT, thres2_shape)
]
nodes = []
nodes += [helper.make_node("MaxPool", ["top_in"], ["t1"], **maxpool_config)]
nodes += [
helper.make_node(
"MultiThreshold",
["t1", "thres1"],
["t2"],
domain="finn",
out_dtype="BIPOLAR",
out_bias=-1.0,
out_scale_f=1.0,
)
]
nodes += [helper.make_node("MaxPool", ["t2"], ["t3"], **maxpool_config)]
nodes += [
helper.make_node(
"MultiThreshold",
["t3", "thres2"],
["top_out"],
domain="finn",
out_dtype="UINT4",
)
]
modelproto = helper.make_model(
helper.make_graph(
name="test",
inputs=[top_in],
outputs=[top_out],
value_info=value_info,
nodes=nodes,
)
)
model = ModelWrapper(modelproto)
model = model.transform(InferShapes())
model.set_initializer("thres1", np.array([[0]]))
model.set_initializer(
"thres2", get_multithreshold_rand_params(*thres2_shape, seed=0)
)
# Transform
new_model = model.transform(MoveMaxPoolPastMultiThreshold())
inp_dict = {"top_in": np.random.rand(*input_shape).astype(np.float32)}
# Test
assert oxe.compare_execution(model, new_model, inp_dict)
assert new_model.graph.node[0].op_type == "MaxPool"
assert new_model.graph.node[1].op_type == "MultiThreshold"
assert new_model.graph.node[2].op_type == "MultiThreshold"
assert new_model.graph.node[3].op_type == "MaxPool"
assert len(new_model.graph.node) == 4
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment