Skip to content
Snippets Groups Projects
Commit 92665b53 authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Transform] use DataType.signed() in MoveMaxPoolPastMultiThreshold

added dt inference to test case
parent 53cd8cff
No related branches found
No related tags found
No related merge requests found
......@@ -27,6 +27,7 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import numpy as np
import warnings
from onnx import helper as oh
from finn.transformation import Transformation
......@@ -546,16 +547,10 @@ class MoveMaxPoolPastMultiThreshold(Transformation):
if n.op_type == "MaxPool" and not model.is_fork_node(n):
consumer = model.find_consumer(n.output[0])
if consumer is not None and consumer.op_type == "MultiThreshold":
is_signed = True
for attr in consumer.attribute:
if (
attr.name == "out_dtype"
and len(attr.s) >= 5
and attr.s[:4] == b"UINT"
):
is_signed = False
if is_signed:
mt_out = consumer.output[0]
mt_odt = model.get_tensor_datatype(mt_out)
if mt_odt.signed():
warnings.warn("Skipping signed-output MultiThreshold")
continue
# remove old nodes
......
......@@ -5,6 +5,7 @@ import finn.core.onnx_exec as oxe
from finn.core.modelwrapper import ModelWrapper
from finn.transformation.streamline.reorder import MoveMaxPoolPastMultiThreshold
from finn.transformation.infer_shapes import InferShapes
from finn.transformation.infer_datatypes import InferDataTypes
def get_multithreshold_rand_params(channels, num_of_thres, seed=None):
......@@ -79,6 +80,7 @@ def test_move_maxpool_past_multithreshold():
)
model = ModelWrapper(modelproto)
model = model.transform(InferShapes())
model = model.transform(InferDataTypes())
model.set_initializer("thres1", np.array([[0]]))
model.set_initializer(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment