From 6232549f8ffa2857166b3f2681a30fba63d8d3e7 Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu <maltanar@gmail.com> Date: Wed, 17 Jun 2020 16:55:14 +0100 Subject: [PATCH] [Transform] add more MultiThreshold checks in MoveMaxPoolPastMultiThreshold --- src/finn/transformation/streamline/reorder.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/finn/transformation/streamline/reorder.py b/src/finn/transformation/streamline/reorder.py index 9348f917b..b46b82c77 100644 --- a/src/finn/transformation/streamline/reorder.py +++ b/src/finn/transformation/streamline/reorder.py @@ -34,6 +34,7 @@ from finn.transformation import Transformation from finn.transformation.infer_shapes import InferShapes from finn.core.onnx_exec import execute_node from finn.util.basic import get_by_name +from finn.custom_op.registry import getCustomOp class MoveAddPastMul(Transformation): @@ -559,6 +560,18 @@ class MoveMaxPoolPastMultiThreshold(Transformation): "Skipping padded MaxPool + signed-output MultiThreshold" ) continue + # check for non-decreasing thresholds and nonnegative + # scale factor in MultiThreshold + # otherwise we cannot do the reordering + T = model.get_initializer(consumer.input[1]) + T_sorted = np.sort(T, axis=1) + assert ( + T == T_sorted + ).all(), "MultiThreshold must have non-decreasing thresholds" + mt_inst = getCustomOp(consumer) + if mt_inst.get_nodeattr("out_scale") < 0: + warnings.warn("Skipping MultiThreshold with negative out_scale") + continue # remove old nodes graph.node.remove(n) -- GitLab