From a8d1b5c199e0fc5a169001f98adcedc0ba917c33 Mon Sep 17 00:00:00 2001
From: auphelia <jakobapk@web.de>
Date: Tue, 10 Dec 2019 12:47:05 +0000
Subject: [PATCH] [Execution and test] Added verification (and test) for
 StreamingMaxPool_Batch

---
 .../custom_op/verify_custom_op_construct.py   | 27 +++++++++++++++----
 tests/test_verify_custom_ops.py               | 22 ++++++++++++++-
 2 files changed, 43 insertions(+), 6 deletions(-)

diff --git a/src/finn/custom_op/verify_custom_op_construct.py b/src/finn/custom_op/verify_custom_op_construct.py
index 07d47a959..42ce2e6e2 100644
--- a/src/finn/custom_op/verify_custom_op_construct.py
+++ b/src/finn/custom_op/verify_custom_op_construct.py
@@ -4,7 +4,7 @@ from finn.core.utils import get_by_name
 class CustomOp_Construct(Enum):
     MultiThreshold = auto()
     XnorPopcountMatMul = auto()
-    #StreamingMaxPool_Batch = auto()
+    StreamingMaxPool_Batch = auto()
     StreamingFCLayer_Batch = auto()
 
     def verify_construct(self, node):
@@ -20,10 +20,13 @@ class CustomOp_Construct(Enum):
             num_of_attr = 3
         elif self.name == "XnorPopcountMatMul":
             num_of_attr = 0
-        #elif self.name == "StreamingMaxPool_Batch":
-            #num_of_attr =
+        elif self.name == "StreamingMaxPool_Batch":
+            num_of_attr = 6
         elif self.name == "StreamingFCLayer_Batch":
             num_of_attr = 14
+        else:
+            Exception("CustomOp {} is not yet in the verification".format(node.op_type))
+
         if len(node.attribute) == num_of_attr:
             return True
         else:
@@ -57,7 +60,17 @@ class CustomOp_Construct(Enum):
 
         elif self.name == "XnorPopcountMatMul":
             return True
-        #elif self.name == "StreamingMaxPool_Batch":
+        elif self.name == "StreamingMaxPool_Batch":
+            try:
+                get_by_name(node.attribute, "code_gen_dir")
+                get_by_name(node.attribute, "executable_path")
+                get_by_name(node.attribute, "ImgDim")
+                get_by_name(node.attribute, "PoolDim")
+                get_by_name(node.attribute, "NumChannels")
+                return True
+            except:
+                return False
+
         elif self.name == "StreamingFCLayer_Batch":
             try:
                 get_by_name(node.attribute, "code_gen_dir")
@@ -89,7 +102,11 @@ class CustomOp_Construct(Enum):
                 return True
             else:
                 return False
-        #elif self.name == "StreamingMaxPool_Batch":
+        elif self.name == "StreamingMaxPool_Batch":
+            if len(node.input) == 1:
+                return True
+            else:
+                return False
         elif self.name == "StreamingFCLayer_Batch":
             # check noActivation value to determine the number of inputs
             no_act = get_by_name(node.attribute, "noActivation")
diff --git a/tests/test_verify_custom_ops.py b/tests/test_verify_custom_ops.py
index c05278b1d..dd54c5312 100644
--- a/tests/test_verify_custom_ops.py
+++ b/tests/test_verify_custom_ops.py
@@ -24,6 +24,25 @@ def test_verify_layout_custom_ops():
     inst = CustomOp_Construct[xnor_node.op_type]
     inst.verify_construct(xnor_node)
 
+    # StreamingMaxPool_Batch
+    MaxPool_batch_node = helper.make_node(
+        "StreamingMaxPool_Batch",
+        ["in"],
+        ["out"],
+        domain="finn",
+        backend="fpgadataflow",
+        code_gen_dir="",
+        executable_path="",
+        ImgDim=4,
+        PoolDim=2,
+        NumChannels=2,
+    )
+
+    inst = CustomOp_Construct[MaxPool_batch_node.op_type]
+    inst.verify_construct(MaxPool_batch_node)
+    
+    
+    
     # StreamingFCLayer_Batch - no activation
     FCLayer_node = helper.make_node(
         "StreamingFCLayer_Batch",
@@ -43,7 +62,8 @@ def test_verify_layout_custom_ops():
         outputDataType="<FINN DataType>",
         ActVal=0,
         binaryXnorMode=1,
-        noActivation=1
+        noActivation=1,
+    )
     
     inst = CustomOp_Construct[FCLayer_node.op_type]
     inst.verify_construct(FCLayer_node)
-- 
GitLab