From 9c134a5116c6bf5c30e7a255a6e56217b54c41db Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu <maltanar@gmail.com> Date: Tue, 3 Dec 2019 00:54:56 +0000 Subject: [PATCH] [CustomOp] generalize param gen fxn to a single optional one --- src/finn/custom_op/fpgadataflow/__init__.py | 10 ++-------- .../custom_op/fpgadataflow/streamingfclayer_batch.py | 6 +++--- 2 files changed, 5 insertions(+), 11 deletions(-) diff --git a/src/finn/custom_op/fpgadataflow/__init__.py b/src/finn/custom_op/fpgadataflow/__init__.py index 0d9a306ba..f2733ee76 100644 --- a/src/finn/custom_op/fpgadataflow/__init__.py +++ b/src/finn/custom_op/fpgadataflow/__init__.py @@ -42,8 +42,7 @@ class HLSCustomOp(CustomOp): def code_generation(self, model): node = self.onnx_node - self.generate_weights(model) - self.generate_thresholds(model) + self.generate_params(model) self.global_includes() self.defines() self.read_npy_data() @@ -78,12 +77,7 @@ class HLSCustomOp(CustomOp): builder.build(code_gen_dir) self.set_nodeattr("executable_path", builder.executable_path) - @abstractmethod - def generate_weights(self, context): - pass - - @abstractmethod - def generate_thresholds(self, context): + def generate_params(self, model): pass @abstractmethod diff --git a/src/finn/custom_op/fpgadataflow/streamingfclayer_batch.py b/src/finn/custom_op/fpgadataflow/streamingfclayer_batch.py index fdc66ce2b..58d60e7c4 100644 --- a/src/finn/custom_op/fpgadataflow/streamingfclayer_batch.py +++ b/src/finn/custom_op/fpgadataflow/streamingfclayer_batch.py @@ -149,7 +149,8 @@ class StreamingFCLayer_Batch(HLSCustomOp): assert ret.shape[2] == n_thres_steps return ret - def generate_weights(self, model): + def generate_params(self, model): + # weights weights = model.get_initializer(self.onnx_node.input[1]) # convert weights into hlslib-compatible format weight_tensor = self.get_hls_compatible_weight_tensor(weights) @@ -184,8 +185,7 @@ class StreamingFCLayer_Batch(HLSCustomOp): ) f_weights.write(weight_hls_code) f_weights.close() - - def generate_thresholds(self, model): + # thresholds thresholds = model.get_initializer(self.onnx_node.input[2]) if thresholds is not None: threshold_tensor = self.get_hls_compatible_threshold_tensor(thresholds) -- GitLab