Skip to content
Snippets Groups Projects
Commit 7f0d7f0b authored by Tobi-Alonso's avatar Tobi-Alonso
Browse files

[FPGADataflow] Change order of checks in InferPool_Batch

parent a81fc1e6
No related branches found
No related tags found
No related merge requests found
...@@ -199,16 +199,6 @@ class InferPool_Batch(Transformation): ...@@ -199,16 +199,6 @@ class InferPool_Batch(Transformation):
k = inst.get_nodeattr("kernel") k = inst.get_nodeattr("kernel")
stride = inst.get_nodeattr("stride") stride = inst.get_nodeattr("stride")
if k < stride:
continue
elif k == stride:
warnings.warn(
"""Inferring Pool_Batch node for k == stride.
This case can be optimized.
For example, for MaxPool run InferStreamingMaxPool before
InferPool_Batch """
)
try: try:
pad = get_by_name(n.attribute, "pads").ints[-1] pad = get_by_name(n.attribute, "pads").ints[-1]
except AttributeError: except AttributeError:
...@@ -221,10 +211,15 @@ class InferPool_Batch(Transformation): ...@@ -221,10 +211,15 @@ class InferPool_Batch(Transformation):
if not idt.is_integer(): if not idt.is_integer():
continue continue
# if idt.signed() and n.op_type == "MaxPool": if k < stride:
# # No support for signed input (see accu initialization continue
# # in Pool_batch HLSLIB function) elif k == stride:
# continue warnings.warn(
"""Inferring Pool_Batch node for k == stride.
This case can be optimized.
For example, for MaxPool run InferStreamingMaxPool before
InferPool_Batch """
)
odt = model.get_tensor_datatype(node_output) odt = model.get_tensor_datatype(node_output)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment