Skip to content
Snippets Groups Projects
Commit 37e43b87 authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[CodeGen, Refactor] made dir and executable CustomOp attributes

also move "if thresholds exist" logic into generate_threshold
parent d7c16dce
No related branches found
No related tags found
No related merge requests found
......@@ -41,14 +41,13 @@ class HLSCustomOp(CustomOp):
self.code_gen_dir = util.get_by_name(onnx_node.attribute, "code_gen_dir")
self.executable_path = ""
def get_nodeattr_types(self):
return {"code_gen_dir": ("s", False, ""), "executable_path": ("s", False, "")}
def code_generation(self, model):
node = self.onnx_node
if node.op_type == "StreamingFCLayer_Batch":
self.generate_weights(model)
try:
self.generate_thresholds(model)
except:
pass
self.generate_weights(model)
self.generate_thresholds(model)
self.global_includes()
self.defines()
self.read_npy_data()
......
......@@ -15,7 +15,7 @@ class StreamingFCLayer_Batch(HLSCustomOp):
super().__init__(onnx_node)
def get_nodeattr_types(self):
return {
my_attrs = {
"WMEM": ("i", True, 0),
"TMEM": ("i", True, 0),
"PE": ("i", True, 0),
......@@ -29,6 +29,8 @@ class StreamingFCLayer_Batch(HLSCustomOp):
"weightDataType": ("s", True, ""),
"outputDataType": ("s", True, ""),
}
my_attrs.update(super().get_nodeattr_types())
return my_attrs
def make_shape_compatible_op(self):
pass
......@@ -186,32 +188,33 @@ class StreamingFCLayer_Batch(HLSCustomOp):
def generate_thresholds(self, model):
thresholds = model.get_initializer(self.onnx_node.input[2])
threshold_tensor = self.get_hls_compatible_threshold_tensor(thresholds)
tdt = DataType.INT32
# use UINT32 threshold export for bipolar times bipolar
inp_is_bipolar = self.get_input_datatype() == DataType.BIPOLAR
wt_is_bipolar = self.get_weight_datatype() == DataType.BIPOLAR
if inp_is_bipolar and wt_is_bipolar:
tdt = DataType.UINT32
thresholds_hls_code = numpy_to_hls_code(
threshold_tensor, tdt, "thresholds", False, True
)
# write thresholds into thresh.h
f_thresh = open("{}/thresh.h".format(self.tmp_dir), "w")
tdt_hls = tdt.get_hls_datatype_str()
odt_hls = self.get_output_datatype().get_hls_datatype_str()
f_thresh.write(
"static ThresholdsActivation<{},{},{},{},{},{}> threshs = ".format(
self.get_nodeattr("TMEM"),
self.get_nodeattr("PE"),
threshold_tensor.shape[-1],
tdt_hls,
odt_hls,
self.get_nodeattr("ActVal"),
if thresholds is not None:
threshold_tensor = self.get_hls_compatible_threshold_tensor(thresholds)
tdt = DataType.INT32
# use UINT32 threshold export for bipolar times bipolar
inp_is_bipolar = self.get_input_datatype() == DataType.BIPOLAR
wt_is_bipolar = self.get_weight_datatype() == DataType.BIPOLAR
if inp_is_bipolar and wt_is_bipolar:
tdt = DataType.UINT32
thresholds_hls_code = numpy_to_hls_code(
threshold_tensor, tdt, "thresholds", False, True
)
)
f_thresh.write(thresholds_hls_code)
f_thresh.close()
# write thresholds into thresh.h
f_thresh = open("{}/thresh.h".format(self.tmp_dir), "w")
tdt_hls = tdt.get_hls_datatype_str()
odt_hls = self.get_output_datatype().get_hls_datatype_str()
f_thresh.write(
"static ThresholdsActivation<{},{},{},{},{},{}> threshs = ".format(
self.get_nodeattr("TMEM"),
self.get_nodeattr("PE"),
threshold_tensor.shape[-1],
tdt_hls,
odt_hls,
self.get_nodeattr("ActVal"),
)
)
f_thresh.write(thresholds_hls_code)
f_thresh.close()
def execute_node(self, context, graph):
node = self.onnx_node
......
......@@ -9,11 +9,13 @@ from finn.custom_op.fpgadataflow import HLSCustomOp
class StreamingMaxPool(HLSCustomOp):
def get_nodeattr_types(self):
return {
my_attrs = {
"ImgDim": ("i", True, 0),
"PoolDim": ("i", True, 0),
"NumChannels": ("i", True, 0),
}
my_attrs.update(super().get_nodeattr_types())
return my_attrs
def make_shape_compatible_op(self):
pass
......
......@@ -9,11 +9,13 @@ from finn.custom_op.fpgadataflow import HLSCustomOp
class StreamingMaxPool_Batch(HLSCustomOp):
def get_nodeattr_types(self):
return {
my_attrs = {
"ImgDim": ("i", True, 0),
"PoolDim": ("i", True, 0),
"NumChannels": ("i", True, 0),
}
my_attrs.update(super().get_nodeattr_types())
return my_attrs
def make_shape_compatible_op(self):
pass
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment