Skip to content
Snippets Groups Projects
Commit bb5a4a42 authored by Lucian Petrica's avatar Lucian Petrica
Browse files

Bitwidth optimizations to accumulators and thresholds, through clipping of...

Bitwidth optimizations to accumulators and thresholds, through clipping of too-large threshold values and explicit calculation of the bitwidth
parent d6495c52
No related branches found
No related tags found
No related merge requests found
......@@ -594,7 +594,6 @@ class StreamingFCLayer_Batch(HLSCustomOp):
thresholds = model.get_initializer(self.onnx_node.input[2])
if thresholds is not None:
threshold_tensor = self.get_hls_compatible_threshold_tensor(thresholds)
tdt = DataType.INT32
# use UINT32 threshold export for bipolar times bipolar
inp_is_bipolar = self.get_input_datatype() == DataType.BIPOLAR
wt_is_bipolar = self.get_weight_datatype() == DataType.BIPOLAR
......@@ -604,8 +603,40 @@ class StreamingFCLayer_Batch(HLSCustomOp):
bin_xnor_mode = self.get_nodeattr("binaryXnorMode") == 1
inp_is_bipolar = inp_is_bipolar or (inp_is_binary and bin_xnor_mode)
wt_is_bipolar = wt_is_bipolar or (wt_is_binary and bin_xnor_mode)
if inp_is_bipolar and wt_is_bipolar:
tdt = DataType.UINT32
# set threshold datatype (and accumulator datatype implicitly)
min_threshold = thresholds.min()
max_threshold = thresholds.max()
min_weight = weights.min()
max_weight = weights.max()
perceptive_field_elems = self.get_nodeattr("MW")
min_input = self.get_input_datatype().min()
max_input = self.get_input_datatype().max()
# calculate minimum and maximum values of accumulator
# assume inputs span the whole range of the input datatype
acc_min = perceptive_field_elems * min(
min_weight * max_input,
min_weight * min_input,
max_weight * max_input,
max_weight * min_input,
)
acc_max = perceptive_field_elems * max(
min_weight * max_input,
min_weight * min_input,
max_weight * max_input,
max_weight * min_input,
)
# get range required by threshold values
tdt_min = min(acc_min, min_threshold)
tdt_max = max(acc_max, max_threshold)
if tdt_min < 0:
if abs(tdt_min) > tdt_max:
tdt = DataType.get_smallest_possible(tdt_min)
else:
tdt = DataType.get_smallest_possible(0 - tdt_max)
else:
tdt = DataType.get_smallest_possible(tdt_max)
thresholds_hls_code = numpy_to_hls_code(
threshold_tensor, tdt, "thresholds", False, True
)
......
......@@ -279,7 +279,22 @@ class Thresholding_Batch(HLSCustomOp):
thresholds = model.get_initializer(self.onnx_node.input[1])
threshold_tensor = self.get_hls_compatible_threshold_tensor(thresholds)
tdt = DataType.INT32
min_threshold = thresholds.min()
max_threshold = thresholds.max()
min_input = self.get_input_datatype().min()
max_input = self.get_input_datatype().max()
# get range required by threshold values
tdt_min = min(min_input, min_threshold)
tdt_max = max(max_input, max_threshold)
if tdt_min < 0:
if abs(tdt_min) > tdt_max:
tdt = DataType.get_smallest_possible(tdt_min)
else:
tdt = DataType.get_smallest_possible(0 - tdt_max - 1)
else:
tdt = DataType.get_smallest_possible(tdt_max)
thresholds_hls_code = numpy_to_hls_code(
threshold_tensor, tdt, "thresholds", False, True
)
......
......@@ -51,10 +51,20 @@ class RoundAndClipThresholds(Transformation):
model.set_tensor_datatype(n.input[1], idtype)
graph_modified = True
if idtype.is_integer() and not idtype.signed() and (Tnew < 0).any():
# clip any negative thresholds
# clip any negative thresholds if input is unsigned
Tnew = np.clip(Tnew, 0, None)
model.set_initializer(n.input[1], Tnew)
# use same datatype as inputs for thresholds
model.set_tensor_datatype(n.input[1], idtype)
graph_modified = True
if idtype.is_integer() and (
(Tnew < (idtype.min() - 1)).any()
or (Tnew > (idtype.max() + 1)).any()
):
# clip any large thresholds to input range + 1
Tnew = np.clip(Tnew, idtype.min() - 1, idtype.max() + 1)
model.set_initializer(n.input[1], Tnew)
# use same datatype as inputs for thresholds
model.set_tensor_datatype(n.input[1], idtype)
graph_modified = True
return (model, graph_modified)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment