From 7faf316b42e382c8490459fdbddd64aafdb65851 Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <yamanu@xilinx.com>
Date: Tue, 31 Jan 2023 09:34:04 +0000
Subject: [PATCH] [DWC] handle indivisible widths via LCM-sized intermediate
 stream

InsertDWC now uses hls mode, and vivado mode is only used for
8-bit divisible stream widths
---
 .../streamingdatawidthconverter_batch.py      | 79 ++++++++++++++-----
 .../transformation/fpgadataflow/insert_dwc.py |  9 ++-
 2 files changed, 64 insertions(+), 24 deletions(-)

diff --git a/src/finn/custom_op/fpgadataflow/streamingdatawidthconverter_batch.py b/src/finn/custom_op/fpgadataflow/streamingdatawidthconverter_batch.py
index 33ba6450f..940b4f4ab 100644
--- a/src/finn/custom_op/fpgadataflow/streamingdatawidthconverter_batch.py
+++ b/src/finn/custom_op/fpgadataflow/streamingdatawidthconverter_batch.py
@@ -78,24 +78,33 @@ class StreamingDataWidthConverter_Batch(HLSCustomOp):
 
     def check_divisible_iowidths(self):
         impl_style = self.get_nodeattr("impl_style")
-        if impl_style == "hls":
-            # when using impl_style = hls must have the following
-            # if inWidth > outWidth: inWidth % outWidth = 0
-            # if inWidth < outWidth: outWidth % inWidth = 0
-            iwidth = self.get_nodeattr("inWidth")
-            owidth = self.get_nodeattr("outWidth")
-            if iwidth > owidth:
-                assert (
-                    iwidth % owidth == 0
-                ), """DWC InWidth is bigger than OutWidth and is not divisible by it.
-                Please adjust PE and SIMD values so that InWidth % OutWidth = 0
-                or alternatively use impl_style = vivado"""
-            else:
-                assert (
-                    owidth % iwidth == 0
-                ), """DWC OutWidth is bigger than InWidth and is not divisible by it.
-                Please adjust PE and SIMD values so that OutWidth % InWidth = 0
-                or alternatively use impl_style = vivado"""
+        iwidth = self.get_nodeattr("inWidth")
+        owidth = self.get_nodeattr("outWidth")
+        if impl_style == "vivado":
+            # the AXIS IP we use in vivado mode only supports
+            # stream widths that are divisible by 8
+            iwidth_d8 = iwidth % 8 == 0
+            owidth_d8 = owidth % 8 == 0
+            assert (
+                iwidth_d8 and owidth_d8
+            ), """DWC impl_style=vivado requires
+            stream widths that are divisible by 8: (%d, %d)""" % (
+                iwidth,
+                owidth,
+            )
+
+    def get_iowidth_lcm(self):
+        iwidth = self.get_nodeattr("inWidth")
+        owidth = self.get_nodeattr("outWidth")
+        return int(np.lcm(iwidth, owidth))
+
+    def needs_lcm(self):
+        iwidth = self.get_nodeattr("inWidth")
+        owidth = self.get_nodeattr("outWidth")
+        maxwidth = max(iwidth, owidth)
+        minwidth = min(iwidth, owidth)
+        impl_style = self.get_nodeattr("impl_style")
+        return (impl_style == "hls") and (maxwidth % minwidth != 0)
 
     def get_folded_input_shape(self, ind=0):
         self.check_divisible_iowidths()
@@ -202,6 +211,16 @@ class StreamingDataWidthConverter_Batch(HLSCustomOp):
             "#define NumInWords %d " % numInWords,
             "#define numReps %d" % numReps,
         ]
+        if self.needs_lcm():
+            lcmWidth = self.get_iowidth_lcm()
+            assert (
+                numInWords % (lcmWidth / inWidth) == 0
+            ), "Error in DWC LCM calculation"
+            numLCMToOut = numInWords // (lcmWidth / inWidth)
+            self.code_gen_dict["$DEFINES$"].append("#define LCMWidth %d" % lcmWidth)
+            self.code_gen_dict["$DEFINES$"].append(
+                "#define NumLCMToOut %d" % (numLCMToOut)
+            )
 
     def read_npy_data(self):
         code_gen_dir = self.get_nodeattr("code_gen_dir_cppsim")
@@ -226,6 +245,12 @@ class StreamingDataWidthConverter_Batch(HLSCustomOp):
         self.code_gen_dict["$STREAMDECLARATIONS$"].append(
             'hls::stream<ap_uint<{}>> in0 ("in0");'.format(self.get_instream_width())
         )
+        if self.needs_lcm():
+            self.code_gen_dict["$STREAMDECLARATIONS$"].append(
+                'hls::stream<ap_uint<{}>> intermediate ("intermediate");'.format(
+                    self.get_iowidth_lcm()
+                )
+            )
         self.code_gen_dict["$STREAMDECLARATIONS$"].append(
             'hls::stream<ap_uint<{}>> out ("out");'.format(self.get_outstream_width())
         )
@@ -233,9 +258,19 @@ class StreamingDataWidthConverter_Batch(HLSCustomOp):
     def docompute(self):
         # TODO continue with fxns below, they are copy-pasted
         op = "StreamingDataWidthConverter_Batch"
-        self.code_gen_dict["$DOCOMPUTE$"] = [
-            "%s<InWidth, OutWidth, NumInWords>(in0, out, numReps);" % (op)
-        ]
+        if self.needs_lcm():
+            self.code_gen_dict["$DOCOMPUTE$"] = [
+                'hls::stream<ap_uint<{}>> intermediate ("intermediate");'.format(
+                    self.get_iowidth_lcm()
+                ),
+                "%s<InWidth, LCMWidth, NumInWords>(in0, intermediate, numReps);" % (op),
+                "%s<LCMWidth, OutWidth, NumLCMToOut>(intermediate, out, numReps);"
+                % (op),
+            ]
+        else:
+            self.code_gen_dict["$DOCOMPUTE$"] = [
+                "%s<InWidth, OutWidth, NumInWords>(in0, out, numReps);" % (op)
+            ]
 
     def dataoutstrm(self):
         code_gen_dir = self.get_nodeattr("code_gen_dir_cppsim")
@@ -287,6 +322,8 @@ class StreamingDataWidthConverter_Batch(HLSCustomOp):
         self.code_gen_dict["$PRAGMAS$"].append(
             "#pragma HLS INTERFACE ap_ctrl_none port=return"
         )
+        if self.needs_lcm():
+            self.code_gen_dict["$PRAGMAS$"].append("#pragma HLS DATAFLOW")
 
     def execute_node(self, context, graph):
         mode = self.get_nodeattr("exec_mode")
diff --git a/src/finn/transformation/fpgadataflow/insert_dwc.py b/src/finn/transformation/fpgadataflow/insert_dwc.py
index efc179923..632d1f813 100644
--- a/src/finn/transformation/fpgadataflow/insert_dwc.py
+++ b/src/finn/transformation/fpgadataflow/insert_dwc.py
@@ -83,10 +83,13 @@ class InsertDWC(Transformation):
                             dwc_out_width = n1.get_instream_width()
                             larger_width = max(dwc_in_width, dwc_out_width)
                             smaller_width = min(dwc_in_width, dwc_out_width)
-                            if larger_width % smaller_width == 0:
-                                impl_style = "hls"
-                            else:
+                            both_8bit_aligned = (larger_width % 8 == 0) and (
+                                smaller_width % 8 == 0
+                            )
+                            if both_8bit_aligned:
                                 impl_style = "vivado"
+                            else:
+                                impl_style = "hls"
 
                             # determine shape for dwc
                             dwc_shape = n0.get_normal_output_shape()
-- 
GitLab