From 82a1a1e0a1217596f7f206d00a487008692f6079 Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <maltanar@gmail.com>
Date: Mon, 17 May 2021 11:40:09 +0100
Subject: [PATCH] [Conv] add assertions for square k/img/stride and dilation of
 1

---
 .../fpgadataflow/convolutioninputgenerator.py | 23 +++++++++++--------
 1 file changed, 14 insertions(+), 9 deletions(-)

diff --git a/src/finn/custom_op/fpgadataflow/convolutioninputgenerator.py b/src/finn/custom_op/fpgadataflow/convolutioninputgenerator.py
index a97267e7b..dc583cf90 100644
--- a/src/finn/custom_op/fpgadataflow/convolutioninputgenerator.py
+++ b/src/finn/custom_op/fpgadataflow/convolutioninputgenerator.py
@@ -67,7 +67,8 @@ class ConvolutionInputGenerator(HLSCustomOp):
             "OFMDim": ("ints", True, []),  # [H, W] = [Y, X]
             "SIMD": ("i", True, 0),
             "Stride": ("ints", True, [1, 1]),  # [H, W] = [Y, X]
-            "Dilation": ("ints", True, [1, 1]),  # [H, W] = [Y, X]
+            # note: only dilation=1 supported for now
+            "Dilation": ("ints", True, [1, 1], {[1, 1]}),  # [H, W] = [Y, X]
             # FINN DataTypes for inputs, weights, outputs
             "inputDataType": ("s", True, ""),
             "outputDataType": ("s", True, ""),
@@ -87,6 +88,17 @@ class ConvolutionInputGenerator(HLSCustomOp):
         my_attrs.update(super().get_nodeattr_types())
         return my_attrs
 
+    def get_nodeattr(self, name):
+        # overriding get_nodeattr to check for square kernel/img.. requirement
+        # since this can't be done with the attribute restriction in nodeattr_types
+        # TODO non-square can be enabled in theory but needs testing
+        ret = super().get_nodeattr(name)
+        props_to_check = ["ConvKernelDim", "IFMDim", "OFMDim", "Stride", "Dilation"]
+        if name in props_to_check:
+            is_square = ret[0] == ret[1]
+            assert is_square, "Only square %s supported" % name
+        return ret
+
     def get_normal_input_shape(self):
         ifm_dim_h, ifm_dim_w = self.get_nodeattr("IFMDim")
         ifm_ch = self.get_nodeattr("IFMChannels")
@@ -384,14 +396,7 @@ class ConvolutionInputGenerator(HLSCustomOp):
             #define Input_precision1 {}\n #define IFMDim1 {}\n
             #define OFMDim1 {}\n #define SIMD1 {}\n
             #define Stride1 {}\n #define numReps {}""".format(
-                k,
-                ifm_ch,
-                ifm_precision,
-                ifm_dim,
-                ofm_dim,
-                simd,
-                stride,
-                numReps,
+                k, ifm_ch, ifm_precision, ifm_dim, ofm_dim, simd, stride, numReps
             )
         ]
 
-- 
GitLab