From 8850581ba4ce620156bf6777964bb7e53b77f9b4 Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu <maltanar@gmail.com> Date: Tue, 3 Dec 2019 01:23:40 +0000 Subject: [PATCH] [StreamingFC] check num inputs before trying to grab thresholds --- .../fpgadataflow/streamingfclayer_batch.py | 59 ++++++++++--------- 1 file changed, 30 insertions(+), 29 deletions(-) diff --git a/src/finn/custom_op/fpgadataflow/streamingfclayer_batch.py b/src/finn/custom_op/fpgadataflow/streamingfclayer_batch.py index e375e239e..27272bc10 100644 --- a/src/finn/custom_op/fpgadataflow/streamingfclayer_batch.py +++ b/src/finn/custom_op/fpgadataflow/streamingfclayer_batch.py @@ -185,35 +185,36 @@ class StreamingFCLayer_Batch(HLSCustomOp): f_weights.write(weight_hls_code) f_weights.close() # thresholds - thresholds = model.get_initializer(self.onnx_node.input[2]) - if thresholds is not None: - threshold_tensor = self.get_hls_compatible_threshold_tensor(thresholds) - tdt = DataType.INT32 - # use UINT32 threshold export for bipolar times bipolar - inp_is_bipolar = self.get_input_datatype() == DataType.BIPOLAR - wt_is_bipolar = self.get_weight_datatype() == DataType.BIPOLAR - if inp_is_bipolar and wt_is_bipolar: - tdt = DataType.UINT32 - thresholds_hls_code = numpy_to_hls_code( - threshold_tensor, tdt, "thresholds", False, True - ) - # write thresholds into thresh.h - code_gen_dir = self.get_nodeattr("code_gen_dir") - f_thresh = open("{}/thresh.h".format(code_gen_dir), "w") - tdt_hls = tdt.get_hls_datatype_str() - odt_hls = self.get_output_datatype().get_hls_datatype_str() - f_thresh.write( - "static ThresholdsActivation<{},{},{},{},{},{}> threshs = ".format( - self.get_nodeattr("TMEM"), - self.get_nodeattr("PE"), - threshold_tensor.shape[-1], - tdt_hls, - odt_hls, - self.get_nodeattr("ActVal"), + if len(self.onnx_node.input) > 2: + thresholds = model.get_initializer(self.onnx_node.input[2]) + if thresholds is not None: + threshold_tensor = self.get_hls_compatible_threshold_tensor(thresholds) + tdt = DataType.INT32 + # use UINT32 threshold export for bipolar times bipolar + inp_is_bipolar = self.get_input_datatype() == DataType.BIPOLAR + wt_is_bipolar = self.get_weight_datatype() == DataType.BIPOLAR + if inp_is_bipolar and wt_is_bipolar: + tdt = DataType.UINT32 + thresholds_hls_code = numpy_to_hls_code( + threshold_tensor, tdt, "thresholds", False, True ) - ) - f_thresh.write(thresholds_hls_code) - f_thresh.close() + # write thresholds into thresh.h + code_gen_dir = self.get_nodeattr("code_gen_dir") + f_thresh = open("{}/thresh.h".format(code_gen_dir), "w") + tdt_hls = tdt.get_hls_datatype_str() + odt_hls = self.get_output_datatype().get_hls_datatype_str() + f_thresh.write( + "static ThresholdsActivation<{},{},{},{},{},{}> threshs = ".format( + self.get_nodeattr("TMEM"), + self.get_nodeattr("PE"), + threshold_tensor.shape[-1], + tdt_hls, + odt_hls, + self.get_nodeattr("ActVal"), + ) + ) + f_thresh.write(thresholds_hls_code) + f_thresh.close() def execute_node(self, context, graph): node = self.onnx_node @@ -247,7 +248,7 @@ class StreamingFCLayer_Batch(HLSCustomOp): # execute the precompiled model super().exec_precompiled_singlenode_model() # load output npy file - super().npy_to_dynamic_output() + super().npy_to_dynamic_output(context) def global_includes(self): self.code_gen_dict["$GLOBALS$"] = ['#include "weights.hpp"'] -- GitLab