Skip to content
Snippets Groups Projects
Commit 28cd2b05 authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Transform] check MaxPool padding before skipping

parent 92665b53
No related branches found
No related tags found
No related merge requests found
......@@ -546,11 +546,18 @@ class MoveMaxPoolPastMultiThreshold(Transformation):
node_ind += 1
if n.op_type == "MaxPool" and not model.is_fork_node(n):
consumer = model.find_consumer(n.output[0])
pads = get_by_name(n.attribute, "pads")
has_padding = False
if pads is not None:
pads = list(pads.ints)
has_padding = np.prod(pads) != 0
if consumer is not None and consumer.op_type == "MultiThreshold":
mt_out = consumer.output[0]
mt_odt = model.get_tensor_datatype(mt_out)
if mt_odt.signed():
warnings.warn("Skipping signed-output MultiThreshold")
if mt_odt.signed() and has_padding:
warnings.warn(
"Skipping padded MaxPool + signed-output MultiThreshold"
)
continue
# remove old nodes
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment