Skip to content
Snippets Groups Projects
Commit 9c134a51 authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[CustomOp] generalize param gen fxn to a single optional one

parent 4a9989df
No related branches found
No related tags found
No related merge requests found
......@@ -42,8 +42,7 @@ class HLSCustomOp(CustomOp):
def code_generation(self, model):
node = self.onnx_node
self.generate_weights(model)
self.generate_thresholds(model)
self.generate_params(model)
self.global_includes()
self.defines()
self.read_npy_data()
......@@ -78,12 +77,7 @@ class HLSCustomOp(CustomOp):
builder.build(code_gen_dir)
self.set_nodeattr("executable_path", builder.executable_path)
@abstractmethod
def generate_weights(self, context):
pass
@abstractmethod
def generate_thresholds(self, context):
def generate_params(self, model):
pass
@abstractmethod
......
......@@ -149,7 +149,8 @@ class StreamingFCLayer_Batch(HLSCustomOp):
assert ret.shape[2] == n_thres_steps
return ret
def generate_weights(self, model):
def generate_params(self, model):
# weights
weights = model.get_initializer(self.onnx_node.input[1])
# convert weights into hlslib-compatible format
weight_tensor = self.get_hls_compatible_weight_tensor(weights)
......@@ -184,8 +185,7 @@ class StreamingFCLayer_Batch(HLSCustomOp):
)
f_weights.write(weight_hls_code)
f_weights.close()
def generate_thresholds(self, model):
# thresholds
thresholds = model.get_initializer(self.onnx_node.input[2])
if thresholds is not None:
threshold_tensor = self.get_hls_compatible_threshold_tensor(thresholds)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment