From 9c134a5116c6bf5c30e7a255a6e56217b54c41db Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <maltanar@gmail.com>
Date: Tue, 3 Dec 2019 00:54:56 +0000
Subject: [PATCH] [CustomOp] generalize param gen fxn to a single optional one

---
 src/finn/custom_op/fpgadataflow/__init__.py            | 10 ++--------
 .../custom_op/fpgadataflow/streamingfclayer_batch.py   |  6 +++---
 2 files changed, 5 insertions(+), 11 deletions(-)

diff --git a/src/finn/custom_op/fpgadataflow/__init__.py b/src/finn/custom_op/fpgadataflow/__init__.py
index 0d9a306ba..f2733ee76 100644
--- a/src/finn/custom_op/fpgadataflow/__init__.py
+++ b/src/finn/custom_op/fpgadataflow/__init__.py
@@ -42,8 +42,7 @@ class HLSCustomOp(CustomOp):
 
     def code_generation(self, model):
         node = self.onnx_node
-        self.generate_weights(model)
-        self.generate_thresholds(model)
+        self.generate_params(model)
         self.global_includes()
         self.defines()
         self.read_npy_data()
@@ -78,12 +77,7 @@ class HLSCustomOp(CustomOp):
         builder.build(code_gen_dir)
         self.set_nodeattr("executable_path", builder.executable_path)
 
-    @abstractmethod
-    def generate_weights(self, context):
-        pass
-
-    @abstractmethod
-    def generate_thresholds(self, context):
+    def generate_params(self, model):
         pass
 
     @abstractmethod
diff --git a/src/finn/custom_op/fpgadataflow/streamingfclayer_batch.py b/src/finn/custom_op/fpgadataflow/streamingfclayer_batch.py
index fdc66ce2b..58d60e7c4 100644
--- a/src/finn/custom_op/fpgadataflow/streamingfclayer_batch.py
+++ b/src/finn/custom_op/fpgadataflow/streamingfclayer_batch.py
@@ -149,7 +149,8 @@ class StreamingFCLayer_Batch(HLSCustomOp):
         assert ret.shape[2] == n_thres_steps
         return ret
 
-    def generate_weights(self, model):
+    def generate_params(self, model):
+        # weights
         weights = model.get_initializer(self.onnx_node.input[1])
         # convert weights into hlslib-compatible format
         weight_tensor = self.get_hls_compatible_weight_tensor(weights)
@@ -184,8 +185,7 @@ class StreamingFCLayer_Batch(HLSCustomOp):
             )
         f_weights.write(weight_hls_code)
         f_weights.close()
-
-    def generate_thresholds(self, model):
+        # thresholds
         thresholds = model.get_initializer(self.onnx_node.input[2])
         if thresholds is not None:
             threshold_tensor = self.get_hls_compatible_threshold_tensor(thresholds)
-- 
GitLab