Skip to content
Snippets Groups Projects
Commit 92665b53 authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Transform] use DataType.signed() in MoveMaxPoolPastMultiThreshold

added dt inference to test case
parent 53cd8cff
No related branches found
No related tags found
No related merge requests found
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import numpy as np import numpy as np
import warnings
from onnx import helper as oh from onnx import helper as oh
from finn.transformation import Transformation from finn.transformation import Transformation
...@@ -546,16 +547,10 @@ class MoveMaxPoolPastMultiThreshold(Transformation): ...@@ -546,16 +547,10 @@ class MoveMaxPoolPastMultiThreshold(Transformation):
if n.op_type == "MaxPool" and not model.is_fork_node(n): if n.op_type == "MaxPool" and not model.is_fork_node(n):
consumer = model.find_consumer(n.output[0]) consumer = model.find_consumer(n.output[0])
if consumer is not None and consumer.op_type == "MultiThreshold": if consumer is not None and consumer.op_type == "MultiThreshold":
is_signed = True mt_out = consumer.output[0]
for attr in consumer.attribute: mt_odt = model.get_tensor_datatype(mt_out)
if ( if mt_odt.signed():
attr.name == "out_dtype" warnings.warn("Skipping signed-output MultiThreshold")
and len(attr.s) >= 5
and attr.s[:4] == b"UINT"
):
is_signed = False
if is_signed:
continue continue
# remove old nodes # remove old nodes
......
...@@ -5,6 +5,7 @@ import finn.core.onnx_exec as oxe ...@@ -5,6 +5,7 @@ import finn.core.onnx_exec as oxe
from finn.core.modelwrapper import ModelWrapper from finn.core.modelwrapper import ModelWrapper
from finn.transformation.streamline.reorder import MoveMaxPoolPastMultiThreshold from finn.transformation.streamline.reorder import MoveMaxPoolPastMultiThreshold
from finn.transformation.infer_shapes import InferShapes from finn.transformation.infer_shapes import InferShapes
from finn.transformation.infer_datatypes import InferDataTypes
def get_multithreshold_rand_params(channels, num_of_thres, seed=None): def get_multithreshold_rand_params(channels, num_of_thres, seed=None):
...@@ -79,6 +80,7 @@ def test_move_maxpool_past_multithreshold(): ...@@ -79,6 +80,7 @@ def test_move_maxpool_past_multithreshold():
) )
model = ModelWrapper(modelproto) model = ModelWrapper(modelproto)
model = model.transform(InferShapes()) model = model.transform(InferShapes())
model = model.transform(InferDataTypes())
model.set_initializer("thres1", np.array([[0]])) model.set_initializer("thres1", np.array([[0]]))
model.set_initializer( model.set_initializer(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment