From 6232549f8ffa2857166b3f2681a30fba63d8d3e7 Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <maltanar@gmail.com>
Date: Wed, 17 Jun 2020 16:55:14 +0100
Subject: [PATCH] [Transform] add more MultiThreshold checks in
 MoveMaxPoolPastMultiThreshold

---
 src/finn/transformation/streamline/reorder.py | 13 +++++++++++++
 1 file changed, 13 insertions(+)

diff --git a/src/finn/transformation/streamline/reorder.py b/src/finn/transformation/streamline/reorder.py
index 9348f917b..b46b82c77 100644
--- a/src/finn/transformation/streamline/reorder.py
+++ b/src/finn/transformation/streamline/reorder.py
@@ -34,6 +34,7 @@ from finn.transformation import Transformation
 from finn.transformation.infer_shapes import InferShapes
 from finn.core.onnx_exec import execute_node
 from finn.util.basic import get_by_name
+from finn.custom_op.registry import getCustomOp
 
 
 class MoveAddPastMul(Transformation):
@@ -559,6 +560,18 @@ class MoveMaxPoolPastMultiThreshold(Transformation):
                             "Skipping padded MaxPool + signed-output MultiThreshold"
                         )
                         continue
+                    # check for non-decreasing thresholds and nonnegative
+                    # scale factor in MultiThreshold
+                    # otherwise we cannot do the reordering
+                    T = model.get_initializer(consumer.input[1])
+                    T_sorted = np.sort(T, axis=1)
+                    assert (
+                        T == T_sorted
+                    ).all(), "MultiThreshold must have non-decreasing thresholds"
+                    mt_inst = getCustomOp(consumer)
+                    if mt_inst.get_nodeattr("out_scale") < 0:
+                        warnings.warn("Skipping MultiThreshold with negative out_scale")
+                        continue
 
                     # remove old nodes
                     graph.node.remove(n)
-- 
GitLab