diff --git a/src/finn/transformation/streamline/reorder.py b/src/finn/transformation/streamline/reorder.py index 7d192aebfd4a7a8c52c479cf060f4d539797934d..3fe1ef242364b1070726681495ae3632660f2904 100644 --- a/src/finn/transformation/streamline/reorder.py +++ b/src/finn/transformation/streamline/reorder.py @@ -27,6 +27,7 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import numpy as np +import warnings from onnx import helper as oh from finn.transformation import Transformation @@ -546,16 +547,10 @@ class MoveMaxPoolPastMultiThreshold(Transformation): if n.op_type == "MaxPool" and not model.is_fork_node(n): consumer = model.find_consumer(n.output[0]) if consumer is not None and consumer.op_type == "MultiThreshold": - is_signed = True - for attr in consumer.attribute: - if ( - attr.name == "out_dtype" - and len(attr.s) >= 5 - and attr.s[:4] == b"UINT" - ): - is_signed = False - - if is_signed: + mt_out = consumer.output[0] + mt_odt = model.get_tensor_datatype(mt_out) + if mt_odt.signed(): + warnings.warn("Skipping signed-output MultiThreshold") continue # remove old nodes diff --git a/tests/transformation/test_move_maxpool_past_multithreshold.py b/tests/transformation/test_move_maxpool_past_multithreshold.py index 0378d956f4bc3160cd93e4a1e27151c71cd40a89..0a40c83f4bd3c35e5bf8d31917b2019479b67268 100644 --- a/tests/transformation/test_move_maxpool_past_multithreshold.py +++ b/tests/transformation/test_move_maxpool_past_multithreshold.py @@ -5,6 +5,7 @@ import finn.core.onnx_exec as oxe from finn.core.modelwrapper import ModelWrapper from finn.transformation.streamline.reorder import MoveMaxPoolPastMultiThreshold from finn.transformation.infer_shapes import InferShapes +from finn.transformation.infer_datatypes import InferDataTypes def get_multithreshold_rand_params(channels, num_of_thres, seed=None): @@ -79,6 +80,7 @@ def test_move_maxpool_past_multithreshold(): ) model = ModelWrapper(modelproto) model = model.transform(InferShapes()) + model = model.transform(InferDataTypes()) model.set_initializer("thres1", np.array([[0]])) model.set_initializer(