Skip to content
Snippets Groups Projects
Commit 64306ba7 authored by auphelia's avatar auphelia
Browse files

[CustomOp] Update ImgDim and numReps in thresholding

parent 91591f0f
No related branches found
No related tags found
No related merge requests found
......@@ -602,17 +602,17 @@ class Thresholding_Batch(HLSCustomOp):
# TODO check and add whatever missing
def defines(self, var):
if self.get_nodeattr("mem_mode") == "const":
numReps = 1
else:
numInputVectors = list(self.get_nodeattr("numInputVectors"))
numReps = int(np.prod(numInputVectors))
numReps = 1
numInputVectors = list(self.get_nodeattr("numInputVectors"))
total_spatial_size = int(np.prod(numInputVectors))
self.code_gen_dict["$DEFINES$"] = [
"""#define NumChannels1 {}\n #define PE1 {}\n #define numReps {}""".format(
"""#define NumChannels1 {}\n #define PE1 {}\n #define numReps {}\n
#define ImgDim1 {}""".format(
self.get_nodeattr("NumChannels"),
self.get_nodeattr("PE"),
numReps,
total_spatial_size,
)
]
if self.get_nodeattr("mem_mode") == "decoupled":
......@@ -653,7 +653,7 @@ class Thresholding_Batch(HLSCustomOp):
npy_in = "%s/thresholds.npy" % code_gen_dir
self.code_gen_dict["$READNPYDATA$"].append(
'npy2apintstream<%s, %s, %d, %s>("%s", weights, false, numReps);'
'npy2apintstream<%s, %s, %d, %s>("%s", weights, false, ImgDim1);'
% (packed_hls_type, elem_hls_type, elem_bits, npy_type, npy_in)
)
......@@ -675,18 +675,13 @@ class Thresholding_Batch(HLSCustomOp):
def docompute(self):
tmpl_args = self.get_template_param_values()
# TODO: why put some template parameters into defines and not others?
# should ImgDim be defined or just filled in here like we do now?
node = self.onnx_node
inp_vecs = self.get_nodeattr("numInputVectors")
total_spatial_size = int(np.prod(inp_vecs))
mem_mode = self.get_nodeattr("mem_mode")
if mem_mode == "const":
self.code_gen_dict["$DOCOMPUTE$"] = [
"""{}<{}, NumChannels1, PE1, {}, {}>
"""{}<ImgDim1, NumChannels1, PE1, {}, {}>
(in0, out, threshs, numReps);""".format(
node.op_type,
total_spatial_size,
tmpl_args["TSrcI"],
tmpl_args["TDstI"],
)
......@@ -696,10 +691,9 @@ class Thresholding_Batch(HLSCustomOp):
# - for cppsim the repetition comes from the threshold stream reader+input
# - for synth the unit runs continuously anyway (ap_ctrl_none)
self.code_gen_dict["$DOCOMPUTE$"] = [
"""{}<{}, NumChannels1, PE1, {}, {}, ActVal1, ThresType1, NumSteps1>
(in0, out, weights, 1);""".format(
"""{}<ImgDim1, NumChannels1, PE1, {}, {}, ActVal1, ThresType1, NumSteps1>
(in0, out, weights, numReps);""".format(
"Thresholding_Stream_Batch",
total_spatial_size,
tmpl_args["TSrcI"],
tmpl_args["TDstI"],
)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment