Skip to content
Snippets Groups Projects
Commit dd00db03 authored by auphelia's avatar auphelia
Browse files

[CustomOp] Add output activation support to VVAU node

parent c32a6779
No related branches found
No related tags found
No related merge requests found
......@@ -11,13 +11,6 @@ from finn.util.data_packing import (
rtlsim_output_to_npy,
)
# ONNX i/o tensor shape assumptions for Vector_Vector_Activate_Batch:
# input 0 is the input tensor, shape (.., i_size) = (..., MW)
# input 1 is the weight tensor, shape (i_size, o_size) = (MW, MH)
# (optional) input 2 is the thresholds tensor, shape (o_size, n_thres)
# output 0 is the output tensor, shape (.., o_size) = (..., MH)
# the ... here can be any shape (representing groups of vectors)
class Vector_Vector_Activate_Batch(HLSCustomOp):
"""Class that corresponds to finn-hlslib Vector_Vector_Activate_Batch function"""
......@@ -158,22 +151,13 @@ class Vector_Vector_Activate_Batch(HLSCustomOp):
ret = dict()
inp_hls_str = self.get_input_datatype().get_hls_datatype_str()
out_hls_str = self.get_output_datatype().get_hls_datatype_str()
# inp_is_binary = self.get_input_datatype() == DataType.BINARY
# wt_is_binary = self.get_weight_datatype() == DataType.BINARY
inp_is_bipolar = self.get_input_datatype() == DataType.BIPOLAR
wt_is_bipolar = self.get_weight_datatype() == DataType.BIPOLAR
# fill in TSrcI and TWeightI
# TODO handle non-bipolar binary inputs
if inp_is_bipolar and wt_is_bipolar:
ret["TSrcI"] = "Recast<XnorMul>"
ret["TWeightI"] = "Identity"
elif (not inp_is_bipolar) and wt_is_bipolar:
ret["TSrcI"] = "Slice<%s>" % inp_hls_str
ret["TWeightI"] = "Recast<Binary>"
elif inp_is_bipolar and (not wt_is_bipolar):
ret["TSrcI"] = "Recast<Binary>"
ret["TWeightI"] = "Identity"
elif (not inp_is_bipolar) and (not wt_is_bipolar):
if inp_is_bipolar or wt_is_bipolar:
raise Exception("VVAU node doesn't support bipolar values yet.")
else:
ret["TSrcI"] = "Slice<%s>" % inp_hls_str
ret["TWeightI"] = "Identity"
......@@ -202,6 +186,33 @@ class Vector_Vector_Activate_Batch(HLSCustomOp):
ret = ret.reshape(1, pe, wmem, 1)
return ret
def get_hls_compatible_threshold_tensor(self, orig_thres_matrix):
ch = self.get_nodeattr("Channels")
pe = self.get_nodeattr("PE")
tmem = self.calc_tmem()
assert ch % pe == 0, "Requirement Channels divisable by PE is violated."
assert (
orig_thres_matrix.ndim == 2
), """Threshold matrix dimension is
not as expected (2)."""
n_thres_steps = orig_thres_matrix.shape[1]
ret = orig_thres_matrix
# distribute rows between PEs
ret = interleave_matrix_outer_dim_from_partitions(ret, pe)
assert (
ret.shape[0] == pe
), """First dimension after distribution of the
rows between PEs is not as expected (pe)"""
assert (
ret.shape[1] == tmem
), """Second dimension after distribution of the
rows between PEs is not as expected (tmem)"""
assert (
ret.shape[2] == n_thres_steps
), """Third dimension after distribution of the
rows between PEs is not as expected (n_thres_steps)"""
return ret.reshape(1, pe, tmem, n_thres_steps)
def generate_params(self, model, path):
# weights
weights = model.get_initializer(self.onnx_node.input[1])
......@@ -238,6 +249,36 @@ class Vector_Vector_Activate_Batch(HLSCustomOp):
f_weights.write(weight_hls_code)
f_weights.close()
# save thresholds in thresh.h
if len(self.onnx_node.input) > 2:
thresholds = model.get_initializer(self.onnx_node.input[2])
if thresholds is not None:
threshold_tensor = self.get_hls_compatible_threshold_tensor(thresholds)
tdt = DataType.INT32
thresholds_hls_code = numpy_to_hls_code(
threshold_tensor, tdt, "thresholds", False, True
)
# write thresholds into thresh.h
f_thresh = open("{}/thresh.h".format(code_gen_dir), "w")
tdt_hls = tdt.get_hls_datatype_str()
# use binary to export bipolar activations
odt = self.get_output_datatype()
odt_hls = odt.get_hls_datatype_str()
f_thresh.write(
"static ThresholdsActivation<{},{},{},{},{},{},{}> threshs \
= ".format(
self.calc_tmem(),
self.get_nodeattr("PE"),
threshold_tensor.shape[-1],
tdt_hls,
odt_hls,
self.get_nodeattr("ActVal"),
"std::less_equal<%s>" % tdt_hls,
)
)
f_thresh.write(thresholds_hls_code)
f_thresh.close()
def execute_node(self, context, graph):
mode = self.get_nodeattr("exec_mode")
node = self.onnx_node
......@@ -336,9 +377,8 @@ class Vector_Vector_Activate_Batch(HLSCustomOp):
def global_includes(self):
self.code_gen_dict["$GLOBALS$"] = ['#include "weights.hpp"']
self.code_gen_dict["$GLOBALS$"] += ['#include "activations.hpp"']
# if self.calc_tmem() != 0:
# # TODO find a better way of checking for no pregenerated thresholds
# self.code_gen_dict["$GLOBALS$"] += ['#include "thresh.h"']
if self.calc_tmem() != 0:
self.code_gen_dict["$GLOBALS$"] += ['#include "thresh.h"']
def defines(self, var):
dim = self.get_nodeattr("Dim")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment