From b9ac818352efa713d15b101c073e1ade7e57e470 Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <maltanar@gmail.com>
Date: Tue, 26 Jul 2022 10:08:22 +0200
Subject: [PATCH] [Downsample] add 1D downsample support + conversion

---
 .../custom_op/fpgadataflow/downsampler.py     | 31 ++++++++++++++++---
 .../fpgadataflow/convert_to_hls_layers.py     | 23 ++++++++------
 2 files changed, 40 insertions(+), 14 deletions(-)

diff --git a/src/finn/custom_op/fpgadataflow/downsampler.py b/src/finn/custom_op/fpgadataflow/downsampler.py
index b8645e04e..e9009e185 100644
--- a/src/finn/custom_op/fpgadataflow/downsampler.py
+++ b/src/finn/custom_op/fpgadataflow/downsampler.py
@@ -55,6 +55,10 @@ class DownSampler(HLSCustomOp):
             "inputDataType": ("s", True, ""),
             # Batch size
             "numInputVectors": ("i", False, 1),
+            # 1D (True) or 2D (False) spatial data
+            "is1D": ("i", False, 0),
+            # for 1D only: (D, 1) (True) or (1, D) dims
+            "is1D_unitx": ("i", False, 1),
         }
         my_attrs.update(super().get_nodeattr_types())
         return my_attrs
@@ -66,25 +70,43 @@ class DownSampler(HLSCustomOp):
         return int(np.floor((idim - 1) / stride) + 1)
 
     def get_exp_cycles(self):
+        is_1D = self.get_nodeattr("is1D")
         idim = self.get_nodeattr("ImgDim")
+        idim_total = idim if is_1D else idim * idim
         channels = self.get_nodeattr("NumChannels")
         simd = self.get_nodeattr("SIMD")
         batch_size = self.get_nodeattr("numInputVectors")
-        exp_cycles = channels / simd * batch_size * idim * idim
+        exp_cycles = channels / simd * batch_size * idim_total
         return int(exp_cycles)
 
     def get_normal_input_shape(self):
+        is_1D = self.get_nodeattr("is1D")
+        is_1D_unitx = self.get_nodeattr("is1D_unitx")
         idim = self.get_nodeattr("ImgDim")
         num_ch = self.get_nodeattr("NumChannels")
         batch = self.get_nodeattr("numInputVectors")
-        ishape = (batch, idim, idim, num_ch)
+        if is_1D:
+            if is_1D_unitx:
+                ishape = (batch, idim, 1, num_ch)
+            else:
+                ishape = (batch, 1, idim, num_ch)
+        else:
+            ishape = (batch, idim, idim, num_ch)
         return ishape
 
     def get_normal_output_shape(self):
+        is_1D = self.get_nodeattr("is1D")
+        is_1D_unitx = self.get_nodeattr("is1D_unitx")
         odim = self.get_downsampled_odim()
         num_ch = self.get_nodeattr("NumChannels")
         batch = self.get_nodeattr("numInputVectors")
-        oshape = (batch, odim, odim, num_ch)
+        if is_1D:
+            if is_1D_unitx:
+                oshape = (batch, odim, 1, num_ch)
+            else:
+                oshape = (batch, 1, odim, num_ch)
+        else:
+            oshape = (batch, odim, odim, num_ch)
         return oshape
 
     def get_folded_input_shape(self):
@@ -204,8 +226,9 @@ class DownSampler(HLSCustomOp):
         )
 
     def docompute(self):
+        dim_var = "1D" if (self.get_nodeattr("is1D") == 1) else "2D"
         self.code_gen_dict["$DOCOMPUTE$"] = [
-            """ConvolutionInputGenerator_2D_kernel1<IFMChannels, Input_precision,
+            f"""ConvolutionInputGenerator_{dim_var}_kernel1<IFMChannels, Input_precision,
             IFMDim, SIMD,Stride> (in0, out, numReps);"""
         ]
 
diff --git a/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py b/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py
index 429bc34ff..9059e023a 100644
--- a/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py
+++ b/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py
@@ -138,17 +138,18 @@ class InferConvInpGen(Transformation):
                 )
 
                 if (stride_h > 1 or stride_w > 1) and is_kernel_pointwise:
-                    assert is_square_image, (
-                        "%s : DownSampler currently only supports square input images."
-                        % n.name
-                    )
-                    assert is_equal_stride, (
-                        """%s : DownSampler currently only supports equal stride value
-                        along different axes."""
-                        % n.name
+                    downsample_1D = (ifm_dim_h == 1) or (ifm_dim_w == 1)
+                    is1D_unitx = ifm_dim_w == 1
+                    downsample_2D = (
+                        (not downsample_1D) and is_square_image and is_equal_stride
                     )
-                    ConvInpGen_idim = ConvInpGen_idim_h
-                    stride = stride_h
+                    if not (downsample_1D or downsample_2D):
+                        warnings.warn(
+                            f"Couldn't infer Downsample from {n.name}, check config."
+                        )
+                        continue
+                    ConvInpGen_idim = max(ConvInpGen_idim_h, ConvInpGen_idim_w)
+                    stride = max(stride_h, stride_w)
                     # create DownSampler node
                     ConvInpGen_node = helper.make_node(
                         "DownSampler",
@@ -162,6 +163,8 @@ class InferConvInpGen(Transformation):
                         Stride=stride,
                         inputDataType=dt.name,
                         name="DownSampler_" + n.name,
+                        is1D=downsample_1D,
+                        is1D_unitx=is1D_unitx,
                     )
                     graph.node.insert(ConvInpGen_node_idx, ConvInpGen_node)
                 else:
-- 
GitLab