From dd00db030f147bde64b05cc56c3d034e7bcac9e9 Mon Sep 17 00:00:00 2001
From: auphelia <jakobapk@web.de>
Date: Wed, 8 Jul 2020 17:02:12 +0100
Subject: [PATCH] [CustomOp] Add output activation support to VVAU node

---
 .../vector_vector_activate_batch.py           | 84 ++++++++++++++-----
 1 file changed, 62 insertions(+), 22 deletions(-)

diff --git a/src/finn/custom_op/fpgadataflow/vector_vector_activate_batch.py b/src/finn/custom_op/fpgadataflow/vector_vector_activate_batch.py
index 5a7db03c1..d7339f761 100644
--- a/src/finn/custom_op/fpgadataflow/vector_vector_activate_batch.py
+++ b/src/finn/custom_op/fpgadataflow/vector_vector_activate_batch.py
@@ -11,13 +11,6 @@ from finn.util.data_packing import (
     rtlsim_output_to_npy,
 )
 
-# ONNX i/o tensor shape assumptions for Vector_Vector_Activate_Batch:
-# input 0 is the input tensor, shape (.., i_size) = (..., MW)
-# input 1 is the weight tensor, shape (i_size, o_size) = (MW, MH)
-# (optional) input 2 is the thresholds tensor, shape (o_size, n_thres)
-# output 0 is the output tensor, shape (.., o_size) = (..., MH)
-# the ... here can be any shape (representing groups of vectors)
-
 
 class Vector_Vector_Activate_Batch(HLSCustomOp):
     """Class that corresponds to finn-hlslib Vector_Vector_Activate_Batch function"""
@@ -158,22 +151,13 @@ class Vector_Vector_Activate_Batch(HLSCustomOp):
         ret = dict()
         inp_hls_str = self.get_input_datatype().get_hls_datatype_str()
         out_hls_str = self.get_output_datatype().get_hls_datatype_str()
-        # inp_is_binary = self.get_input_datatype() == DataType.BINARY
-        # wt_is_binary = self.get_weight_datatype() == DataType.BINARY
         inp_is_bipolar = self.get_input_datatype() == DataType.BIPOLAR
         wt_is_bipolar = self.get_weight_datatype() == DataType.BIPOLAR
         # fill in TSrcI and TWeightI
         # TODO handle non-bipolar binary inputs
-        if inp_is_bipolar and wt_is_bipolar:
-            ret["TSrcI"] = "Recast<XnorMul>"
-            ret["TWeightI"] = "Identity"
-        elif (not inp_is_bipolar) and wt_is_bipolar:
-            ret["TSrcI"] = "Slice<%s>" % inp_hls_str
-            ret["TWeightI"] = "Recast<Binary>"
-        elif inp_is_bipolar and (not wt_is_bipolar):
-            ret["TSrcI"] = "Recast<Binary>"
-            ret["TWeightI"] = "Identity"
-        elif (not inp_is_bipolar) and (not wt_is_bipolar):
+        if inp_is_bipolar or wt_is_bipolar:
+            raise Exception("VVAU node doesn't support bipolar values yet.")
+        else:
             ret["TSrcI"] = "Slice<%s>" % inp_hls_str
             ret["TWeightI"] = "Identity"
 
@@ -202,6 +186,33 @@ class Vector_Vector_Activate_Batch(HLSCustomOp):
         ret = ret.reshape(1, pe, wmem, 1)
         return ret
 
+    def get_hls_compatible_threshold_tensor(self, orig_thres_matrix):
+        ch = self.get_nodeattr("Channels")
+        pe = self.get_nodeattr("PE")
+        tmem = self.calc_tmem()
+        assert ch % pe == 0, "Requirement Channels divisable by PE is violated."
+        assert (
+            orig_thres_matrix.ndim == 2
+        ), """Threshold matrix dimension is
+        not as expected (2)."""
+        n_thres_steps = orig_thres_matrix.shape[1]
+        ret = orig_thres_matrix
+        # distribute rows between PEs
+        ret = interleave_matrix_outer_dim_from_partitions(ret, pe)
+        assert (
+            ret.shape[0] == pe
+        ), """First dimension after distribution of the
+        rows between PEs is not as expected (pe)"""
+        assert (
+            ret.shape[1] == tmem
+        ), """Second dimension after distribution of the
+        rows between PEs is not as expected (tmem)"""
+        assert (
+            ret.shape[2] == n_thres_steps
+        ), """Third dimension after distribution of the
+        rows between PEs is not as expected (n_thres_steps)"""
+        return ret.reshape(1, pe, tmem, n_thres_steps)
+
     def generate_params(self, model, path):
         # weights
         weights = model.get_initializer(self.onnx_node.input[1])
@@ -238,6 +249,36 @@ class Vector_Vector_Activate_Batch(HLSCustomOp):
         f_weights.write(weight_hls_code)
         f_weights.close()
 
+        # save thresholds in thresh.h
+        if len(self.onnx_node.input) > 2:
+            thresholds = model.get_initializer(self.onnx_node.input[2])
+            if thresholds is not None:
+                threshold_tensor = self.get_hls_compatible_threshold_tensor(thresholds)
+                tdt = DataType.INT32
+                thresholds_hls_code = numpy_to_hls_code(
+                    threshold_tensor, tdt, "thresholds", False, True
+                )
+                # write thresholds into thresh.h
+                f_thresh = open("{}/thresh.h".format(code_gen_dir), "w")
+                tdt_hls = tdt.get_hls_datatype_str()
+                # use binary to export bipolar activations
+                odt = self.get_output_datatype()
+                odt_hls = odt.get_hls_datatype_str()
+                f_thresh.write(
+                    "static ThresholdsActivation<{},{},{},{},{},{},{}> threshs \
+                    = ".format(
+                        self.calc_tmem(),
+                        self.get_nodeattr("PE"),
+                        threshold_tensor.shape[-1],
+                        tdt_hls,
+                        odt_hls,
+                        self.get_nodeattr("ActVal"),
+                        "std::less_equal<%s>" % tdt_hls,
+                    )
+                )
+                f_thresh.write(thresholds_hls_code)
+                f_thresh.close()
+
     def execute_node(self, context, graph):
         mode = self.get_nodeattr("exec_mode")
         node = self.onnx_node
@@ -336,9 +377,8 @@ class Vector_Vector_Activate_Batch(HLSCustomOp):
     def global_includes(self):
         self.code_gen_dict["$GLOBALS$"] = ['#include "weights.hpp"']
         self.code_gen_dict["$GLOBALS$"] += ['#include "activations.hpp"']
-        # if self.calc_tmem() != 0:
-        #    # TODO find a better way of checking for no pregenerated thresholds
-        #    self.code_gen_dict["$GLOBALS$"] += ['#include "thresh.h"']
+        if self.calc_tmem() != 0:
+            self.code_gen_dict["$GLOBALS$"] += ['#include "thresh.h"']
 
     def defines(self, var):
         dim = self.get_nodeattr("Dim")
-- 
GitLab