Skip to content
Snippets Groups Projects
Commit 7afc0986 authored by icolbert's avatar icolbert
Browse files

Update minimize_accumulator_width for VVAU

parent b8386edf
No related branches found
No related tags found
No related merge requests found
......@@ -106,75 +106,95 @@ class VectorVectorActivation(HLSCustomOp):
def minimize_accumulator_width(self, model):
"""Minimize the accumulator bit width according to the weight values,
input data types, and size of dot product"""
if not self.get_nodeattr("runtime_writeable_weights"):
weights = model.get_initializer(self.onnx_node.input[1])
k_h, k_w = self.get_nodeattr("Kernel")
fm = self.get_nodeattr("Channels")
# put weights into the shape expected by calculate_matvec_accumulator_range
weights = weights.reshape(fm, k_h * k_w).transpose()
if len(self.onnx_node.input) > 2:
thresholds = model.get_initializer(self.onnx_node.input[2])
else:
thresholds = None
idt = self.get_input_datatype()
# calculate minimum and maximum values of accumulator according to the
# weight values using the bounds derived in https://arxiv.org/abs/2301.13376
weights = model.get_initializer(self.onnx_node.input[1])
k_h, k_w = self.get_nodeattr("Kernel")
fm = self.get_nodeattr("Channels")
# put weights into the shape expected by calculate_matvec_accumulator_range
weights = weights.reshape(fm, k_h * k_w).transpose()
# since in the calculation the values of the weight matrix are used,
# for the bipolar case they need to be converted to bipolar
if self.get_nodeattr("binaryXnorMode"):
weights = 2 * weights - 1
if len(self.onnx_node.input) > 2:
thresholds = model.get_initializer(self.onnx_node.input[2])
else:
thresholds = None
idt = self.get_input_datatype()
# if runtime-writeable weights, then the values of the weights can
# change and we need to use the worst-case values from the datatypes
if self.get_nodeattr("runtime_writeable_weights"):
wdt = self.get_weight_datatype()
lower_worst = wdt.min() * np.ones_like(weights)
lower_range = calculate_matvec_accumulator_range(lower_worst, idt)
upper_worst = wdt.min() * np.ones_like(weights)
upper_range = calculate_matvec_accumulator_range(upper_worst, idt)
acc_min = min(min(lower_range), min(upper_range))
acc_max = max(max(upper_range), max(upper_range))
thresholds = None # range of thresholds are also runtime-writeable
# if not runtime-writeable weights, then we can calculate the min
# and max values of the accumulation range using knowledge of the
# weights and input data types since they are fixed
else:
(acc_min, acc_max) = calculate_matvec_accumulator_range(weights, idt)
if thresholds is not None:
threshold_tensor = self.get_hls_compatible_threshold_tensor(thresholds)
# set threshold datatype (and accumulator datatype implicitly)
# if the thresholds can be used to determine range, then adjust the range
# according to the known values of the thresholds
if thresholds is not None:
threshold_tensor = self.get_hls_compatible_threshold_tensor(thresholds)
# set threshold datatype (and accumulator datatype implicitly)
min_threshold = thresholds.min()
max_threshold = thresholds.max()
# clip threshold values
clip_upper = None
clip_lower = None
if max_threshold > acc_max + 1:
clip_upper = acc_max + 1
if min_threshold < acc_min:
clip_lower = acc_min
if (clip_lower is not None) or (clip_upper is not None):
warnings.warn(
"Clipping some thresholds in %s" % self.onnx_node.name
)
thresholds = np.clip(thresholds, clip_lower, clip_upper)
model.set_initializer(self.onnx_node.input[2], thresholds)
threshold_tensor = self.get_hls_compatible_threshold_tensor(
thresholds
)
min_threshold = thresholds.min()
max_threshold = thresholds.max()
# clip threshold values
clip_upper = None
clip_lower = None
if max_threshold > acc_max + 1:
clip_upper = acc_max + 1
if min_threshold < acc_min:
clip_lower = acc_min
if (clip_lower is not None) or (clip_upper is not None):
warnings.warn(
"Clipping some thresholds in %s" % self.onnx_node.name
)
thresholds = np.clip(thresholds, clip_lower, clip_upper)
model.set_initializer(self.onnx_node.input[2], thresholds)
threshold_tensor = self.get_hls_compatible_threshold_tensor(
thresholds
)
min_threshold = thresholds.min()
max_threshold = thresholds.max()
# get range required by threshold values
tdt_min = min(acc_min, min_threshold)
tdt_max = max(acc_max, max_threshold)
if tdt_min < 0:
if abs(tdt_min) > tdt_max:
tdt = DataType.get_smallest_possible(tdt_min)
else:
tdt = DataType.get_smallest_possible(-tdt_max - 1)
# get range required by threshold values
tdt_min = min(acc_min, min_threshold)
tdt_max = max(acc_max, max_threshold)
if tdt_min < 0:
if abs(tdt_min) > tdt_max:
tdt = DataType.get_smallest_possible(tdt_min)
else:
tdt = DataType.get_smallest_possible(tdt_max)
assert np.vectorize(tdt.allowed)(
threshold_tensor
).all(), "Thresholds in %s can't be expressed with type %s" % (
self.onnx_node.name,
str(tdt),
)
self.set_nodeattr("accDataType", tdt.name)
tdt = DataType.get_smallest_possible(-tdt_max - 1)
else:
if acc_min < 0:
if abs(acc_min) > acc_max:
adt = DataType.get_smallest_possible(acc_min)
else:
adt = DataType.get_smallest_possible(-acc_max - 1)
tdt = DataType.get_smallest_possible(tdt_max)
assert np.vectorize(tdt.allowed)(
threshold_tensor
).all(), "Thresholds in %s can't be expressed with type %s" % (
self.onnx_node.name,
str(tdt),
)
adt = tdt # Set activation datatype to the threshold datatype
else:
if acc_min < 0:
if abs(acc_min) > acc_max:
adt = DataType.get_smallest_possible(acc_min)
else:
adt = DataType.get_smallest_possible(acc_max)
# ensure a datatype divisible by 8-bits in case this is the last node
bw = roundup_to_integer_multiple(adt.bitwidth(), 8)
new_adt_name = adt.name.replace(str(adt.bitwidth()), str(bw))
adt = DataType[new_adt_name]
self.set_nodeattr("accDataType", adt.name)
# for no-activation nodes, output dt = acc dt
self.set_nodeattr("outputDataType", adt.name)
adt = DataType.get_smallest_possible(-acc_max - 1)
else:
adt = DataType.get_smallest_possible(acc_max)
# if this is the last node in the graph, then ensure the datatype is
# divisibly by 8 bits
if model.find_direct_successors(self.onnx_node) is None:
bw = roundup_to_integer_multiple(adt.bitwidth(), 8)
new_adt_name = adt.name.replace(str(adt.bitwidth()), str(bw))
adt = DataType[new_adt_name]
# for no-activation nodes, output dt = acc dt
self.set_nodeattr("outputDataType", adt.name)
self.set_nodeattr("accDataType", adt.name)
return DataType[self.get_nodeattr("accDataType")]
def minimize_weight_bit_width(self, model):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment