From 206f4002821062427212cd4071ace1776347de1e Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <maltanar@gmail.com>
Date: Fri, 22 Jul 2022 19:25:39 +0200
Subject: [PATCH] [VVAU] binaryXnorMode support (untested)

---
 .../fpgadataflow/vectorvectoractivation.py    | 57 +++++++++++++++++--
 1 file changed, 52 insertions(+), 5 deletions(-)

diff --git a/src/finn/custom_op/fpgadataflow/vectorvectoractivation.py b/src/finn/custom_op/fpgadataflow/vectorvectoractivation.py
index 9d0a9ee52..24cb3101f 100644
--- a/src/finn/custom_op/fpgadataflow/vectorvectoractivation.py
+++ b/src/finn/custom_op/fpgadataflow/vectorvectoractivation.py
@@ -96,6 +96,9 @@ class VectorVectorActivation(HLSCustomOp):
                 "auto",
                 {"auto", "block", "distributed", "ultra"},
             ),
+            # use xnor-popcount for binary weights/inputs, thus treating them
+            # as bipolar
+            "binaryXnorMode": ("i", False, 0, {0, 1}),
         }
         my_attrs.update(super().get_nodeattr_types())
         return my_attrs
@@ -289,13 +292,31 @@ class VectorVectorActivation(HLSCustomOp):
         ret = dict()
         inp_hls_str = self.get_input_datatype().get_hls_datatype_str()
         out_hls_str = self.get_output_datatype().get_hls_datatype_str()
+        inp_is_binary = self.get_input_datatype() == DataType["BINARY"]
+        # out_is_binary = self.get_output_datatype() == DataType["BINARY"]
+        wt_is_binary = self.get_weight_datatype() == DataType["BINARY"]
+        bin_xnor_mode = self.get_nodeattr("binaryXnorMode") == 1
+        if (inp_is_binary or wt_is_binary) and (not bin_xnor_mode):
+            raise Exception("True binary (non-bipolar) inputs not yet supported")
         inp_is_bipolar = self.get_input_datatype() == DataType["BIPOLAR"]
+        # out_is_bipolar = self.get_output_datatype() == DataType["BIPOLAR"]
         wt_is_bipolar = self.get_weight_datatype() == DataType["BIPOLAR"]
+        # reinterpret inp/wt as bipolar if bin_xnor_mode is iset
+        inp_is_bipolar = inp_is_bipolar or (inp_is_binary and bin_xnor_mode)
+        wt_is_bipolar = wt_is_bipolar or (wt_is_binary and bin_xnor_mode)
         # fill in TSrcI and TWeightI
-        # TODO handle bipolar inputs
-        if inp_is_bipolar or wt_is_bipolar:
-            raise Exception("VVAU node doesn't support bipolar values yet.")
-        else:
+        # TODO check these with Giulio
+        # TODO handle non-bipolar binary inputs
+        if inp_is_bipolar and wt_is_bipolar:
+            ret["TSrcI"] = "Recast<XnorMul>"
+            ret["TWeightI"] = "Identity"
+        elif (not inp_is_bipolar) and wt_is_bipolar:
+            ret["TSrcI"] = "Slice<%s>" % inp_hls_str
+            ret["TWeightI"] = "Recast<Binary>"
+        elif inp_is_bipolar and (not wt_is_bipolar):
+            ret["TSrcI"] = "Recast<Binary>"
+            ret["TWeightI"] = "Identity"
+        elif (not inp_is_bipolar) and (not wt_is_bipolar):
             ret["TSrcI"] = "Slice<%s>" % inp_hls_str
             ret["TWeightI"] = "Identity"
 
@@ -324,6 +345,13 @@ class VectorVectorActivation(HLSCustomOp):
         return ret
 
     def get_hls_compatible_threshold_tensor(self, orig_thres_matrix):
+        """Convert the original numpy weight matrix orig_weight_matrix into
+        a form suitable for passing to the hlslib call:
+        * ensure MH % PE == 0
+        * for bipolar weights&inputs, ensure thresholds are positive
+        * interleave rows between PEs
+        * reshape into (PE, TMEM, n_thres_steps) and return
+        """
         ch = self.get_nodeattr("Channels")
         pe = self.get_nodeattr("PE")
         tmem = self.calc_tmem()
@@ -333,14 +361,33 @@ class VectorVectorActivation(HLSCustomOp):
         ), """Threshold matrix dimension is
         not as expected (2)."""
         n_thres_steps = orig_thres_matrix.shape[1]
+        inp_is_bipolar = self.get_input_datatype() == DataType["BIPOLAR"]
+        wt_is_bipolar = self.get_weight_datatype() == DataType["BIPOLAR"]
+        # reinterpret inp/wt as bipolar if bin_xnor_mode is iset
+        inp_is_binary = self.get_input_datatype() == DataType["BINARY"]
+        wt_is_binary = self.get_weight_datatype() == DataType["BINARY"]
+        bin_xnor_mode = self.get_nodeattr("binaryXnorMode") == 1
+        inp_is_bipolar = inp_is_bipolar or (inp_is_binary and bin_xnor_mode)
+        wt_is_bipolar = wt_is_bipolar or (wt_is_binary and bin_xnor_mode)
+        if inp_is_bipolar and wt_is_bipolar:
+            # ensure all thresholds are nonnegative
+            assert (orig_thres_matrix >= 0).all()
+            # ensure all thresholds are integer
+            assert (orig_thres_matrix.astype(np.int32) == orig_thres_matrix).all()
         ret = orig_thres_matrix
         # workaround for vivado_hls threshold bug
-        if ret[0][0] == 0:
+        if ret[0][0] == 0 and n_thres_steps == 1:
             ret = np.copy(ret)
             ret[0][0] = 1
             warnings.warn(
                 "Setting 0-valued first threshold to 1 to avoid vivado_hls bug"
             )
+        # ensure channels = mh , duplicating if necessary
+        if ret.shape[0] == 1:
+            ret = np.tile(ret, (ch, 1))
+        assert (
+            ret.shape[0] == ch
+        ), "Channels of threshold matrix are not as expected (ch)"
         # distribute rows between PEs
         ret = interleave_matrix_outer_dim_from_partitions(ret, pe)
         assert (
-- 
GitLab