Skip to content
Snippets Groups Projects
Commit 8850581b authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[StreamingFC] check num inputs before trying to grab thresholds

parent 2a9e731a
No related branches found
No related tags found
No related merge requests found
......@@ -185,35 +185,36 @@ class StreamingFCLayer_Batch(HLSCustomOp):
f_weights.write(weight_hls_code)
f_weights.close()
# thresholds
thresholds = model.get_initializer(self.onnx_node.input[2])
if thresholds is not None:
threshold_tensor = self.get_hls_compatible_threshold_tensor(thresholds)
tdt = DataType.INT32
# use UINT32 threshold export for bipolar times bipolar
inp_is_bipolar = self.get_input_datatype() == DataType.BIPOLAR
wt_is_bipolar = self.get_weight_datatype() == DataType.BIPOLAR
if inp_is_bipolar and wt_is_bipolar:
tdt = DataType.UINT32
thresholds_hls_code = numpy_to_hls_code(
threshold_tensor, tdt, "thresholds", False, True
)
# write thresholds into thresh.h
code_gen_dir = self.get_nodeattr("code_gen_dir")
f_thresh = open("{}/thresh.h".format(code_gen_dir), "w")
tdt_hls = tdt.get_hls_datatype_str()
odt_hls = self.get_output_datatype().get_hls_datatype_str()
f_thresh.write(
"static ThresholdsActivation<{},{},{},{},{},{}> threshs = ".format(
self.get_nodeattr("TMEM"),
self.get_nodeattr("PE"),
threshold_tensor.shape[-1],
tdt_hls,
odt_hls,
self.get_nodeattr("ActVal"),
if len(self.onnx_node.input) > 2:
thresholds = model.get_initializer(self.onnx_node.input[2])
if thresholds is not None:
threshold_tensor = self.get_hls_compatible_threshold_tensor(thresholds)
tdt = DataType.INT32
# use UINT32 threshold export for bipolar times bipolar
inp_is_bipolar = self.get_input_datatype() == DataType.BIPOLAR
wt_is_bipolar = self.get_weight_datatype() == DataType.BIPOLAR
if inp_is_bipolar and wt_is_bipolar:
tdt = DataType.UINT32
thresholds_hls_code = numpy_to_hls_code(
threshold_tensor, tdt, "thresholds", False, True
)
)
f_thresh.write(thresholds_hls_code)
f_thresh.close()
# write thresholds into thresh.h
code_gen_dir = self.get_nodeattr("code_gen_dir")
f_thresh = open("{}/thresh.h".format(code_gen_dir), "w")
tdt_hls = tdt.get_hls_datatype_str()
odt_hls = self.get_output_datatype().get_hls_datatype_str()
f_thresh.write(
"static ThresholdsActivation<{},{},{},{},{},{}> threshs = ".format(
self.get_nodeattr("TMEM"),
self.get_nodeattr("PE"),
threshold_tensor.shape[-1],
tdt_hls,
odt_hls,
self.get_nodeattr("ActVal"),
)
)
f_thresh.write(thresholds_hls_code)
f_thresh.close()
def execute_node(self, context, graph):
node = self.onnx_node
......@@ -247,7 +248,7 @@ class StreamingFCLayer_Batch(HLSCustomOp):
# execute the precompiled model
super().exec_precompiled_singlenode_model()
# load output npy file
super().npy_to_dynamic_output()
super().npy_to_dynamic_output(context)
def global_includes(self):
self.code_gen_dict["$GLOBALS$"] = ['#include "weights.hpp"']
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment