Skip to content
Snippets Groups Projects
Commit 7afc0986 authored by icolbert's avatar icolbert
Browse files

Update minimize_accumulator_width for VVAU

parent b8386edf
No related branches found
No related tags found
No related merge requests found
...@@ -106,75 +106,95 @@ class VectorVectorActivation(HLSCustomOp): ...@@ -106,75 +106,95 @@ class VectorVectorActivation(HLSCustomOp):
def minimize_accumulator_width(self, model): def minimize_accumulator_width(self, model):
"""Minimize the accumulator bit width according to the weight values, """Minimize the accumulator bit width according to the weight values,
input data types, and size of dot product""" input data types, and size of dot product"""
if not self.get_nodeattr("runtime_writeable_weights"): weights = model.get_initializer(self.onnx_node.input[1])
weights = model.get_initializer(self.onnx_node.input[1]) k_h, k_w = self.get_nodeattr("Kernel")
k_h, k_w = self.get_nodeattr("Kernel") fm = self.get_nodeattr("Channels")
fm = self.get_nodeattr("Channels") # put weights into the shape expected by calculate_matvec_accumulator_range
# put weights into the shape expected by calculate_matvec_accumulator_range weights = weights.reshape(fm, k_h * k_w).transpose()
weights = weights.reshape(fm, k_h * k_w).transpose() # since in the calculation the values of the weight matrix are used,
if len(self.onnx_node.input) > 2: # for the bipolar case they need to be converted to bipolar
thresholds = model.get_initializer(self.onnx_node.input[2]) if self.get_nodeattr("binaryXnorMode"):
else: weights = 2 * weights - 1
thresholds = None if len(self.onnx_node.input) > 2:
idt = self.get_input_datatype() thresholds = model.get_initializer(self.onnx_node.input[2])
# calculate minimum and maximum values of accumulator according to the else:
# weight values using the bounds derived in https://arxiv.org/abs/2301.13376 thresholds = None
idt = self.get_input_datatype()
# if runtime-writeable weights, then the values of the weights can
# change and we need to use the worst-case values from the datatypes
if self.get_nodeattr("runtime_writeable_weights"):
wdt = self.get_weight_datatype()
lower_worst = wdt.min() * np.ones_like(weights)
lower_range = calculate_matvec_accumulator_range(lower_worst, idt)
upper_worst = wdt.min() * np.ones_like(weights)
upper_range = calculate_matvec_accumulator_range(upper_worst, idt)
acc_min = min(min(lower_range), min(upper_range))
acc_max = max(max(upper_range), max(upper_range))
thresholds = None # range of thresholds are also runtime-writeable
# if not runtime-writeable weights, then we can calculate the min
# and max values of the accumulation range using knowledge of the
# weights and input data types since they are fixed
else:
(acc_min, acc_max) = calculate_matvec_accumulator_range(weights, idt) (acc_min, acc_max) = calculate_matvec_accumulator_range(weights, idt)
if thresholds is not None: # if the thresholds can be used to determine range, then adjust the range
threshold_tensor = self.get_hls_compatible_threshold_tensor(thresholds) # according to the known values of the thresholds
# set threshold datatype (and accumulator datatype implicitly) if thresholds is not None:
threshold_tensor = self.get_hls_compatible_threshold_tensor(thresholds)
# set threshold datatype (and accumulator datatype implicitly)
min_threshold = thresholds.min()
max_threshold = thresholds.max()
# clip threshold values
clip_upper = None
clip_lower = None
if max_threshold > acc_max + 1:
clip_upper = acc_max + 1
if min_threshold < acc_min:
clip_lower = acc_min
if (clip_lower is not None) or (clip_upper is not None):
warnings.warn(
"Clipping some thresholds in %s" % self.onnx_node.name
)
thresholds = np.clip(thresholds, clip_lower, clip_upper)
model.set_initializer(self.onnx_node.input[2], thresholds)
threshold_tensor = self.get_hls_compatible_threshold_tensor(
thresholds
)
min_threshold = thresholds.min() min_threshold = thresholds.min()
max_threshold = thresholds.max() max_threshold = thresholds.max()
# clip threshold values # get range required by threshold values
clip_upper = None tdt_min = min(acc_min, min_threshold)
clip_lower = None tdt_max = max(acc_max, max_threshold)
if max_threshold > acc_max + 1: if tdt_min < 0:
clip_upper = acc_max + 1 if abs(tdt_min) > tdt_max:
if min_threshold < acc_min: tdt = DataType.get_smallest_possible(tdt_min)
clip_lower = acc_min
if (clip_lower is not None) or (clip_upper is not None):
warnings.warn(
"Clipping some thresholds in %s" % self.onnx_node.name
)
thresholds = np.clip(thresholds, clip_lower, clip_upper)
model.set_initializer(self.onnx_node.input[2], thresholds)
threshold_tensor = self.get_hls_compatible_threshold_tensor(
thresholds
)
min_threshold = thresholds.min()
max_threshold = thresholds.max()
# get range required by threshold values
tdt_min = min(acc_min, min_threshold)
tdt_max = max(acc_max, max_threshold)
if tdt_min < 0:
if abs(tdt_min) > tdt_max:
tdt = DataType.get_smallest_possible(tdt_min)
else:
tdt = DataType.get_smallest_possible(-tdt_max - 1)
else: else:
tdt = DataType.get_smallest_possible(tdt_max) tdt = DataType.get_smallest_possible(-tdt_max - 1)
assert np.vectorize(tdt.allowed)(
threshold_tensor
).all(), "Thresholds in %s can't be expressed with type %s" % (
self.onnx_node.name,
str(tdt),
)
self.set_nodeattr("accDataType", tdt.name)
else: else:
if acc_min < 0: tdt = DataType.get_smallest_possible(tdt_max)
if abs(acc_min) > acc_max: assert np.vectorize(tdt.allowed)(
adt = DataType.get_smallest_possible(acc_min) threshold_tensor
else: ).all(), "Thresholds in %s can't be expressed with type %s" % (
adt = DataType.get_smallest_possible(-acc_max - 1) self.onnx_node.name,
str(tdt),
)
adt = tdt # Set activation datatype to the threshold datatype
else:
if acc_min < 0:
if abs(acc_min) > acc_max:
adt = DataType.get_smallest_possible(acc_min)
else: else:
adt = DataType.get_smallest_possible(acc_max) adt = DataType.get_smallest_possible(-acc_max - 1)
# ensure a datatype divisible by 8-bits in case this is the last node else:
bw = roundup_to_integer_multiple(adt.bitwidth(), 8) adt = DataType.get_smallest_possible(acc_max)
new_adt_name = adt.name.replace(str(adt.bitwidth()), str(bw)) # if this is the last node in the graph, then ensure the datatype is
adt = DataType[new_adt_name] # divisibly by 8 bits
self.set_nodeattr("accDataType", adt.name) if model.find_direct_successors(self.onnx_node) is None:
# for no-activation nodes, output dt = acc dt bw = roundup_to_integer_multiple(adt.bitwidth(), 8)
self.set_nodeattr("outputDataType", adt.name) new_adt_name = adt.name.replace(str(adt.bitwidth()), str(bw))
adt = DataType[new_adt_name]
# for no-activation nodes, output dt = acc dt
self.set_nodeattr("outputDataType", adt.name)
self.set_nodeattr("accDataType", adt.name)
return DataType[self.get_nodeattr("accDataType")] return DataType[self.get_nodeattr("accDataType")]
def minimize_weight_bit_width(self, model): def minimize_weight_bit_width(self, model):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment