From 7f0d7f0b864833b6d189748ec50eb7fe34ba9d15 Mon Sep 17 00:00:00 2001
From: Tobi-Alonso <tobi.alonso@gmail.com>
Date: Wed, 1 Jul 2020 10:55:37 +0100
Subject: [PATCH] [FPGADataflow] Change order of checks in InferPool_Batch

---
 .../fpgadataflow/convert_to_hls_layers.py     | 23 ++++++++-----------
 1 file changed, 9 insertions(+), 14 deletions(-)

diff --git a/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py b/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py
index 26e7c2e77..cb3b1dc3b 100644
--- a/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py
+++ b/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py
@@ -199,16 +199,6 @@ class InferPool_Batch(Transformation):
                     k = inst.get_nodeattr("kernel")
                     stride = inst.get_nodeattr("stride")
 
-                if k < stride:
-                    continue
-                elif k == stride:
-                    warnings.warn(
-                        """Inferring Pool_Batch node for k == stride.
-                        This case can be optimized.
-                        For example, for MaxPool run InferStreamingMaxPool before
-                        InferPool_Batch """
-                    )
-
                 try:
                     pad = get_by_name(n.attribute, "pads").ints[-1]
                 except AttributeError:
@@ -221,10 +211,15 @@ class InferPool_Batch(Transformation):
                 if not idt.is_integer():
                     continue
 
-                # if idt.signed() and n.op_type == "MaxPool":
-                #     # No support for signed input (see accu initialization
-                #     # in Pool_batch HLSLIB function)
-                #     continue
+                if k < stride:
+                    continue
+                elif k == stride:
+                    warnings.warn(
+                        """Inferring Pool_Batch node for k == stride.
+                        This case can be optimized.
+                        For example, for MaxPool run InferStreamingMaxPool before
+                        InferPool_Batch """
+                    )
 
                 odt = model.get_tensor_datatype(node_output)
 
-- 
GitLab