From 7f0d7f0b864833b6d189748ec50eb7fe34ba9d15 Mon Sep 17 00:00:00 2001 From: Tobi-Alonso <tobi.alonso@gmail.com> Date: Wed, 1 Jul 2020 10:55:37 +0100 Subject: [PATCH] [FPGADataflow] Change order of checks in InferPool_Batch --- .../fpgadataflow/convert_to_hls_layers.py | 23 ++++++++----------- 1 file changed, 9 insertions(+), 14 deletions(-) diff --git a/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py b/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py index 26e7c2e77..cb3b1dc3b 100644 --- a/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py +++ b/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py @@ -199,16 +199,6 @@ class InferPool_Batch(Transformation): k = inst.get_nodeattr("kernel") stride = inst.get_nodeattr("stride") - if k < stride: - continue - elif k == stride: - warnings.warn( - """Inferring Pool_Batch node for k == stride. - This case can be optimized. - For example, for MaxPool run InferStreamingMaxPool before - InferPool_Batch """ - ) - try: pad = get_by_name(n.attribute, "pads").ints[-1] except AttributeError: @@ -221,10 +211,15 @@ class InferPool_Batch(Transformation): if not idt.is_integer(): continue - # if idt.signed() and n.op_type == "MaxPool": - # # No support for signed input (see accu initialization - # # in Pool_batch HLSLIB function) - # continue + if k < stride: + continue + elif k == stride: + warnings.warn( + """Inferring Pool_Batch node for k == stride. + This case can be optimized. + For example, for MaxPool run InferStreamingMaxPool before + InferPool_Batch """ + ) odt = model.get_tensor_datatype(node_output) -- GitLab