Skip to content
Snippets Groups Projects
Commit 6232549f authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Transform] add more MultiThreshold checks in MoveMaxPoolPastMultiThreshold

parent d6c04a6c
No related branches found
No related tags found
No related merge requests found
......@@ -34,6 +34,7 @@ from finn.transformation import Transformation
from finn.transformation.infer_shapes import InferShapes
from finn.core.onnx_exec import execute_node
from finn.util.basic import get_by_name
from finn.custom_op.registry import getCustomOp
class MoveAddPastMul(Transformation):
......@@ -559,6 +560,18 @@ class MoveMaxPoolPastMultiThreshold(Transformation):
"Skipping padded MaxPool + signed-output MultiThreshold"
)
continue
# check for non-decreasing thresholds and nonnegative
# scale factor in MultiThreshold
# otherwise we cannot do the reordering
T = model.get_initializer(consumer.input[1])
T_sorted = np.sort(T, axis=1)
assert (
T == T_sorted
).all(), "MultiThreshold must have non-decreasing thresholds"
mt_inst = getCustomOp(consumer)
if mt_inst.get_nodeattr("out_scale") < 0:
warnings.warn("Skipping MultiThreshold with negative out_scale")
continue
# remove old nodes
graph.node.remove(n)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment